You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tv...@apache.org on 2022/06/15 12:25:12 UTC

[beam] branch master updated: Sickbay Pytorch example IT test (#21857)

This is an automated email from the ASF dual-hosted git repository.

tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new b1a313e0380 Sickbay Pytorch example IT test (#21857)
b1a313e0380 is described below

commit b1a313e03807882e1bb178c689e4f935246d5534
Author: Anand Inguva <34...@users.noreply.github.com>
AuthorDate: Wed Jun 15 12:25:02 2022 +0000

    Sickbay Pytorch example IT test (#21857)
    
    * Remove defaults for the args
    
    * Sickbay the IT test
    
    Sickbay the IT test until the licensing for the model and datasets are approved
    
    * add TODO to github issue
    
    Co-authored-by: tvalentyn <tv...@users.noreply.github.com>
---
 .../examples/inference/pytorch_image_classification.py     | 14 ++++----------
 .../apache_beam/ml/inference/pytorch_inference_it_test.py  |  2 ++
 2 files changed, 6 insertions(+), 10 deletions(-)

diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
index 070fc80dd76..f1f19d09299 100644
--- a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
+++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 #
 
-""""A pipeline that uses RunInference API to perform image classification."""
+"""A pipeline that uses RunInference API to perform image classification."""
 
 import argparse
 import io
@@ -30,7 +30,7 @@ from apache_beam.io.filesystems import FileSystems
 from apache_beam.ml.inference.base import KeyedModelHandler
 from apache_beam.ml.inference.base import PredictionResult
 from apache_beam.ml.inference.base import RunInference
-from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler
+from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
 from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.options.pipeline_options import SetupOptions
 from PIL import Image
@@ -74,8 +74,6 @@ def parse_known_args(argv):
   parser.add_argument(
       '--input',
       dest='input',
-      default='gs://apache-beam-ml/testing/inputs/'
-      'it_mobilenetv2_imagenet_validation_inputs.txt',
       help='Path to the text file containing image names.')
   parser.add_argument(
       '--output',
@@ -85,10 +83,7 @@ def parse_known_args(argv):
   parser.add_argument(
       '--model_state_dict_path',
       dest='model_state_dict_path',
-      default='gs://apache-beam-ml/'
-      'models/imagenet_classification_mobilenet_v2.pt',
-      help="Path to the model's state_dict. "
-      "Default state_dict would be MobilenetV2.")
+      help="Path to the model's state_dict.")
   parser.add_argument(
       '--images_dir',
       default=None,
@@ -102,7 +97,6 @@ def run(argv=None, model_class=None, model_params=None, save_main_session=True):
   Args:
     argv: Command line arguments defined for this example.
     model_class: Reference to the class definition of the model.
-                If None, MobilenetV2 will be used as default .
     model_params: Parameters passed to the constructor of the model_class.
                   These will be used to instantiate the model object in the
                   RunInference API.
@@ -118,7 +112,7 @@ def run(argv=None, model_class=None, model_params=None, save_main_session=True):
   # In this example we pass keyed inputs to RunInference transform.
   # Therefore, we use KeyedModelHandler wrapper over PytorchModelHandler.
   model_handler = KeyedModelHandler(
-      PytorchModelHandler(
+      PytorchModelHandlerTensor(
           state_dict_path=known_args.model_state_dict_path,
           model_class=model_class,
           model_params=model_params))
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py
index fb0a2789be6..a4231f40434 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py
@@ -63,6 +63,8 @@ def process_outputs(filepath):
     os.getenv('FORCE_TORCH_IT') is None and torch is None,
     'Missing dependencies. '
     'Test depends on torch, torchvision, pillow, and transformers')
+# TODO: https://github.com/apache/beam/issues/21859
+@pytest.mark.skip
 class PyTorchInference(unittest.TestCase):
   @pytest.mark.uses_pytorch
   @pytest.mark.it_postcommit