You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tv...@apache.org on 2022/06/06 18:22:07 UTC
[beam] branch master updated: [BEAM-14068]Add Pytorch inference IT test and example (#17462)
This is an automated email from the ASF dual-hosted git repository.
tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 7445b9535a9 [BEAM-14068]Add Pytorch inference IT test and example (#17462)
7445b9535a9 is described below
commit 7445b9535a935757fc6d84d4138c4a7d772d5637
Author: Anand Inguva <34...@users.noreply.github.com>
AuthorDate: Mon Jun 6 18:22:01 2022 +0000
[BEAM-14068]Add Pytorch inference IT test and example (#17462)
Co-authored-by: tvalentyn <tv...@users.noreply.github.com>
---
build.gradle.kts | 7 +-
.../apache_beam/examples/inference/__init__.py | 16 +++
.../inference/pytorch_image_classification.py | 146 +++++++++++++++++++++
.../ml/inference/pytorch_inference_it_test.py | 95 ++++++++++++++
.../ml/inference/torch_tests_requirements.txt | 20 +++
sdks/python/scripts/run_integration_test.sh | 19 ++-
sdks/python/test-suites/dataflow/build.gradle | 2 +-
sdks/python/test-suites/direct/common.gradle | 33 +++++
8 files changed, 327 insertions(+), 11 deletions(-)
diff --git a/build.gradle.kts b/build.gradle.kts
index fd8d8714177..c6c0059f62a 100644
--- a/build.gradle.kts
+++ b/build.gradle.kts
@@ -320,6 +320,7 @@ tasks.register("python37PostCommit") {
dependsOn(":sdks:python:test-suites:dataflow:py37:spannerioIT")
dependsOn(":sdks:python:test-suites:direct:py37:spannerioIT")
dependsOn(":sdks:python:test-suites:portable:py37:xlangSpannerIOIT")
+ dependsOn(":sdks:python:test-suites:direct:py37:inferencePostCommitIT")
}
tasks.register("python38PostCommit") {
@@ -334,11 +335,7 @@ tasks.register("python39PostCommit") {
dependsOn(":sdks:python:test-suites:direct:py39:postCommitIT")
dependsOn(":sdks:python:test-suites:direct:py39:hdfsIntegrationTest")
dependsOn(":sdks:python:test-suites:portable:py39:postCommitPy39")
-}
-
-
-task("python36SickbayPostCommit") {
- dependsOn(":sdks:python:test-suites:dataflow:py36:postCommitSickbay")
+ dependsOn(":sdks:python:test-suites:direct:py39:inferencePostCommitIT")
}
task("python37SickbayPostCommit") {
diff --git a/sdks/python/apache_beam/examples/inference/__init__.py b/sdks/python/apache_beam/examples/inference/__init__.py
new file mode 100644
index 00000000000..cce3acad34a
--- /dev/null
+++ b/sdks/python/apache_beam/examples/inference/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
new file mode 100644
index 00000000000..a3ea84ad01b
--- /dev/null
+++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+""""A pipeline that uses RunInference API to perform image classification."""
+
+import argparse
+import io
+import os
+from typing import Iterable
+from typing import Optional
+from typing import Tuple
+
+import apache_beam as beam
+import torch
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.api import PredictionResult
+from apache_beam.ml.inference.api import RunInference
+from apache_beam.ml.inference.pytorch_inference import PytorchModelLoader
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from PIL import Image
+from torchvision import transforms
+from torchvision.models.mobilenetv2 import MobileNetV2
+
+
+def read_image(image_file_name: str,
+ path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]:
+ if path_to_dir is not None:
+ image_file_name = os.path.join(path_to_dir, image_file_name)
+ with FileSystems().open(image_file_name, 'r') as file:
+ data = Image.open(io.BytesIO(file.read())).convert('RGB')
+ return image_file_name, data
+
+
+def preprocess_image(data: Image.Image) -> torch.Tensor:
+ image_size = (224, 224)
+ # Pre-trained PyTorch models expect input images normalized with the
+ # below values (see: https://pytorch.org/vision/stable/models.html)
+ normalize = transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ transform = transforms.Compose([
+ transforms.Resize(image_size),
+ transforms.ToTensor(),
+ normalize,
+ ])
+ return transform(data)
+
+
+class PostProcessor(beam.DoFn):
+ def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]:
+ filename, prediction_result = element
+ prediction = torch.argmax(prediction_result.inference, dim=0)
+ yield filename + ',' + str(prediction.item())
+
+
+def parse_known_args(argv):
+ """Parses args for the workflow."""
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--input',
+ dest='input',
+ default='gs://apache-beam-ml/testing/inputs/'
+ 'it_mobilenetv2_imagenet_validation_inputs.txt',
+ help='Path to the text file containing image names.')
+ parser.add_argument(
+ '--output',
+ dest='output',
+ help='Path where to save output predictions.'
+ ' text file.')
+ parser.add_argument(
+ '--model_state_dict_path',
+ dest='model_state_dict_path',
+ default='gs://apache-beam-ml/'
+ 'models/imagenet_classification_mobilenet_v2.pt',
+ help="Path to the model's state_dict. "
+ "Default state_dict would be MobilenetV2.")
+ parser.add_argument(
+ '--images_dir',
+ default=None,
+ help='Path to the directory where images are stored.'
+ 'Not required if image names in the input file have absolute path.')
+ return parser.parse_known_args(argv)
+
+
+def run(argv=None, model_class=None, model_params=None, save_main_session=True):
+ """
+ Args:
+ argv: Command line arguments defined for this example.
+ model_class: Reference to the class definition of the model.
+ If None, MobilenetV2 will be used as default .
+ model_params: Parameters passed to the constructor of the model_class.
+ These will be used to instantiate the model object in the
+ RunInference API.
+ """
+ known_args, pipeline_args = parse_known_args(argv)
+ pipeline_options = PipelineOptions(pipeline_args)
+ pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
+
+ if not model_class:
+ model_class = MobileNetV2
+ model_params = {'num_classes': 1000}
+
+ model_loader = PytorchModelLoader(
+ state_dict_path=known_args.model_state_dict_path,
+ model_class=model_class,
+ model_params=model_params)
+
+ with beam.Pipeline(options=pipeline_options) as p:
+ filename_value_pair = (
+ p
+ | 'ReadImageNames' >> beam.io.ReadFromText(
+ known_args.input, skip_header_lines=1)
+ | 'ReadImageData' >> beam.Map(
+ lambda image_name: read_image(
+ image_file_name=image_name, path_to_dir=known_args.images_dir))
+ | 'PreprocessImages' >> beam.MapTuple(
+ lambda file_name, data: (file_name, preprocess_image(data))))
+ predictions = (
+ filename_value_pair
+ | 'PyTorchRunInference' >> RunInference(model_loader).with_output_types(
+ Tuple[str, PredictionResult])
+ | 'ProcessOutput' >> beam.ParDo(PostProcessor()))
+
+ if known_args.output:
+ predictions | "WriteOutputToGCS" >> beam.io.WriteToText( # pylint: disable=expression-not-assigned
+ known_args.output,
+ shard_name_template='',
+ append_trailing_newlines=True)
+
+
+if __name__ == '__main__':
+ run()
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py
new file mode 100644
index 00000000000..066d667b786
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py
@@ -0,0 +1,95 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pylint: skip-file
+
+"""End-to-End test for Pytorch Inference"""
+
+import logging
+import os
+import unittest
+import uuid
+
+import pytest
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.testing.test_pipeline import TestPipeline
+
+try:
+ import torch
+ from apache_beam.examples.inference import pytorch_image_classification
+except ImportError as e:
+ torch = None
+
+_EXPECTED_OUTPUTS = {
+ 'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005001.JPEG': '681',
+ 'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005002.JPEG': '333',
+ 'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005003.JPEG': '711',
+ 'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005004.JPEG': '286',
+ 'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005005.JPEG': '433',
+ 'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005006.JPEG': '290',
+ 'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005007.JPEG': '890',
+ 'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005008.JPEG': '592',
+ 'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005009.JPEG': '406',
+ 'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005010.JPEG': '996',
+ 'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005011.JPEG': '327',
+ 'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005012.JPEG': '573'
+}
+
+
+def process_outputs(filepath):
+ with FileSystems().open(filepath) as f:
+ lines = f.readlines()
+ lines = [l.decode('utf-8').strip('\n') for l in lines]
+ return lines
+
+
+@unittest.skipIf(
+ os.getenv('FORCE_TORCH_IT') is None and torch is None,
+ 'Missing dependencies. '
+ 'Test depends on torch, torchvision and pillow')
+class PyTorchInference(unittest.TestCase):
+ @pytest.mark.uses_pytorch
+ @pytest.mark.it_postcommit
+ def test_torch_run_inference_imagenet_mobilenetv2(self):
+ test_pipeline = TestPipeline(is_integration_test=True)
+ # text files containing absolute path to the imagenet validation data on GCS
+ file_of_image_names = 'gs://apache-beam-ml/testing/inputs/it_mobilenetv2_imagenet_validation_inputs.txt' # disable: line-too-long
+ output_file_dir = 'gs://apache-beam-ml/testing/predictions'
+ output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt'])
+
+ model_state_dict_path = 'gs://apache-beam-ml/models/imagenet_classification_mobilenet_v2.pt'
+ extra_opts = {
+ 'input': file_of_image_names,
+ 'output': output_file,
+ 'model_state_dict_path': model_state_dict_path,
+ }
+ pytorch_image_classification.run(
+ test_pipeline.get_full_options_as_args(**extra_opts),
+ save_main_session=False)
+
+ self.assertEqual(FileSystems().exists(output_file), True)
+ predictions = process_outputs(filepath=output_file)
+
+ for prediction in predictions:
+ filename, prediction = prediction.split(',')
+ self.assertEqual(_EXPECTED_OUTPUTS[filename], prediction)
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.DEBUG)
+ unittest.main()
diff --git a/sdks/python/apache_beam/ml/inference/torch_tests_requirements.txt b/sdks/python/apache_beam/ml/inference/torch_tests_requirements.txt
new file mode 100644
index 00000000000..b1ac63c314c
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/torch_tests_requirements.txt
@@ -0,0 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+torch>=1.7.1
+torchvision>=0.8.2
+pillow>=8.0.0
diff --git a/sdks/python/scripts/run_integration_test.sh b/sdks/python/scripts/run_integration_test.sh
index 7b42e676129..da942bd30e6 100755
--- a/sdks/python/scripts/run_integration_test.sh
+++ b/sdks/python/scripts/run_integration_test.sh
@@ -79,6 +79,7 @@ WORKER_JAR=""
KMS_KEY_NAME="projects/apache-beam-testing/locations/global/keyRings/beam-it/cryptoKeys/test"
SUITE=""
COLLECT_MARKERS=
+REQUIREMENTS_FILE=""
# Default test (pytest) options.
# Run WordCountIT.test_wordcount_it by default if no test options are
@@ -114,6 +115,11 @@ case $key in
shift # past argument
shift # past value
;;
+ --requirements_file)
+ REQUIREMENTS_FILE="$2"
+ shift # past argument
+ shift # past value
+ ;;
--num_workers)
NUM_WORKERS="$2"
shift # past argument
@@ -202,7 +208,6 @@ fi
# Build pipeline options if not provided in --pipeline_opts from commandline
if [[ -z $PIPELINE_OPTS ]]; then
-
# Get tar ball path
if [[ $(find ${SDK_LOCATION} 2> /dev/null) ]]; then
SDK_LOCATION=$(find ${SDK_LOCATION} | tail -n1)
@@ -213,9 +218,13 @@ if [[ -z $PIPELINE_OPTS ]]; then
# Install test dependencies for ValidatesRunner tests.
# pyhamcrest==1.10.0 doesn't work on Py2.
# See: https://github.com/hamcrest/PyHamcrest/issues/131.
- echo "pyhamcrest!=1.10.0,<2.0.0" > postcommit_requirements.txt
- echo "mock<3.0.0" >> postcommit_requirements.txt
- echo "parameterized>=0.7.1,<0.8.0" >> postcommit_requirements.txt
+ if [[ -z $REQUIREMENTS_FILE ]]; then
+ echo "pyhamcrest!=1.10.0,<2.0.0" > postcommit_requirements.txt
+ echo "mock<3.0.0" >> postcommit_requirements.txt
+ echo "parameterized>=0.7.1,<0.8.0" >> postcommit_requirements.txt
+ else
+ cp $REQUIREMENTS_FILE postcommit_requirements.txt
+ fi
# Options used to run testing pipeline on Cloud Dataflow Service. Also used for
# running on DirectRunner (some options ignored).
@@ -283,4 +292,4 @@ if [ -z "$COLLECT_MARKERS" ]; then
pytest $ARGS --test-pipeline-options="$PIPELINE_OPTS"
else
pytest $ARGS --test-pipeline-options="$PIPELINE_OPTS" "$COLLECT_MARKERS"
-fi
\ No newline at end of file
+fi
diff --git a/sdks/python/test-suites/dataflow/build.gradle b/sdks/python/test-suites/dataflow/build.gradle
index 036159eede3..d16111679da 100644
--- a/sdks/python/test-suites/dataflow/build.gradle
+++ b/sdks/python/test-suites/dataflow/build.gradle
@@ -83,4 +83,4 @@ task examplesPostCommit {
getVersionsAsList('dataflow_examples_postcommit_py_versions').each {
dependsOn.add(":sdks:python:test-suites:dataflow:py${getVersionSuffix(it)}:examples")
}
-}
\ No newline at end of file
+}
diff --git a/sdks/python/test-suites/direct/common.gradle b/sdks/python/test-suites/direct/common.gradle
index 15df9696272..f08655e0281 100644
--- a/sdks/python/test-suites/direct/common.gradle
+++ b/sdks/python/test-suites/direct/common.gradle
@@ -185,3 +185,36 @@ tasks.register("hdfsIntegrationTest") {
}
}
}
+
+// Pytorch RunInference IT tests
+task torchTests {
+ dependsOn 'installGcpTest'
+ dependsOn ':sdks:python:sdist'
+ def requirementsFile = "${rootDir}/sdks/python/apache_beam/ml/inference/torch_tests_requirements.txt"
+ doFirst {
+ exec {
+ executable 'sh'
+ args '-c', ". ${envdir}/bin/activate && pip install -r $requirementsFile"
+ }
+ }
+ doLast {
+ def testOpts = basicTestOpts
+ def argMap = [
+ "test_opts": testOpts,
+ "suite": "postCommitIT-direct-py${pythonVersionSuffix}",
+ "collect": "uses_pytorch and it_postcommit",
+ "runner": "TestDirectRunner"
+ ]
+ def cmdArgs = mapToArgString(argMap)
+ exec {
+ executable 'sh'
+ args '-c', ". ${envdir}/bin/activate && export FORCE_TORCH_IT=1 && ${runScriptsDir}/run_integration_test.sh $cmdArgs"
+ }
+ }
+}
+
+// Add all the RunInference framework IT tests to this gradle task that runs on Direct Runner Post commit suite.
+// TODO(anandinguva): Add sklearn IT test here
+project.tasks.register("inferencePostCommitIT") {
+ dependsOn = ['torchTests']
+}