You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tv...@apache.org on 2022/06/06 18:22:07 UTC

[beam] branch master updated: [BEAM-14068]Add Pytorch inference IT test and example (#17462)

This is an automated email from the ASF dual-hosted git repository.

tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 7445b9535a9 [BEAM-14068]Add Pytorch inference IT test and example (#17462)
7445b9535a9 is described below

commit 7445b9535a935757fc6d84d4138c4a7d772d5637
Author: Anand Inguva <34...@users.noreply.github.com>
AuthorDate: Mon Jun 6 18:22:01 2022 +0000

    [BEAM-14068]Add Pytorch inference IT test and example (#17462)
    
    Co-authored-by: tvalentyn <tv...@users.noreply.github.com>
---
 build.gradle.kts                                   |   7 +-
 .../apache_beam/examples/inference/__init__.py     |  16 +++
 .../inference/pytorch_image_classification.py      | 146 +++++++++++++++++++++
 .../ml/inference/pytorch_inference_it_test.py      |  95 ++++++++++++++
 .../ml/inference/torch_tests_requirements.txt      |  20 +++
 sdks/python/scripts/run_integration_test.sh        |  19 ++-
 sdks/python/test-suites/dataflow/build.gradle      |   2 +-
 sdks/python/test-suites/direct/common.gradle       |  33 +++++
 8 files changed, 327 insertions(+), 11 deletions(-)

diff --git a/build.gradle.kts b/build.gradle.kts
index fd8d8714177..c6c0059f62a 100644
--- a/build.gradle.kts
+++ b/build.gradle.kts
@@ -320,6 +320,7 @@ tasks.register("python37PostCommit") {
   dependsOn(":sdks:python:test-suites:dataflow:py37:spannerioIT")
   dependsOn(":sdks:python:test-suites:direct:py37:spannerioIT")
   dependsOn(":sdks:python:test-suites:portable:py37:xlangSpannerIOIT")
+  dependsOn(":sdks:python:test-suites:direct:py37:inferencePostCommitIT")
 }
 
 tasks.register("python38PostCommit") {
@@ -334,11 +335,7 @@ tasks.register("python39PostCommit") {
   dependsOn(":sdks:python:test-suites:direct:py39:postCommitIT")
   dependsOn(":sdks:python:test-suites:direct:py39:hdfsIntegrationTest")
   dependsOn(":sdks:python:test-suites:portable:py39:postCommitPy39")
-}
-
-
-task("python36SickbayPostCommit") {
-  dependsOn(":sdks:python:test-suites:dataflow:py36:postCommitSickbay")
+  dependsOn(":sdks:python:test-suites:direct:py39:inferencePostCommitIT")
 }
 
 task("python37SickbayPostCommit") {
diff --git a/sdks/python/apache_beam/examples/inference/__init__.py b/sdks/python/apache_beam/examples/inference/__init__.py
new file mode 100644
index 00000000000..cce3acad34a
--- /dev/null
+++ b/sdks/python/apache_beam/examples/inference/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
new file mode 100644
index 00000000000..a3ea84ad01b
--- /dev/null
+++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+""""A pipeline that uses RunInference API to perform image classification."""
+
+import argparse
+import io
+import os
+from typing import Iterable
+from typing import Optional
+from typing import Tuple
+
+import apache_beam as beam
+import torch
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.api import PredictionResult
+from apache_beam.ml.inference.api import RunInference
+from apache_beam.ml.inference.pytorch_inference import PytorchModelLoader
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from PIL import Image
+from torchvision import transforms
+from torchvision.models.mobilenetv2 import MobileNetV2
+
+
+def read_image(image_file_name: str,
+               path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]:
+  if path_to_dir is not None:
+    image_file_name = os.path.join(path_to_dir, image_file_name)
+  with FileSystems().open(image_file_name, 'r') as file:
+    data = Image.open(io.BytesIO(file.read())).convert('RGB')
+    return image_file_name, data
+
+
+def preprocess_image(data: Image.Image) -> torch.Tensor:
+  image_size = (224, 224)
+  # Pre-trained PyTorch models expect input images normalized with the
+  # below values (see: https://pytorch.org/vision/stable/models.html)
+  normalize = transforms.Normalize(
+      mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+  transform = transforms.Compose([
+      transforms.Resize(image_size),
+      transforms.ToTensor(),
+      normalize,
+  ])
+  return transform(data)
+
+
+class PostProcessor(beam.DoFn):
+  def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]:
+    filename, prediction_result = element
+    prediction = torch.argmax(prediction_result.inference, dim=0)
+    yield filename + ',' + str(prediction.item())
+
+
+def parse_known_args(argv):
+  """Parses args for the workflow."""
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--input',
+      dest='input',
+      default='gs://apache-beam-ml/testing/inputs/'
+      'it_mobilenetv2_imagenet_validation_inputs.txt',
+      help='Path to the text file containing image names.')
+  parser.add_argument(
+      '--output',
+      dest='output',
+      help='Path where to save output predictions.'
+      ' text file.')
+  parser.add_argument(
+      '--model_state_dict_path',
+      dest='model_state_dict_path',
+      default='gs://apache-beam-ml/'
+      'models/imagenet_classification_mobilenet_v2.pt',
+      help="Path to the model's state_dict. "
+      "Default state_dict would be MobilenetV2.")
+  parser.add_argument(
+      '--images_dir',
+      default=None,
+      help='Path to the directory where images are stored.'
+      'Not required if image names in the input file have absolute path.')
+  return parser.parse_known_args(argv)
+
+
+def run(argv=None, model_class=None, model_params=None, save_main_session=True):
+  """
+  Args:
+    argv: Command line arguments defined for this example.
+    model_class: Reference to the class definition of the model.
+                If None, MobilenetV2 will be used as default .
+    model_params: Parameters passed to the constructor of the model_class.
+                  These will be used to instantiate the model object in the
+                  RunInference API.
+  """
+  known_args, pipeline_args = parse_known_args(argv)
+  pipeline_options = PipelineOptions(pipeline_args)
+  pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
+
+  if not model_class:
+    model_class = MobileNetV2
+    model_params = {'num_classes': 1000}
+
+  model_loader = PytorchModelLoader(
+      state_dict_path=known_args.model_state_dict_path,
+      model_class=model_class,
+      model_params=model_params)
+
+  with beam.Pipeline(options=pipeline_options) as p:
+    filename_value_pair = (
+        p
+        | 'ReadImageNames' >> beam.io.ReadFromText(
+            known_args.input, skip_header_lines=1)
+        | 'ReadImageData' >> beam.Map(
+            lambda image_name: read_image(
+                image_file_name=image_name, path_to_dir=known_args.images_dir))
+        | 'PreprocessImages' >> beam.MapTuple(
+            lambda file_name, data: (file_name, preprocess_image(data))))
+    predictions = (
+        filename_value_pair
+        | 'PyTorchRunInference' >> RunInference(model_loader).with_output_types(
+            Tuple[str, PredictionResult])
+        | 'ProcessOutput' >> beam.ParDo(PostProcessor()))
+
+    if known_args.output:
+      predictions | "WriteOutputToGCS" >> beam.io.WriteToText( # pylint: disable=expression-not-assigned
+        known_args.output,
+        shard_name_template='',
+        append_trailing_newlines=True)
+
+
+if __name__ == '__main__':
+  run()
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py
new file mode 100644
index 00000000000..066d667b786
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py
@@ -0,0 +1,95 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pylint: skip-file
+
+"""End-to-End test for Pytorch Inference"""
+
+import logging
+import os
+import unittest
+import uuid
+
+import pytest
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.testing.test_pipeline import TestPipeline
+
+try:
+  import torch
+  from apache_beam.examples.inference import pytorch_image_classification
+except ImportError as e:
+  torch = None
+
+_EXPECTED_OUTPUTS = {
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005001.JPEG': '681',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005002.JPEG': '333',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005003.JPEG': '711',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005004.JPEG': '286',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005005.JPEG': '433',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005006.JPEG': '290',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005007.JPEG': '890',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005008.JPEG': '592',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005009.JPEG': '406',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005010.JPEG': '996',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005011.JPEG': '327',
+    'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005012.JPEG': '573'
+}
+
+
+def process_outputs(filepath):
+  with FileSystems().open(filepath) as f:
+    lines = f.readlines()
+  lines = [l.decode('utf-8').strip('\n') for l in lines]
+  return lines
+
+
+@unittest.skipIf(
+    os.getenv('FORCE_TORCH_IT') is None and torch is None,
+    'Missing dependencies. '
+    'Test depends on torch, torchvision and pillow')
+class PyTorchInference(unittest.TestCase):
+  @pytest.mark.uses_pytorch
+  @pytest.mark.it_postcommit
+  def test_torch_run_inference_imagenet_mobilenetv2(self):
+    test_pipeline = TestPipeline(is_integration_test=True)
+    # text files containing absolute path to the imagenet validation data on GCS
+    file_of_image_names = 'gs://apache-beam-ml/testing/inputs/it_mobilenetv2_imagenet_validation_inputs.txt'  # disable: line-too-long
+    output_file_dir = 'gs://apache-beam-ml/testing/predictions'
+    output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt'])
+
+    model_state_dict_path = 'gs://apache-beam-ml/models/imagenet_classification_mobilenet_v2.pt'
+    extra_opts = {
+        'input': file_of_image_names,
+        'output': output_file,
+        'model_state_dict_path': model_state_dict_path,
+    }
+    pytorch_image_classification.run(
+        test_pipeline.get_full_options_as_args(**extra_opts),
+        save_main_session=False)
+
+    self.assertEqual(FileSystems().exists(output_file), True)
+    predictions = process_outputs(filepath=output_file)
+
+    for prediction in predictions:
+      filename, prediction = prediction.split(',')
+      self.assertEqual(_EXPECTED_OUTPUTS[filename], prediction)
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.DEBUG)
+  unittest.main()
diff --git a/sdks/python/apache_beam/ml/inference/torch_tests_requirements.txt b/sdks/python/apache_beam/ml/inference/torch_tests_requirements.txt
new file mode 100644
index 00000000000..b1ac63c314c
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/torch_tests_requirements.txt
@@ -0,0 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+torch>=1.7.1
+torchvision>=0.8.2
+pillow>=8.0.0
diff --git a/sdks/python/scripts/run_integration_test.sh b/sdks/python/scripts/run_integration_test.sh
index 7b42e676129..da942bd30e6 100755
--- a/sdks/python/scripts/run_integration_test.sh
+++ b/sdks/python/scripts/run_integration_test.sh
@@ -79,6 +79,7 @@ WORKER_JAR=""
 KMS_KEY_NAME="projects/apache-beam-testing/locations/global/keyRings/beam-it/cryptoKeys/test"
 SUITE=""
 COLLECT_MARKERS=
+REQUIREMENTS_FILE=""
 
 # Default test (pytest) options.
 # Run WordCountIT.test_wordcount_it by default if no test options are
@@ -114,6 +115,11 @@ case $key in
         shift # past argument
         shift # past value
         ;;
+    --requirements_file)
+      REQUIREMENTS_FILE="$2"
+      shift # past argument
+      shift # past value
+      ;;
     --num_workers)
         NUM_WORKERS="$2"
         shift # past argument
@@ -202,7 +208,6 @@ fi
 # Build pipeline options if not provided in --pipeline_opts from commandline
 
 if [[ -z $PIPELINE_OPTS ]]; then
-
   # Get tar ball path
   if [[ $(find ${SDK_LOCATION} 2> /dev/null) ]]; then
     SDK_LOCATION=$(find ${SDK_LOCATION} | tail -n1)
@@ -213,9 +218,13 @@ if [[ -z $PIPELINE_OPTS ]]; then
   # Install test dependencies for ValidatesRunner tests.
   # pyhamcrest==1.10.0 doesn't work on Py2.
   # See: https://github.com/hamcrest/PyHamcrest/issues/131.
-  echo "pyhamcrest!=1.10.0,<2.0.0" > postcommit_requirements.txt
-  echo "mock<3.0.0" >> postcommit_requirements.txt
-  echo "parameterized>=0.7.1,<0.8.0" >> postcommit_requirements.txt
+  if [[ -z $REQUIREMENTS_FILE ]]; then
+    echo "pyhamcrest!=1.10.0,<2.0.0" > postcommit_requirements.txt
+    echo "mock<3.0.0" >> postcommit_requirements.txt
+    echo "parameterized>=0.7.1,<0.8.0" >> postcommit_requirements.txt
+  else
+    cp $REQUIREMENTS_FILE postcommit_requirements.txt
+  fi
 
   # Options used to run testing pipeline on Cloud Dataflow Service. Also used for
   # running on DirectRunner (some options ignored).
@@ -283,4 +292,4 @@ if [ -z "$COLLECT_MARKERS" ]; then
   pytest $ARGS --test-pipeline-options="$PIPELINE_OPTS"
 else
   pytest $ARGS --test-pipeline-options="$PIPELINE_OPTS" "$COLLECT_MARKERS"
-fi
\ No newline at end of file
+fi
diff --git a/sdks/python/test-suites/dataflow/build.gradle b/sdks/python/test-suites/dataflow/build.gradle
index 036159eede3..d16111679da 100644
--- a/sdks/python/test-suites/dataflow/build.gradle
+++ b/sdks/python/test-suites/dataflow/build.gradle
@@ -83,4 +83,4 @@ task examplesPostCommit {
   getVersionsAsList('dataflow_examples_postcommit_py_versions').each {
     dependsOn.add(":sdks:python:test-suites:dataflow:py${getVersionSuffix(it)}:examples")
   }
-}
\ No newline at end of file
+}
diff --git a/sdks/python/test-suites/direct/common.gradle b/sdks/python/test-suites/direct/common.gradle
index 15df9696272..f08655e0281 100644
--- a/sdks/python/test-suites/direct/common.gradle
+++ b/sdks/python/test-suites/direct/common.gradle
@@ -185,3 +185,36 @@ tasks.register("hdfsIntegrationTest") {
     }
   }
 }
