You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tv...@apache.org on 2022/06/13 15:07:18 UTC

[beam] branch master updated: Refactor code according to keyedModelHandler changes (#21819)

This is an automated email from the ASF dual-hosted git repository.

tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 63cd54e2e2b Refactor code according to keyedModelHandler changes (#21819)
63cd54e2e2b is described below

commit 63cd54e2e2b18d6d673adeae72fe4f60a3d8732f
Author: Anand Inguva <34...@users.noreply.github.com>
AuthorDate: Mon Jun 13 15:07:08 2022 +0000

    Refactor code according to keyedModelHandler changes (#21819)
    
    * Refactor code according to keyedModelHandler changes
    
    * Add comments on why keyedModelHandler is used.
---
 .../examples/inference/pytorch_image_classification.py  | 17 +++++++++++------
 sdks/python/apache_beam/ml/inference/base.py            |  1 -
 .../apache_beam/ml/inference/pytorch_inference_test.py  | 16 ++++++++--------
 3 files changed, 19 insertions(+), 15 deletions(-)

diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
index a3ea84ad01b..0509950b114 100644
--- a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
+++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
@@ -27,9 +27,10 @@ from typing import Tuple
 import apache_beam as beam
 import torch
 from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.base import KeyedModelHandler
 from apache_beam.ml.inference.api import PredictionResult
 from apache_beam.ml.inference.api import RunInference
-from apache_beam.ml.inference.pytorch_inference import PytorchModelLoader
+from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler
 from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.options.pipeline_options import SetupOptions
 from PIL import Image
@@ -114,10 +115,13 @@ def run(argv=None, model_class=None, model_params=None, save_main_session=True):
     model_class = MobileNetV2
     model_params = {'num_classes': 1000}
 
-  model_loader = PytorchModelLoader(
-      state_dict_path=known_args.model_state_dict_path,
-      model_class=model_class,
-      model_params=model_params)
+  # In this example we pass keyed inputs to RunInference transform.
+  # Therefore, we use KeyedModelHandler wrapper over PytorchModelHandler.
+  model_handler = KeyedModelHandler(
+      PytorchModelHandler(
+          state_dict_path=known_args.model_state_dict_path,
+          model_class=model_class,
+          model_params=model_params))
 
   with beam.Pipeline(options=pipeline_options) as p:
     filename_value_pair = (
@@ -131,7 +135,8 @@ def run(argv=None, model_class=None, model_params=None, save_main_session=True):
             lambda file_name, data: (file_name, preprocess_image(data))))
     predictions = (
         filename_value_pair
-        | 'PyTorchRunInference' >> RunInference(model_loader).with_output_types(
+        |
+        'PyTorchRunInference' >> RunInference(model_handler).with_output_types(
             Tuple[str, PredictionResult])
         | 'ProcessOutput' >> beam.ParDo(PostProcessor()))
 
diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py
index ae07ac0531e..6d4d54c911d 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -132,7 +132,6 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
 
   def batch_elements_kwargs(self):
     return self._unkeyed.batch_elements_kwargs()
-    return {}
 
 
 class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
index 7f563d7cf4c..604307c3d32 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
@@ -274,7 +274,7 @@ class PytorchRunInferencePipelineTest(unittest.TestCase):
       path = os.path.join(self.tmpdir, 'my_state_dict_path')
       torch.save(state_dict, path)
 
-      model_loader = PytorchModelHandler(
+      model_handler = PytorchModelHandler(
           state_dict_path=path,
           model_class=PytorchLinearRegression,
           model_params={
@@ -282,7 +282,7 @@ class PytorchRunInferencePipelineTest(unittest.TestCase):
           })
 
       pcoll = pipeline | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES)
-      predictions = pcoll | RunInference(model_loader)
+      predictions = pcoll | RunInference(model_handler)
       assert_that(
           predictions,
           equal_to(
@@ -301,7 +301,7 @@ class PytorchRunInferencePipelineTest(unittest.TestCase):
       path = os.path.join(self.tmpdir, 'my_state_dict_path')
       torch.save(state_dict, path)
 
-      model_loader = PytorchModelHandler(
+      model_handler = PytorchModelHandler(
           state_dict_path=path,
           model_class=PytorchLinearRegressionKwargsPredictionParams,
           model_params={
@@ -312,7 +312,7 @@ class PytorchRunInferencePipelineTest(unittest.TestCase):
       prediction_params_side_input = (
           pipeline | 'create side' >> beam.Create(prediction_params))
       predictions = pcoll | RunInference(
-          model_loader=model_loader,
+          model_handler=model_handler,
           prediction_params=beam.pvalue.AsDict(prediction_params_side_input))
       assert_that(
           predictions,
@@ -334,7 +334,7 @@ class PytorchRunInferencePipelineTest(unittest.TestCase):
 
       gs_pth = 'gs://apache-beam-ml/models/' \
           'pytorch_lin_reg_model_2x+0.5_state_dict.pth'
-      model_loader = PytorchModelHandler(
+      model_handler = PytorchModelHandler(
           state_dict_path=gs_pth,
           model_class=PytorchLinearRegression,
           model_params={
@@ -342,7 +342,7 @@ class PytorchRunInferencePipelineTest(unittest.TestCase):
           })
 
       pcoll = pipeline | 'start' >> beam.Create(examples)
-      predictions = pcoll | RunInference(model_loader)
+      predictions = pcoll | RunInference(model_handler)
       assert_that(
           predictions,
           equal_to(expected_predictions, equals_fn=_compare_prediction_result))
@@ -357,7 +357,7 @@ class PytorchRunInferencePipelineTest(unittest.TestCase):
         path = os.path.join(self.tmpdir, 'my_state_dict_path')
         torch.save(state_dict, path)
 
-        model_loader = PytorchModelHandler(
+        model_handler = PytorchModelHandler(
             state_dict_path=path,
             model_class=PytorchLinearRegression,
             model_params={
@@ -366,7 +366,7 @@ class PytorchRunInferencePipelineTest(unittest.TestCase):
 
         pcoll = pipeline | 'start' >> beam.Create(examples)
         # pylint: disable=expression-not-assigned
-        pcoll | RunInference(model_loader)
+        pcoll | RunInference(model_handler)
 
 
 if __name__ == '__main__':