You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tv...@apache.org on 2022/08/05 20:55:39 UTC

[beam] branch master updated: [21894] Validates inference_args early (#22282)

This is an automated email from the ASF dual-hosted git repository.

tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 6a021844fe4 [21894] Validates inference_args early (#22282)
6a021844fe4 is described below

commit 6a021844fe4bce75a0371ba1ea97f11606afd756
Author: Ryan Thompson <ry...@gmail.com>
AuthorDate: Fri Aug 5 16:55:25 2022 -0400

    [21894] Validates inference_args early (#22282)
    
    Co-authored-by: Andy Ye <an...@gmail.com>
---
 sdks/python/apache_beam/ml/inference/base.py       | 18 ++++++++++++++++
 sdks/python/apache_beam/ml/inference/base_test.py  | 25 ++++++++++++++++++++--
 .../apache_beam/ml/inference/pytorch_inference.py  |  6 ++++++
 .../apache_beam/ml/inference/sklearn_inference.py  | 16 --------------
 .../ml/inference/sklearn_inference_test.py         |  9 --------
 5 files changed, 47 insertions(+), 27 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py
index 5bb45d77787..075260ec3ec 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -134,6 +134,17 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
     """
     return {}
 
+  def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+    """Validates inference_args passed in the inference call.
+
+    Most frameworks do not need extra arguments in their predict() call so the
+    default behavior is to error out if inference_args are present.
+    """
+    if inference_args:
+      raise ValueError(
+          'inference_args were provided, but should be None because this '
+          'framework does not expect extra arguments on inferences.')
+
 
 class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
                         ModelHandler[Tuple[KeyT, ExampleT],
@@ -178,6 +189,9 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
   def batch_elements_kwargs(self):
     return self._unkeyed.batch_elements_kwargs()
 
+  def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+    return self._unkeyed.validate_inference_args(inference_args)
+
 
 class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
                              ModelHandler[Union[ExampleT, Tuple[KeyT,
@@ -248,6 +262,9 @@ class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
   def batch_elements_kwargs(self):
     return self._unkeyed.batch_elements_kwargs()
 
+  def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+    return self._unkeyed.validate_inference_args(inference_args)
+
 
 class RunInference(beam.PTransform[beam.PCollection[ExampleT],
                                    beam.PCollection[PredictionT]]):
@@ -297,6 +314,7 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT],
   # handled.
   def expand(
       self, pcoll: beam.PCollection[ExampleT]) -> beam.PCollection[PredictionT]:
+    self._model_handler.validate_inference_args(self._inference_args)
     resource_hints = self._model_handler.get_resource_hints()
     return (
         pcoll
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py
index 98fc2523b6d..ca79a3cd3a3 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -79,12 +79,22 @@ class FakeModelHandlerNeedsBigBatch(FakeModelHandler):
     return {'min_batch_size': 9999}
 
 
-class FakeModelHandlerExtraInferenceArgs(FakeModelHandler):
+class FakeModelHandlerFailsOnInferenceArgs(FakeModelHandler):
+  def run_inference(self, batch, unused_model, inference_args=None):
+    raise ValueError(
+        'run_inference should not be called because error should already be '
+        'thrown from the validate_inference_args check.')
+
+
+class FakeModelHandlerExpectedInferenceArgs(FakeModelHandler):
   def run_inference(self, batch, unused_model, inference_args=None):
     if not inference_args:
       raise ValueError('inference_args should exist')
     return batch
 
+  def validate_inference_args(self, inference_args):
+    pass
+
 
 class RunInferenceBaseTest(unittest.TestCase):
   def test_run_inference_impl_simple_examples(self):
@@ -128,9 +138,20 @@ class RunInferenceBaseTest(unittest.TestCase):
       pcoll = pipeline | 'start' >> beam.Create(examples)
       inference_args = {'key': True}
       actual = pcoll | base.RunInference(
-          FakeModelHandlerExtraInferenceArgs(), inference_args=inference_args)
+          FakeModelHandlerExpectedInferenceArgs(),
+          inference_args=inference_args)
       assert_that(actual, equal_to(examples), label='assert:inferences')
 
+  def test_unexpected_inference_args_passed(self):
+    with self.assertRaisesRegex(ValueError, r'inference_args were provided'):
+      with TestPipeline() as pipeline:
+        examples = [1, 5, 3, 10]
+        pcoll = pipeline | 'start' >> beam.Create(examples)
+        inference_args = {'key': True}
+        _ = pcoll | base.RunInference(
+            FakeModelHandlerFailsOnInferenceArgs(),
+            inference_args=inference_args)
+
   def test_counted_metrics(self):
     pipeline = TestPipeline()
     examples = [1, 5, 3, 10]
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index e25048e140d..945ad7e2ec2 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -150,6 +150,9 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
     """
     return 'RunInferencePytorch'
 
+  def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+    pass
+
 
 @experimental(extra_message="No backwards-compatibility guarantees.")
 class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
@@ -257,3 +260,6 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
        A namespace for metrics collected by the RunInference transform.
     """
     return 'RunInferencePytorch'
+
+  def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+    pass
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
index 63038ec46ce..1338a5abe33 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
@@ -66,20 +66,6 @@ def _load_model(model_uri, file_type):
   raise AssertionError('Unsupported serialization type.')
 
 
-def _validate_inference_args(inference_args):
-  """Confirms that inference_args is None.
-
-  scikit-learn models do not need extra arguments in their predict() call.
-  However, since inference_args is an argument in the RunInference interface,
-  we want to make sure it is not passed here in Sklearn's implementation of
-  RunInference.
-  """
-  if inference_args:
-    raise ValueError(
-        'inference_args were provided, but should be None because scikit-learn '
-        'models do not need extra arguments in their predict() call.')
-
-
 class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
                                             PredictionResult,
                                             BaseEstimator]):
@@ -124,7 +110,6 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
     Returns:
       An Iterable of type PredictionResult.
     """
-    _validate_inference_args(inference_args)
     # vectorize data for better performance
     vectorized_batch = numpy.stack(batch, axis=0)
     predictions = model.predict(vectorized_batch)
@@ -187,7 +172,6 @@ class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
     Returns:
       An Iterable of type PredictionResult.
     """
-    _validate_inference_args(inference_args)
     # sklearn_inference currently only supports single rowed dataframes.
     for dataframe in iter(batch):
       if dataframe.shape[0] != 1:
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
index 978c3a8934d..72c513049ac 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
@@ -317,15 +317,6 @@ class SkLearnRunInferenceTest(unittest.TestCase):
       inference_runner = SklearnModelHandlerPandas(model_uri='unused')
       inference_runner.run_inference([data_frame_too_many_rows], fake_model)
 
-  def test_inference_args_passed(self):
-    with self.assertRaisesRegex(ValueError, r'inference_args were provided'):
-      data_frame = pandas_dataframe()
-      fake_model = FakeModel()
-      inference_runner = SklearnModelHandlerPandas(model_uri='unused')
-      inference_runner.run_inference([data_frame],
-                                     fake_model,
-                                     inference_args={'key1': 'value1'})
-
 
 if __name__ == '__main__':
   unittest.main()