You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2023/02/13 15:23:25 UTC

[beam] branch master updated: Add WatchFilePattern (#25393)

This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new bb5e200df70 Add WatchFilePattern  (#25393)
bb5e200df70 is described below

commit bb5e200df70a875379ff5a6dfe72325172c3eb77
Author: Anand Inguva <34...@users.noreply.github.com>
AuthorDate: Mon Feb 13 10:23:16 2023 -0500

    Add WatchFilePattern  (#25393)
    
    * Add WatchFilePattern transform
    
    * Remove defaults and update instructions
    
    * Add batch args
    
    * Refactor example based on comments
    
    * Changes based on  PR comments
    
    * Fix typo
    
    * Fix up lint
    
    * Fix doc precommit
    
    * Fix up pydocs
    
    * Fixup lint
    
    * Fix up test
    
    * Update docstring as per comments
    
    * Add unittest.main
    
    * Update sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py
    
    Co-authored-by: Danny McCormick <da...@google.com>
    
    ---------
    
    Co-authored-by: Danny McCormick <da...@google.com>
---
 CHANGES.md                                         |   1 +
 ...ytorch_image_classification_with_side_inputs.py | 218 +++++++++++++++++++++
 sdks/python/apache_beam/io/fileio.py               |   6 +-
 sdks/python/apache_beam/ml/inference/base_test.py  |  22 ++-
 sdks/python/apache_beam/ml/inference/utils.py      | 116 +++++++++++
 sdks/python/apache_beam/ml/inference/utils_test.py | 103 ++++++++++
 6 files changed, 455 insertions(+), 11 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 10310c6cbef..55a106f3513 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -68,6 +68,7 @@
 * Add UDF metrics support for Samza portable mode.
 * Option for SparkRunner to avoid the need of SDF output to fit in memory ([#23852](https://github.com/apache/beam/issues/23852)).
   This helps e.g. with ParquetIO reads. Turn the feature on by adding experiment `use_bounded_concurrent_output_for_sdf`.
+* Add `WatchFilePattern` transform, which can be used as a side input to the RunInference PTransfrom to watch for model updates using a file pattern. ([#24042](https://github.com/apache/beam/issues/24042))
 * Add support for loading TorchScript models with `PytorchModelHandler`. The TorchScript model path can be
   passed to PytorchModelHandler using `torch_script_model_path=<path_to_model>`. ([#25321](https://github.com/apache/beam/pull/25321))
 
diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py
new file mode 100644
index 00000000000..2a4e6e9a9bc
--- /dev/null
+++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py
@@ -0,0 +1,218 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+A pipeline that uses RunInference PTransform to perform image classification
+and uses WatchFilePattern as side input to the RunInference PTransform.
+WatchFilePattern is used to watch for a file updates matching the file_pattern
+based on timestamps and emits latest model metadata, which is used in
+RunInference API for the dynamic model updates without the need for stopping
+the beam pipeline.
+
+This pipeline follows the pattern from
+https://beam.apache.org/documentation/patterns/side-inputs/
+
+To use the PubSub reading from a topic in the pipeline as source, you can
+publish a path to the model(resnet152 used in the pipeline from
+torchvision.models.resnet152) to the PubSub topic. Then pass that
+topic via command line arg --topic.  The published path(str) should be
+UTF-8 encoded.
+
+To run the example on DataflowRunner,
+
+python apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py # pylint: disable=line-too-long
+  --project=<your-project>
+  --re=<your-region>
+  --temp_location=<your-tmp-location>
+  --staging_location=<your-staging-location>
+  --runner=DataflowRunner
+  --streaming
+  --interval=10
+  --num_workers=5
+  --requirements_file=apache_beam/ml/inference/torch_tests_requirements.txt
+  --topic=<pubsub_topic>
+  --file_pattern=<glob_pattern>
+
+file_pattern is path(can contain glob characters), which will be passed to
+WatchContinuously transform for model updates. WatchContinuously watches the
+file_pattern and emits a latest file path, sorted by timestamp. Files that
+are read before and updated with same name will be ignored as an update.
+
+The pipeline expects there is at least one file present to match the
+file_pattern before pipeline startup. Presumably, this would be the
+`initial_model_path`. If there is no file matching before pipeline
+startup time, the pipeline would fail.
+"""
+
+import argparse
+import io
+import logging
+import os
+from typing import Iterable
+from typing import Iterator
+from typing import Optional
+from typing import Tuple
+
+import apache_beam as beam
+import torch
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
+from apache_beam.ml.inference.utils import WatchFilePattern
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from apache_beam.runners.runner import PipelineResult
+from PIL import Image
+from torchvision import models
+from torchvision import transforms
+
+
+def read_image(image_file_name: str,
+               path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]:
+  if path_to_dir is not None:
+    image_file_name = os.path.join(path_to_dir, image_file_name)
+  with FileSystems().open(image_file_name, 'r') as file:
+    data = Image.open(io.BytesIO(file.read())).convert('RGB')
+    return image_file_name, data
+
+
+def preprocess_image(data: Image.Image) -> torch.Tensor:
+  image_size = (224, 224)
+  # Pre-trained PyTorch models expect input images normalized with the
+  # below values (see: https://pytorch.org/vision/stable/models.html)
+  normalize = transforms.Normalize(
+      mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+  transform = transforms.Compose([
+      transforms.Resize(image_size),
+      transforms.ToTensor(),
+      normalize,
+  ])
+  return transform(data)
+
+
+def filter_empty_lines(text: str) -> Iterator[str]:
+  if len(text.strip()) > 0:
+    yield text
+
+
+class PostProcessor(beam.DoFn):
+  """
+  Return filename, prediction and the model id used to perform the
+  prediction
+  """
+  def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]:
+    filename, prediction_result = element
+    prediction = torch.argmax(prediction_result.inference, dim=0)
+    yield filename, prediction, prediction_result.model_id
+
+
+def parse_known_args(argv):
+  """Parses args for the workflow."""
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--topic',
+      dest='topic',
+      help='PubSub topic emitting absolute path to the images.'
+      'Path must be accessible by the pipeline.')
+  parser.add_argument(
+      '--model_path',
+      '--initial_model_path',
+      dest='model_path',
+      default='gs://apache-beam-samples/run_inference/resnet152.pth',
+      help="Path to the initial model's state_dict. "
+      "This will be used until the first model update occurs.")
+  parser.add_argument(
+      '--file_pattern', help='Glob pattern to watch for an update.')
+  parser.add_argument(
+      '--interval',
+      default=10,
+      type=int,
+      help='Interval used to check for file updates.')
+
+  return parser.parse_known_args(argv)
+
+
+def run(
+    argv=None,
+    model_class=None,
+    model_params=None,
+    save_main_session=True,
+    device='CPU',
+    test_pipeline=None) -> PipelineResult:
+  """
+  Args:
+    argv: Command line arguments defined for this example.
+    model_class: Reference to the class definition of the model.
+    model_params: Parameters passed to the constructor of the model_class.
+                  These will be used to instantiate the model object in the
+                  RunInference PTransform.
+    save_main_session: Used for internal testing.
+    device: Device to be used on the Runner. Choices are (CPU, GPU).
+    test_pipeline: Used for internal testing.
+  """
+  known_args, pipeline_args = parse_known_args(argv)
+  pipeline_options = PipelineOptions(pipeline_args)
+  pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
+
+  if not model_class:
+    model_class = models.resnet152
+    model_params = {'num_classes': 1000}
+
+  # In this example we pass keyed inputs to RunInference transform.
+  # Therefore, we use KeyedModelHandler wrapper over PytorchModelHandler.
+  model_handler = KeyedModelHandler(
+      PytorchModelHandlerTensor(
+          state_dict_path=known_args.model_path,
+          model_class=model_class,
+          model_params=model_params,
+          device=device,
+          min_batch_size=10,
+          max_batch_size=100))
+
+  pipeline = test_pipeline
+  if not test_pipeline:
+    pipeline = beam.Pipeline(options=pipeline_options)
+
+  side_input = pipeline | WatchFilePattern(
+      interval=known_args.interval, file_pattern=known_args.file_pattern)
+
+  filename_value_pair = (
+      pipeline
+      | 'ReadImageNamesFromPubSub' >> beam.io.ReadFromPubSub(known_args.topic)
+      | 'DecodeBytes' >> beam.Map(lambda x: x.decode('utf-8'))
+      | 'ReadImageData' >>
+      beam.Map(lambda image_name: read_image(image_file_name=image_name))
+      | 'PreprocessImages' >> beam.MapTuple(
+          lambda file_name, data: (file_name, preprocess_image(data))))
+  predictions = (
+      filename_value_pair
+      | 'PyTorchRunInference' >> RunInference(
+          model_handler, model_metadata_pcoll=side_input)
+      | 'ProcessOutput' >> beam.ParDo(PostProcessor()))
+
+  _ = predictions | beam.Map(logging.info)
+
+  result = pipeline.run()
+  result.wait_until_finish()
+  return result
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  run()
diff --git a/sdks/python/apache_beam/io/fileio.py b/sdks/python/apache_beam/io/fileio.py
index da6e86e2cf3..2be5d06a026 100644
--- a/sdks/python/apache_beam/io/fileio.py
+++ b/sdks/python/apache_beam/io/fileio.py
@@ -277,7 +277,8 @@ class MatchContinuously(beam.PTransform):
       start_timestamp=Timestamp.now(),
       stop_timestamp=MAX_TIMESTAMP,
       match_updated_files=False,
-      apply_windowing=False):
+      apply_windowing=False,
+      empty_match_treatment=EmptyMatchTreatment.ALLOW):
     """Initializes a MatchContinuously transform.
 
     Args:
@@ -299,6 +300,7 @@ class MatchContinuously(beam.PTransform):
     self.stop_ts = stop_timestamp
     self.match_upd = match_updated_files
     self.apply_windowing = apply_windowing
+    self.empty_match_treatment = empty_match_treatment
 
   def expand(self, pbegin) -> beam.PCollection[filesystem.FileMetadata]:
     # invoke periodic impulse
@@ -311,7 +313,7 @@ class MatchContinuously(beam.PTransform):
     match_files = (
         impulse
         | 'GetFilePattern' >> beam.Map(lambda x: self.file_pattern)
-        | MatchAll())
+        | MatchAll(self.empty_match_treatment))
 
     # apply deduplication strategy if required
     if self.has_deduplication:
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py
index 319735da236..dad18c7b9e1 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -440,15 +440,19 @@ class RunInferenceBaseTest(unittest.TestCase):
         first_ts + 22,
     ])
 
-    sample_side_input_elements = [(
-        first_ts + 8,
-        base.ModelMetadata(
-            model_id='fake_model_id_1', model_name='fake_model_id_1')),
-                                  (
-                                      first_ts + 15,
-                                      base.ModelMetadata(
-                                          model_id='fake_model_id_2',
-                                          model_name='fake_model_id_2'))]
+    sample_side_input_elements = [
+        (first_ts + 1, base.ModelMetadata(model_id='', model_name='')),
+        # if model_id is empty string, we use the default model
+        # handler model URI.
+        (
+            first_ts + 8,
+            base.ModelMetadata(
+                model_id='fake_model_id_1', model_name='fake_model_id_1')),
+        (
+            first_ts + 15,
+            base.ModelMetadata(
+                model_id='fake_model_id_2', model_name='fake_model_id_2'))
+    ]
 
     model_handler = FakeModelHandlerReturnsPredictionResult()
 
diff --git a/sdks/python/apache_beam/ml/inference/utils.py b/sdks/python/apache_beam/ml/inference/utils.py
index f30d8a8f648..4936ab5fe1d 100644
--- a/sdks/python/apache_beam/ml/inference/utils.py
+++ b/sdks/python/apache_beam/ml/inference/utils.py
@@ -19,13 +19,26 @@
 """
 Util/helper functions used in apache_beam.ml.inference.
 """
+import os
+from functools import partial
 from typing import Any
 from typing import Dict
 from typing import Iterable
 from typing import Optional
 from typing import Union
 
+import apache_beam as beam
+from apache_beam.io.fileio import EmptyMatchTreatment
+from apache_beam.io.fileio import MatchContinuously
+from apache_beam.ml.inference.base import ModelMetadata
 from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.transforms import trigger
+from apache_beam.transforms import window
+from apache_beam.transforms.userstate import CombiningValueStateSpec
+from apache_beam.utils.timestamp import MAX_TIMESTAMP
+from apache_beam.utils.timestamp import Timestamp
+
+_START_TIME_STAMP = Timestamp.now()
 
 
 def _convert_to_result(
@@ -46,3 +59,106 @@ def _convert_to_result(
         y in zip(batch, predictions_per_tensor)
     ]
   return [PredictionResult(x, y, model_id) for x, y in zip(batch, predictions)]
+
+
+class _ConvertIterToSingleton(beam.DoFn):
+  """
+  Internal only; No backwards compatibility.
+
+  The MatchContinuously transform examines all files present in a given
+  directory and returns those that have timestamps older than the
+  pipeline's start time. This can produce an Iterable rather than a
+  Singleton. This class only returns the file path when it is first
+  encountered, and it is cached as part of the side input caching mechanism.
+  If the path is seen again, it will not return anything.
+  By doing this, we can ensure that the output of this transform can be wrapped
+  with beam.pvalue.AsSingleton().
+  """
+  COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum)
+
+  def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)):
+    counter = count_state.read()
+    if counter == 0:
+      count_state.add(1)
+      yield element[1]
+
+
+class _GetLatestFileByTimeStamp(beam.DoFn):
+  """
+  Internal only; No backwards compatibility.
+
+  This DoFn checks the timestamps of files against the time that the pipeline
+  began running. It returns the files that were modified after the pipeline
+  started. If no such files are found, it returns a default file as fallback.
+   """
+  TIME_STATE = CombiningValueStateSpec(
+      'max', combine_fn=partial(max, default=_START_TIME_STAMP))
+
+  def process(self, element, time_state=beam.DoFn.StateParam(TIME_STATE)):
+    _, file_metadata = element
+    new_ts = file_metadata.last_updated_in_seconds
+    old_ts = time_state.read()
+    if new_ts > old_ts:
+      time_state.clear()
+      time_state.add(new_ts)
+      model_path = file_metadata.path
+    else:
+      model_path = ''
+
+    model_name = os.path.splitext(os.path.basename(model_path))[0]
+    return [
+        (model_path, ModelMetadata(model_id=model_path, model_name=model_name))
+    ]
+
+
+class WatchFilePattern(beam.PTransform):
+  def __init__(
+      self,
+      file_pattern,
+      interval=360,
+      stop_timestamp=MAX_TIMESTAMP,
+  ):
+    """
+    Watches a directory for updates to files matching a given file pattern.
+
+    Args:
+      file_pattern: The file path to read from as a local file path or a
+        GCS ``gs://`` path. The path can contain glob characters
+        (``*``, ``?``, and ``[...]`` sets).
+        interval: Interval at which to check for files matching file_pattern
+        in seconds.
+      stop_timestamp: Timestamp after which no more files will be checked.
+
+    **Note**:
+
+    1. Any previously used filenames cannot be reused. If a file is added
+        or updated to a previously used filename, this transform will ignore
+        that update. To trigger a model update, always upload a file with
+        unique name.
+    2. Initially, before the pipeline startup time, WatchFilePattern expects
+        at least one file present that matches the file_pattern.
+    3. This transform is supported in streaming mode since
+        MatchContinuously produces an unbounded source. Running in batch
+        mode can lead to undesired results or result in pipeline being stuck.
+
+
+    """
+    self.file_pattern = file_pattern
+    self.interval = interval
+    self.stop_timestamp = stop_timestamp
+
+  def expand(self, pcoll) -> beam.PCollection[ModelMetadata]:
+    return (
+        pcoll
+        | 'MatchContinuously' >> MatchContinuously(
+            file_pattern=self.file_pattern,
+            interval=self.interval,
+            stop_timestamp=self.stop_timestamp,
+            empty_match_treatment=EmptyMatchTreatment.DISALLOW)
+        | "AttachKey" >> beam.Map(lambda x: (x.path, x))
+        | "GetLatestFileMetaData" >> beam.ParDo(_GetLatestFileByTimeStamp())
+        | "AcceptNewSideInputOnly" >> beam.ParDo(_ConvertIterToSingleton())
+        | 'ApplyGlobalWindow' >> beam.transforms.WindowInto(
+            window.GlobalWindows(),
+            trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)),
+            accumulation_mode=trigger.AccumulationMode.DISCARDING))
diff --git a/sdks/python/apache_beam/ml/inference/utils_test.py b/sdks/python/apache_beam/ml/inference/utils_test.py
new file mode 100644
index 00000000000..66499a5a6f4
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/utils_test.py
@@ -0,0 +1,103 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# pytype: skip-file
+
+import unittest
+
+import apache_beam as beam
+from apache_beam.io.filesystem import FileMetadata
+from apache_beam.ml.inference import utils
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+
+
+class WatchFilePatternTest(unittest.TestCase):
+  def test_latest_file_by_timestamp_default_value(self):
+    # match continuously returns the files in sorted timestamp order.
+    main_input_pcoll = [
+        FileMetadata(
+            'path1.py',
+            10,
+            last_updated_in_seconds=utils._START_TIME_STAMP - 20),
+        FileMetadata(
+            'path2.py',
+            10,
+            last_updated_in_seconds=utils._START_TIME_STAMP - 10)
+    ]
+    with TestPipeline() as p:
+      files_pc = (
+          p
+          | beam.Create(main_input_pcoll)
+          | beam.Map(lambda x: (x.path, x))
+          | beam.ParDo(utils._GetLatestFileByTimeStamp())
+          | beam.Map(lambda x: x[0]))
+      assert_that(files_pc, equal_to(['', '']))
+
+  def test_latest_file_with_timestamp_after_pipeline_construction_time(self):
+    main_input_pcoll = [
+        FileMetadata(
+            'path1.py',
+            10,
+            last_updated_in_seconds=utils._START_TIME_STAMP + 10)
+    ]
+    with TestPipeline() as p:
+      files_pc = (
+          p
+          | beam.Create(main_input_pcoll)
+          | beam.Map(lambda x: (x.path, x))
+          | beam.ParDo(utils._GetLatestFileByTimeStamp())
+          | beam.Map(lambda x: x[0]))
+      assert_that(files_pc, equal_to(['path1.py']))
+
+  def test_emitting_singleton_output(self):
+    # match continuously returns the files in sorted timestamp order.
+    main_input_pcoll = [
+        FileMetadata(
+            'path1.py',
+            10,
+            last_updated_in_seconds=utils._START_TIME_STAMP - 20),
+        # returns default
+        FileMetadata(
+            'path2.py',
+            10,
+            last_updated_in_seconds=utils._START_TIME_STAMP - 10),
+        # returns default
+        FileMetadata(
+            'path3.py',
+            10,
+            last_updated_in_seconds=utils._START_TIME_STAMP + 10),
+        FileMetadata(
+            'path4.py',
+            10,
+            last_updated_in_seconds=utils._START_TIME_STAMP + 20)
+    ]
+    # returns path3.py
+
+    with TestPipeline() as p:
+      files_pc = (
+          p
+          | beam.Create(main_input_pcoll)
+          | beam.Map(lambda x: (x.path, x))
+          | beam.ParDo(utils._GetLatestFileByTimeStamp())
+          | beam.ParDo(utils._ConvertIterToSingleton())
+          | beam.Map(lambda x: x[0]))
+      assert_that(files_pc, equal_to(['', 'path3.py', 'path4.py']))
+
+
+if __name__ == '__main__':
+  unittest.main()