You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2022/06/10 21:31:33 UTC

[beam] branch master updated: Make keying of examples explicit. (#21777)

This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new b8e2e85ab1f Make keying of examples explicit. (#21777)
b8e2e85ab1f is described below

commit b8e2e85ab1fb37a2f89ed20d88730e591ea3bf7e
Author: Robert Bradshaw <ro...@gmail.com>
AuthorDate: Fri Jun 10 14:31:28 2022 -0700

    Make keying of examples explicit. (#21777)
    
    This decouples the keying logic from the DoFn and helps with type inference.
    
    There is both a KeyedModelHandler that expects keys and a MaybeKeyedModelHandler that preserves the old behavior.
---
 sdks/python/apache_beam/ml/inference/base.py       | 134 ++++++++++++++++++---
 sdks/python/apache_beam/ml/inference/base_test.py  |  24 +++-
 .../apache_beam/ml/inference/pytorch_inference.py  |   6 +-
 .../apache_beam/ml/inference/sklearn_inference.py  |  12 +-
 .../ml/inference/sklearn_inference_test.py         |   2 +-
 5 files changed, 145 insertions(+), 33 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py
index 534512a44f3..ae07ac0531e 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -35,9 +35,11 @@ import time
 from typing import Any
 from typing import Generic
 from typing import Iterable
-from typing import List
 from typing import Mapping
+from typing import Sequence
+from typing import Tuple
 from typing import TypeVar
+from typing import Union
 
 import apache_beam as beam
 from apache_beam.utils import shared
@@ -54,6 +56,7 @@ _NANOSECOND_TO_MICROSECOND = 1_000
 ModelT = TypeVar('ModelT')
 ExampleT = TypeVar('ExampleT')
 PredictionT = TypeVar('PredictionT')
+KeyT = TypeVar('KeyT')
 
 
 def _to_milliseconds(time_ns: int) -> int:
@@ -70,13 +73,13 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
     """Loads and initializes a model for processing."""
     raise NotImplementedError(type(self))
 
-  def run_inference(self, batch: List[ExampleT], model: ModelT,
+  def run_inference(self, batch: Sequence[ExampleT], model: ModelT,
                     **kwargs) -> Iterable[PredictionT]:
     """Runs inferences on a batch of examples and
     returns an Iterable of Predictions."""
     raise NotImplementedError(type(self))
 
-  def get_num_bytes(self, batch: List[ExampleT]) -> int:
+  def get_num_bytes(self, batch: Sequence[ExampleT]) -> int:
     """Returns the number of bytes of data for a batch."""
     return len(pickle.dumps(batch))
 
@@ -93,6 +96,111 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
     return {}
 
 
+class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
+                        ModelHandler[Tuple[KeyT, ExampleT],
+                                     Tuple[KeyT, PredictionT],
+                                     ModelT]):
+  """A ModelHandler that takes keyed examples and returns keyed predictions.
+
+  For example, if the original model was used with RunInference to take a
+  PCollection[E] to a PCollection[P], this would take a
+  PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], allowing one to
+  associate the outputs with the inputs based on the key.
+  """
+  def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]):
+    self._unkeyed = unkeyed
+
+  def load_model(self) -> ModelT:
+    return self._unkeyed.load_model()
+
+  def run_inference(
+      self, batch: Sequence[Tuple[KeyT, ExampleT]], model: ModelT,
+      **kwargs) -> Iterable[Tuple[KeyT, PredictionT]]:
+    keys, unkeyed_batch = zip(*batch)
+    return zip(
+        keys, self._unkeyed.run_inference(unkeyed_batch, model, **kwargs))
+
+  def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int:
+    keys, unkeyed_batch = zip(*batch)
+    return len(pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)
+
+  def get_metrics_namespace(self) -> str:
+    return self._unkeyed.get_metrics_namespace()
+
+  def get_resource_hints(self):
+    return self._unkeyed.get_resource_hints()
+
+  def batch_elements_kwargs(self):
+    return self._unkeyed.batch_elements_kwargs()
+    return {}
+
+
+class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
+                             ModelHandler[Union[ExampleT, Tuple[KeyT,
+                                                                ExampleT]],
+                                          Union[PredictionT,
+                                                Tuple[KeyT, PredictionT]],
+                                          ModelT]):
+  """A ModelHandler that takes possibly keyed examples and returns possibly
+  keyed predictions.
+
+  For example, if the original model was used with RunInference to take a
+  PCollection[E] to a PCollection[P], this would take either PCollection[E] to a
+  PCollection[P] or PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]],
+  depending on the whether the elements happen to be tuples, allowing one to
+  associate the outputs with the inputs based on the key.
+
+  Note that this cannot be used if E happens to be a tuple type.  In addition,
+  either all examples should be keyed, or none of them.
+  """
+  def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]):
+    self._unkeyed = unkeyed
+
+  def load_model(self) -> ModelT:
+    return self._unkeyed.load_model()
+
+  def run_inference(
+      self,
+      batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
+      model: ModelT,
+      **kwargs
+  ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
+    # Really the input should be
+    #    Union[Sequence[ExampleT], Sequence[Tuple[KeyT, ExampleT]]]
+    # but there's not a good way to express (or check) that.
+    if isinstance(batch[0], tuple):
+      is_keyed = True
+      keys, unkeyed_batch = zip(*batch)  # type: ignore[arg-type]
+    else:
+      is_keyed = False
+      unkeyed_batch = batch  # type: ignore[assignment]
+    unkeyed_results = self._unkeyed.run_inference(
+        unkeyed_batch, model, **kwargs)
+    if is_keyed:
+      return zip(keys, unkeyed_results)
+    else:
+      return unkeyed_results
+
+  def get_num_bytes(
+      self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int:
+    # MyPy can't follow the branching logic.
+    if isinstance(batch[0], tuple):
+      keys, unkeyed_batch = zip(*batch)  # type: ignore[arg-type]
+      return len(
+          pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)
+    else:
+      return self._unkeyed.get_num_bytes(batch)  # type: ignore[arg-type]
+
+  def get_metrics_namespace(self) -> str:
+    return self._unkeyed.get_metrics_namespace()
+
+  def get_resource_hints(self):
+    return self._unkeyed.get_resource_hints()
+
+  def batch_elements_kwargs(self):
+    return self._unkeyed.batch_elements_kwargs()
+
+
 class RunInference(beam.PTransform[beam.PCollection[ExampleT],
                                    beam.PCollection[PredictionT]]):
   """An extensible transform for running inferences.
@@ -205,32 +313,18 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
     self._model = self._load_model()
 
   def process(self, batch, **kwargs):
-    # Process supports both keyed data, and example only data.
-    # First keys and samples are separated (if there are keys)
-    has_keys = isinstance(batch[0], tuple)
-    if has_keys:
-      examples = [example for _, example in batch]
-      keys = [key for key, _ in batch]
-    else:
-      examples = batch
-      keys = None
-
     start_time = _to_microseconds(self._clock.time_ns())
     result_generator = self._model_handler.run_inference(
-        examples, self._model, **kwargs)
+        batch, self._model, **kwargs)
     predictions = list(result_generator)
 
     end_time = _to_microseconds(self._clock.time_ns())
     inference_latency = end_time - start_time
-    num_bytes = self._model_handler.get_num_bytes(examples)
+    num_bytes = self._model_handler.get_num_bytes(batch)
     num_elements = len(batch)
     self._metrics_collector.update(num_elements, num_bytes, inference_latency)
 
-    # Keys are recombined with predictions in the RunInference PTransform.
-    if has_keys:
-      yield from zip(keys, predictions)
-    else:
-      yield from predictions
+    return predictions
 
   def finish_bundle(self):
     # TODO(BEAM-13970): Figure out why there is a cache.
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py
index 3ea2a9db12b..52f8f883203 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -20,7 +20,7 @@
 import pickle
 import unittest
 from typing import Iterable
-from typing import List
+from typing import Sequence
 
 import apache_beam as beam
 from apache_beam.metrics.metric import MetricsFilter
@@ -44,7 +44,7 @@ class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
       self._fake_clock.current_time_ns += 500_000_000  # 500ms
     return FakeModel()
 
-  def run_inference(self, batch: List[int], model: FakeModel,
+  def run_inference(self, batch: Sequence[int], model: FakeModel,
                     **kwargs) -> Iterable[int]:
     if self._fake_clock:
       self._fake_clock.current_time_ns += 3_000_000  # 3 milliseconds
@@ -98,9 +98,27 @@ class RunInferenceBaseTest(unittest.TestCase):
       keyed_examples = [(i, example) for i, example in enumerate(examples)]
       expected = [(i, example + 1) for i, example in enumerate(examples)]
       pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
-      actual = pcoll | base.RunInference(FakeModelHandler())
+      actual = pcoll | base.RunInference(
+          base.KeyedModelHandler(FakeModelHandler()))
       assert_that(actual, equal_to(expected), label='assert:inferences')
 
+  def test_run_inference_impl_with_maybe_keyed_examples(self):
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10]
+      keyed_examples = [(i, example) for i, example in enumerate(examples)]
+      expected = [example + 1 for example in examples]
+      keyed_expected = [(i, example + 1) for i, example in enumerate(examples)]
+      model_handler = base.MaybeKeyedModelHandler(FakeModelHandler())
+
+      pcoll = pipeline | 'Unkeyed' >> beam.Create(examples)
+      actual = pcoll | 'RunUnkeyed' >> base.RunInference(model_handler)
+      assert_that(actual, equal_to(expected), label='CheckUnkeyed')
+
+      keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples)
+      keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference(
+          model_handler)
+      assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
+
   def test_run_inference_impl_kwargs(self):
     with TestPipeline() as pipeline:
       examples = [1, 5, 3, 10]
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index 3a4fb2926f8..1a1afaaf1c8 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -22,7 +22,7 @@ from typing import Any
 from typing import Callable
 from typing import Dict
 from typing import Iterable
-from typing import List
+from typing import Sequence
 from typing import Union
 
 import torch
@@ -87,7 +87,7 @@ class PytorchModelHandler(ModelHandler[torch.Tensor,
 
   def run_inference(
       self,
-      batch: List[Union[torch.Tensor, Dict[str, torch.Tensor]]],
+      batch: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]],
       model: torch.nn.Module,
       **kwargs) -> Iterable[PredictionResult]:
     """
@@ -119,7 +119,7 @@ class PytorchModelHandler(ModelHandler[torch.Tensor,
       predictions = model(batched_tensors, **prediction_params)
     return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
 
-  def get_num_bytes(self, batch: List[torch.Tensor]) -> int:
+  def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
     """Returns the number of bytes of data for a batch of Tensors."""
     # If elements in `batch` are provided as a dictionaries from key to Tensors
     if isinstance(batch[0], dict):
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
index 3c8eddfd7d3..5ca6a18b5d1 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
@@ -20,7 +20,7 @@ import pickle
 import sys
 from typing import Any
 from typing import Iterable
-from typing import List
+from typing import Sequence
 from typing import Union
 
 import numpy
@@ -75,7 +75,7 @@ class SklearnModelHandler(ModelHandler[Union[numpy.ndarray, pandas.DataFrame],
 
   def run_inference(
       self,
-      batch: List[Union[numpy.ndarray, pandas.DataFrame]],
+      batch: Sequence[Union[numpy.ndarray, pandas.DataFrame]],
       model: BaseEstimator,
       **kwargs) -> Iterable[PredictionResult]:
     # TODO(github.com/apache/beam/issues/21769): Use supplied input type hint.
@@ -86,7 +86,7 @@ class SklearnModelHandler(ModelHandler[Union[numpy.ndarray, pandas.DataFrame],
     raise ValueError('Unsupported data type.')
 
   @staticmethod
-  def _predict_np_array(batch: List[numpy.ndarray],
+  def _predict_np_array(batch: Sequence[numpy.ndarray],
                         model: Any) -> Iterable[PredictionResult]:
     # vectorize data for better performance
     vectorized_batch = numpy.stack(batch, axis=0)
@@ -94,7 +94,7 @@ class SklearnModelHandler(ModelHandler[Union[numpy.ndarray, pandas.DataFrame],
     return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
 
   @staticmethod
-  def _predict_pandas_dataframe(batch: List[pandas.DataFrame],
+  def _predict_pandas_dataframe(batch: Sequence[pandas.DataFrame],
                                 model: Any) -> Iterable[PredictionResult]:
     # sklearn_inference currently only supports single rowed dataframes.
     for dataframe in batch:
@@ -113,11 +113,11 @@ class SklearnModelHandler(ModelHandler[Union[numpy.ndarray, pandas.DataFrame],
     ]
 
   def get_num_bytes(
-      self, batch: List[Union[numpy.ndarray, pandas.DataFrame]]) -> int:
+      self, batch: Sequence[Union[numpy.ndarray, pandas.DataFrame]]) -> int:
     """Returns the number of bytes of data for a batch."""
     if isinstance(batch[0], numpy.ndarray):
       return sum(sys.getsizeof(element) for element in batch)
     elif isinstance(batch[0], pandas.DataFrame):
-      data_frames: List[pandas.DataFrame] = batch
+      data_frames: Sequence[pandas.DataFrame] = batch
       return sum(df.memory_usage(deep=True).sum() for df in data_frames)
     raise ValueError('Unsupported data type.')
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
index 91eb86e2de4..0d7294eb406 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
@@ -265,7 +265,7 @@ class SkLearnRunInferenceTest(unittest.TestCase):
 
       pcoll = pipeline | 'start' >> beam.Create(keyed_rows)
       actual = pcoll | api.RunInference(
-          SklearnModelHandler(model_uri=temp_file_name))
+          base.KeyedModelHandler(SklearnModelHandler(model_uri=temp_file_name)))
       expected = [
           ('0', api.PredictionResult(splits[0], 5)),
           ('1', api.PredictionResult(splits[1], 8)),