You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tv...@apache.org on 2022/06/15 12:25:12 UTC
[beam] branch master updated: Sickbay Pytorch example IT test (#21857)
This is an automated email from the ASF dual-hosted git repository.
tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new b1a313e0380 Sickbay Pytorch example IT test (#21857)
b1a313e0380 is described below
commit b1a313e03807882e1bb178c689e4f935246d5534
Author: Anand Inguva <34...@users.noreply.github.com>
AuthorDate: Wed Jun 15 12:25:02 2022 +0000
Sickbay Pytorch example IT test (#21857)
* Remove defaults for the args
* Sickbay the IT test
Sickbay the IT test until the licensing for the model and datasets are approved
* add TODO to github issue
Co-authored-by: tvalentyn <tv...@users.noreply.github.com>
---
.../examples/inference/pytorch_image_classification.py | 14 ++++----------
.../apache_beam/ml/inference/pytorch_inference_it_test.py | 2 ++
2 files changed, 6 insertions(+), 10 deletions(-)
diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
index 070fc80dd76..f1f19d09299 100644
--- a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
+++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
@@ -15,7 +15,7 @@
# limitations under the License.
#
-""""A pipeline that uses RunInference API to perform image classification."""
+"""A pipeline that uses RunInference API to perform image classification."""
import argparse
import io
@@ -30,7 +30,7 @@ from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import KeyedModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
-from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler
+from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from PIL import Image
@@ -74,8 +74,6 @@ def parse_known_args(argv):
parser.add_argument(
'--input',
dest='input',
- default='gs://apache-beam-ml/testing/inputs/'
- 'it_mobilenetv2_imagenet_validation_inputs.txt',
help='Path to the text file containing image names.')
parser.add_argument(
'--output',
@@ -85,10 +83,7 @@ def parse_known_args(argv):
parser.add_argument(
'--model_state_dict_path',
dest='model_state_dict_path',
- default='gs://apache-beam-ml/'
- 'models/imagenet_classification_mobilenet_v2.pt',
- help="Path to the model's state_dict. "
- "Default state_dict would be MobilenetV2.")
+ help="Path to the model's state_dict.")
parser.add_argument(
'--images_dir',
default=None,
@@ -102,7 +97,6 @@ def run(argv=None, model_class=None, model_params=None, save_main_session=True):
Args:
argv: Command line arguments defined for this example.
model_class: Reference to the class definition of the model.
- If None, MobilenetV2 will be used as default .
model_params: Parameters passed to the constructor of the model_class.
These will be used to instantiate the model object in the
RunInference API.
@@ -118,7 +112,7 @@ def run(argv=None, model_class=None, model_params=None, save_main_session=True):
# In this example we pass keyed inputs to RunInference transform.
# Therefore, we use KeyedModelHandler wrapper over PytorchModelHandler.
model_handler = KeyedModelHandler(
- PytorchModelHandler(
+ PytorchModelHandlerTensor(
state_dict_path=known_args.model_state_dict_path,
model_class=model_class,
model_params=model_params))
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py
index fb0a2789be6..a4231f40434 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py
@@ -63,6 +63,8 @@ def process_outputs(filepath):
os.getenv('FORCE_TORCH_IT') is None and torch is None,
'Missing dependencies. '
'Test depends on torch, torchvision, pillow, and transformers')
+# TODO: https://github.com/apache/beam/issues/21859
+@pytest.mark.skip
class PyTorchInference(unittest.TestCase):
@pytest.mark.uses_pytorch
@pytest.mark.it_postcommit