+
+// Pytorch RunInference IT tests
+task torchTests {
+  dependsOn 'installGcpTest'
+  dependsOn ':sdks:python:sdist'
+  def requirementsFile = "${rootDir}/sdks/python/apache_beam/ml/inference/torch_tests_requirements.txt"
+  doFirst {
+      exec {
+        executable 'sh'
+        args '-c', ". ${envdir}/bin/activate && pip install -r $requirementsFile"
+      }
+    }
+  doLast {
+      def testOpts = basicTestOpts
+      def argMap = [
+          "test_opts": testOpts,
+          "suite": "postCommitIT-direct-py${pythonVersionSuffix}",
+          "collect": "uses_pytorch and it_postcommit",
+          "runner": "TestDirectRunner"
+      ]
+      def cmdArgs = mapToArgString(argMap)
+      exec {
+        executable 'sh'
+        args '-c', ". ${envdir}/bin/activate && export FORCE_TORCH_IT=1 && ${runScriptsDir}/run_integration_test.sh $cmdArgs"
+      }
+    }
+}
+
+// Add all the RunInference framework IT tests to this gradle task that runs on Direct Runner Post commit suite.
+// TODO(anandinguva): Add sklearn IT test here
+project.tasks.register("inferencePostCommitIT") {
+  dependsOn = ['torchTests']
+}