You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tv...@apache.org on 2022/08/05 20:55:39 UTC
[beam] branch master updated: [21894] Validates inference_args early (#22282)
This is an automated email from the ASF dual-hosted git repository.
tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 6a021844fe4 [21894] Validates inference_args early (#22282)
6a021844fe4 is described below
commit 6a021844fe4bce75a0371ba1ea97f11606afd756
Author: Ryan Thompson <ry...@gmail.com>
AuthorDate: Fri Aug 5 16:55:25 2022 -0400
[21894] Validates inference_args early (#22282)
Co-authored-by: Andy Ye <an...@gmail.com>
---
sdks/python/apache_beam/ml/inference/base.py | 18 ++++++++++++++++
sdks/python/apache_beam/ml/inference/base_test.py | 25 ++++++++++++++++++++--
.../apache_beam/ml/inference/pytorch_inference.py | 6 ++++++
.../apache_beam/ml/inference/sklearn_inference.py | 16 --------------
.../ml/inference/sklearn_inference_test.py | 9 --------
5 files changed, 47 insertions(+), 27 deletions(-)
diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py
index 5bb45d77787..075260ec3ec 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -134,6 +134,17 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
"""
return {}
+ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+ """Validates inference_args passed in the inference call.
+
+ Most frameworks do not need extra arguments in their predict() call so the
+ default behavior is to error out if inference_args are present.
+ """
+ if inference_args:
+ raise ValueError(
+ 'inference_args were provided, but should be None because this '
+ 'framework does not expect extra arguments on inferences.')
+
class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
ModelHandler[Tuple[KeyT, ExampleT],
@@ -178,6 +189,9 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
def batch_elements_kwargs(self):
return self._unkeyed.batch_elements_kwargs()
+ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+ return self._unkeyed.validate_inference_args(inference_args)
+
class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
ModelHandler[Union[ExampleT, Tuple[KeyT,
@@ -248,6 +262,9 @@ class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
def batch_elements_kwargs(self):
return self._unkeyed.batch_elements_kwargs()
+ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+ return self._unkeyed.validate_inference_args(inference_args)
+
class RunInference(beam.PTransform[beam.PCollection[ExampleT],
beam.PCollection[PredictionT]]):
@@ -297,6 +314,7 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT],
# handled.
def expand(
self, pcoll: beam.PCollection[ExampleT]) -> beam.PCollection[PredictionT]:
+ self._model_handler.validate_inference_args(self._inference_args)
resource_hints = self._model_handler.get_resource_hints()
return (
pcoll
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py
index 98fc2523b6d..ca79a3cd3a3 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -79,12 +79,22 @@ class FakeModelHandlerNeedsBigBatch(FakeModelHandler):
return {'min_batch_size': 9999}
-class FakeModelHandlerExtraInferenceArgs(FakeModelHandler):
+class FakeModelHandlerFailsOnInferenceArgs(FakeModelHandler):
+ def run_inference(self, batch, unused_model, inference_args=None):
+ raise ValueError(
+ 'run_inference should not be called because error should already be '
+ 'thrown from the validate_inference_args check.')
+
+
+class FakeModelHandlerExpectedInferenceArgs(FakeModelHandler):
def run_inference(self, batch, unused_model, inference_args=None):
if not inference_args:
raise ValueError('inference_args should exist')
return batch
+ def validate_inference_args(self, inference_args):
+ pass
+
class RunInferenceBaseTest(unittest.TestCase):
def test_run_inference_impl_simple_examples(self):
@@ -128,9 +138,20 @@ class RunInferenceBaseTest(unittest.TestCase):
pcoll = pipeline | 'start' >> beam.Create(examples)
inference_args = {'key': True}
actual = pcoll | base.RunInference(
- FakeModelHandlerExtraInferenceArgs(), inference_args=inference_args)
+ FakeModelHandlerExpectedInferenceArgs(),
+ inference_args=inference_args)
assert_that(actual, equal_to(examples), label='assert:inferences')
+ def test_unexpected_inference_args_passed(self):
+ with self.assertRaisesRegex(ValueError, r'inference_args were provided'):
+ with TestPipeline() as pipeline:
+ examples = [1, 5, 3, 10]
+ pcoll = pipeline | 'start' >> beam.Create(examples)
+ inference_args = {'key': True}
+ _ = pcoll | base.RunInference(
+ FakeModelHandlerFailsOnInferenceArgs(),
+ inference_args=inference_args)
+
def test_counted_metrics(self):
pipeline = TestPipeline()
examples = [1, 5, 3, 10]
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index e25048e140d..945ad7e2ec2 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -150,6 +150,9 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
"""
return 'RunInferencePytorch'
+ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+ pass
+
@experimental(extra_message="No backwards-compatibility guarantees.")
class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
@@ -257,3 +260,6 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
A namespace for metrics collected by the RunInference transform.
"""
return 'RunInferencePytorch'
+
+ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+ pass
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
index 63038ec46ce..1338a5abe33 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
@@ -66,20 +66,6 @@ def _load_model(model_uri, file_type):
raise AssertionError('Unsupported serialization type.')
-def _validate_inference_args(inference_args):
- """Confirms that inference_args is None.
-
- scikit-learn models do not need extra arguments in their predict() call.
- However, since inference_args is an argument in the RunInference interface,
- we want to make sure it is not passed here in Sklearn's implementation of
- RunInference.
- """
- if inference_args:
- raise ValueError(
- 'inference_args were provided, but should be None because scikit-learn '
- 'models do not need extra arguments in their predict() call.')
-
-
class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
PredictionResult,
BaseEstimator]):
@@ -124,7 +110,6 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
Returns:
An Iterable of type PredictionResult.
"""
- _validate_inference_args(inference_args)
# vectorize data for better performance
vectorized_batch = numpy.stack(batch, axis=0)
predictions = model.predict(vectorized_batch)
@@ -187,7 +172,6 @@ class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
Returns:
An Iterable of type PredictionResult.
"""
- _validate_inference_args(inference_args)
# sklearn_inference currently only supports single rowed dataframes.
for dataframe in iter(batch):
if dataframe.shape[0] != 1:
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
index 978c3a8934d..72c513049ac 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
@@ -317,15 +317,6 @@ class SkLearnRunInferenceTest(unittest.TestCase):
inference_runner = SklearnModelHandlerPandas(model_uri='unused')
inference_runner.run_inference([data_frame_too_many_rows], fake_model)
- def test_inference_args_passed(self):
- with self.assertRaisesRegex(ValueError, r'inference_args were provided'):
- data_frame = pandas_dataframe()
- fake_model = FakeModel()
- inference_runner = SklearnModelHandlerPandas(model_uri='unused')
- inference_runner.run_inference([data_frame],
- fake_model,
- inference_args={'key1': 'value1'})
-
if __name__ == '__main__':
unittest.main()