You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2023/02/13 15:23:25 UTC
[beam] branch master updated: Add WatchFilePattern (#25393)
This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new bb5e200df70 Add WatchFilePattern (#25393)
bb5e200df70 is described below
commit bb5e200df70a875379ff5a6dfe72325172c3eb77
Author: Anand Inguva <34...@users.noreply.github.com>
AuthorDate: Mon Feb 13 10:23:16 2023 -0500
Add WatchFilePattern (#25393)
* Add WatchFilePattern transform
* Remove defaults and update instructions
* Add batch args
* Refactor example based on comments
* Changes based on PR comments
* Fix typo
* Fix up lint
* Fix doc precommit
* Fix up pydocs
* Fixup lint
* Fix up test
* Update docstring as per comments
* Add unittest.main
* Update sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py
Co-authored-by: Danny McCormick <da...@google.com>
---------
Co-authored-by: Danny McCormick <da...@google.com>
---
CHANGES.md | 1 +
...ytorch_image_classification_with_side_inputs.py | 218 +++++++++++++++++++++
sdks/python/apache_beam/io/fileio.py | 6 +-
sdks/python/apache_beam/ml/inference/base_test.py | 22 ++-
sdks/python/apache_beam/ml/inference/utils.py | 116 +++++++++++
sdks/python/apache_beam/ml/inference/utils_test.py | 103 ++++++++++
6 files changed, 455 insertions(+), 11 deletions(-)
diff --git a/CHANGES.md b/CHANGES.md
index 10310c6cbef..55a106f3513 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -68,6 +68,7 @@
* Add UDF metrics support for Samza portable mode.
* Option for SparkRunner to avoid the need of SDF output to fit in memory ([#23852](https://github.com/apache/beam/issues/23852)).
This helps e.g. with ParquetIO reads. Turn the feature on by adding experiment `use_bounded_concurrent_output_for_sdf`.
+* Add `WatchFilePattern` transform, which can be used as a side input to the RunInference PTransfrom to watch for model updates using a file pattern. ([#24042](https://github.com/apache/beam/issues/24042))
* Add support for loading TorchScript models with `PytorchModelHandler`. The TorchScript model path can be
passed to PytorchModelHandler using `torch_script_model_path=<path_to_model>`. ([#25321](https://github.com/apache/beam/pull/25321))
diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py
new file mode 100644
index 00000000000..2a4e6e9a9bc
--- /dev/null
+++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py
@@ -0,0 +1,218 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+A pipeline that uses RunInference PTransform to perform image classification
+and uses WatchFilePattern as side input to the RunInference PTransform.
+WatchFilePattern is used to watch for a file updates matching the file_pattern
+based on timestamps and emits latest model metadata, which is used in
+RunInference API for the dynamic model updates without the need for stopping
+the beam pipeline.
+
+This pipeline follows the pattern from
+https://beam.apache.org/documentation/patterns/side-inputs/
+
+To use the PubSub reading from a topic in the pipeline as source, you can
+publish a path to the model(resnet152 used in the pipeline from
+torchvision.models.resnet152) to the PubSub topic. Then pass that
+topic via command line arg --topic. The published path(str) should be
+UTF-8 encoded.
+
+To run the example on DataflowRunner,
+
+python apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py # pylint: disable=line-too-long
+ --project=<your-project>
+ --re=<your-region>
+ --temp_location=<your-tmp-location>
+ --staging_location=<your-staging-location>
+ --runner=DataflowRunner
+ --streaming
+ --interval=10
+ --num_workers=5
+ --requirements_file=apache_beam/ml/inference/torch_tests_requirements.txt
+ --topic=<pubsub_topic>
+ --file_pattern=<glob_pattern>
+
+file_pattern is path(can contain glob characters), which will be passed to
+WatchContinuously transform for model updates. WatchContinuously watches the
+file_pattern and emits a latest file path, sorted by timestamp. Files that
+are read before and updated with same name will be ignored as an update.
+
+The pipeline expects there is at least one file present to match the
+file_pattern before pipeline startup. Presumably, this would be the
+`initial_model_path`. If there is no file matching before pipeline
+startup time, the pipeline would fail.
+"""
+
+import argparse
+import io
+import logging
+import os
+from typing import Iterable
+from typing import Iterator
+from typing import Optional
+from typing import Tuple
+
+import apache_beam as beam
+import torch
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
+from apache_beam.ml.inference.utils import WatchFilePattern
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from apache_beam.runners.runner import PipelineResult
+from PIL import Image
+from torchvision import models
+from torchvision import transforms
+
+
+def read_image(image_file_name: str,
+ path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]:
+ if path_to_dir is not None:
+ image_file_name = os.path.join(path_to_dir, image_file_name)
+ with FileSystems().open(image_file_name, 'r') as file:
+ data = Image.open(io.BytesIO(file.read())).convert('RGB')
+ return image_file_name, data
+
+
+def preprocess_image(data: Image.Image) -> torch.Tensor:
+ image_size = (224, 224)
+ # Pre-trained PyTorch models expect input images normalized with the
+ # below values (see: https://pytorch.org/vision/stable/models.html)
+ normalize = transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ transform = transforms.Compose([
+ transforms.Resize(image_size),
+ transforms.ToTensor(),
+ normalize,
+ ])
+ return transform(data)
+
+
+def filter_empty_lines(text: str) -> Iterator[str]:
+ if len(text.strip()) > 0:
+ yield text
+
+
+class PostProcessor(beam.DoFn):
+ """
+ Return filename, prediction and the model id used to perform the
+ prediction
+ """
+ def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]:
+ filename, prediction_result = element
+ prediction = torch.argmax(prediction_result.inference, dim=0)
+ yield filename, prediction, prediction_result.model_id
+
+
+def parse_known_args(argv):
+ """Parses args for the workflow."""
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--topic',
+ dest='topic',
+ help='PubSub topic emitting absolute path to the images.'
+ 'Path must be accessible by the pipeline.')
+ parser.add_argument(
+ '--model_path',
+ '--initial_model_path',
+ dest='model_path',
+ default='gs://apache-beam-samples/run_inference/resnet152.pth',
+ help="Path to the initial model's state_dict. "
+ "This will be used until the first model update occurs.")
+ parser.add_argument(
+ '--file_pattern', help='Glob pattern to watch for an update.')
+ parser.add_argument(
+ '--interval',
+ default=10,
+ type=int,
+ help='Interval used to check for file updates.')
+
+ return parser.parse_known_args(argv)
+
+
+def run(
+ argv=None,
+ model_class=None,
+ model_params=None,
+ save_main_session=True,
+ device='CPU',
+ test_pipeline=None) -> PipelineResult:
+ """
+ Args:
+ argv: Command line arguments defined for this example.
+ model_class: Reference to the class definition of the model.
+ model_params: Parameters passed to the constructor of the model_class.
+ These will be used to instantiate the model object in the
+ RunInference PTransform.
+ save_main_session: Used for internal testing.
+ device: Device to be used on the Runner. Choices are (CPU, GPU).
+ test_pipeline: Used for internal testing.
+ """
+ known_args, pipeline_args = parse_known_args(argv)
+ pipeline_options = PipelineOptions(pipeline_args)
+ pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
+
+ if not model_class:
+ model_class = models.resnet152
+ model_params = {'num_classes': 1000}
+
+ # In this example we pass keyed inputs to RunInference transform.
+ # Therefore, we use KeyedModelHandler wrapper over PytorchModelHandler.
+ model_handler = KeyedModelHandler(
+ PytorchModelHandlerTensor(
+ state_dict_path=known_args.model_path,
+ model_class=model_class,
+ model_params=model_params,
+ device=device,
+ min_batch_size=10,
+ max_batch_size=100))
+
+ pipeline = test_pipeline
+ if not test_pipeline:
+ pipeline = beam.Pipeline(options=pipeline_options)
+
+ side_input = pipeline | WatchFilePattern(
+ interval=known_args.interval, file_pattern=known_args.file_pattern)
+
+ filename_value_pair = (
+ pipeline
+ | 'ReadImageNamesFromPubSub' >> beam.io.ReadFromPubSub(known_args.topic)
+ | 'DecodeBytes' >> beam.Map(lambda x: x.decode('utf-8'))
+ | 'ReadImageData' >>
+ beam.Map(lambda image_name: read_image(image_file_name=image_name))
+ | 'PreprocessImages' >> beam.MapTuple(
+ lambda file_name, data: (file_name, preprocess_image(data))))
+ predictions = (
+ filename_value_pair
+ | 'PyTorchRunInference' >> RunInference(
+ model_handler, model_metadata_pcoll=side_input)
+ | 'ProcessOutput' >> beam.ParDo(PostProcessor()))
+
+ _ = predictions | beam.Map(logging.info)
+
+ result = pipeline.run()
+ result.wait_until_finish()
+ return result
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ run()
diff --git a/sdks/python/apache_beam/io/fileio.py b/sdks/python/apache_beam/io/fileio.py
index da6e86e2cf3..2be5d06a026 100644
--- a/sdks/python/apache_beam/io/fileio.py
+++ b/sdks/python/apache_beam/io/fileio.py
@@ -277,7 +277,8 @@ class MatchContinuously(beam.PTransform):
start_timestamp=Timestamp.now(),
stop_timestamp=MAX_TIMESTAMP,
match_updated_files=False,
- apply_windowing=False):
+ apply_windowing=False,
+ empty_match_treatment=EmptyMatchTreatment.ALLOW):
"""Initializes a MatchContinuously transform.
Args:
@@ -299,6 +300,7 @@ class MatchContinuously(beam.PTransform):
self.stop_ts = stop_timestamp
self.match_upd = match_updated_files
self.apply_windowing = apply_windowing
+ self.empty_match_treatment = empty_match_treatment
def expand(self, pbegin) -> beam.PCollection[filesystem.FileMetadata]:
# invoke periodic impulse
@@ -311,7 +313,7 @@ class MatchContinuously(beam.PTransform):
match_files = (
impulse
| 'GetFilePattern' >> beam.Map(lambda x: self.file_pattern)
- | MatchAll())
+ | MatchAll(self.empty_match_treatment))
# apply deduplication strategy if required
if self.has_deduplication:
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py
index 319735da236..dad18c7b9e1 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -440,15 +440,19 @@ class RunInferenceBaseTest(unittest.TestCase):
first_ts + 22,
])
- sample_side_input_elements = [(
- first_ts + 8,
- base.ModelMetadata(
- model_id='fake_model_id_1', model_name='fake_model_id_1')),
- (
- first_ts + 15,
- base.ModelMetadata(
- model_id='fake_model_id_2',
- model_name='fake_model_id_2'))]
+ sample_side_input_elements = [
+ (first_ts + 1, base.ModelMetadata(model_id='', model_name='')),
+ # if model_id is empty string, we use the default model
+ # handler model URI.
+ (
+ first_ts + 8,
+ base.ModelMetadata(
+ model_id='fake_model_id_1', model_name='fake_model_id_1')),
+ (
+ first_ts + 15,
+ base.ModelMetadata(
+ model_id='fake_model_id_2', model_name='fake_model_id_2'))
+ ]
model_handler = FakeModelHandlerReturnsPredictionResult()
diff --git a/sdks/python/apache_beam/ml/inference/utils.py b/sdks/python/apache_beam/ml/inference/utils.py
index f30d8a8f648..4936ab5fe1d 100644
--- a/sdks/python/apache_beam/ml/inference/utils.py
+++ b/sdks/python/apache_beam/ml/inference/utils.py
@@ -19,13 +19,26 @@
"""
Util/helper functions used in apache_beam.ml.inference.
"""
+import os
+from functools import partial
from typing import Any
from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Union
+import apache_beam as beam
+from apache_beam.io.fileio import EmptyMatchTreatment
+from apache_beam.io.fileio import MatchContinuously
+from apache_beam.ml.inference.base import ModelMetadata
from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.transforms import trigger
+from apache_beam.transforms import window
+from apache_beam.transforms.userstate import CombiningValueStateSpec
+from apache_beam.utils.timestamp import MAX_TIMESTAMP
+from apache_beam.utils.timestamp import Timestamp
+
+_START_TIME_STAMP = Timestamp.now()
def _convert_to_result(
@@ -46,3 +59,106 @@ def _convert_to_result(
y in zip(batch, predictions_per_tensor)
]
return [PredictionResult(x, y, model_id) for x, y in zip(batch, predictions)]
+
+
+class _ConvertIterToSingleton(beam.DoFn):
+ """
+ Internal only; No backwards compatibility.
+
+ The MatchContinuously transform examines all files present in a given
+ directory and returns those that have timestamps older than the
+ pipeline's start time. This can produce an Iterable rather than a
+ Singleton. This class only returns the file path when it is first
+ encountered, and it is cached as part of the side input caching mechanism.
+ If the path is seen again, it will not return anything.
+ By doing this, we can ensure that the output of this transform can be wrapped
+ with beam.pvalue.AsSingleton().
+ """
+ COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum)
+
+ def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)):
+ counter = count_state.read()
+ if counter == 0:
+ count_state.add(1)
+ yield element[1]
+
+
+class _GetLatestFileByTimeStamp(beam.DoFn):
+ """
+ Internal only; No backwards compatibility.
+
+ This DoFn checks the timestamps of files against the time that the pipeline
+ began running. It returns the files that were modified after the pipeline
+ started. If no such files are found, it returns a default file as fallback.
+ """
+ TIME_STATE = CombiningValueStateSpec(
+ 'max', combine_fn=partial(max, default=_START_TIME_STAMP))
+
+ def process(self, element, time_state=beam.DoFn.StateParam(TIME_STATE)):
+ _, file_metadata = element
+ new_ts = file_metadata.last_updated_in_seconds
+ old_ts = time_state.read()
+ if new_ts > old_ts:
+ time_state.clear()
+ time_state.add(new_ts)
+ model_path = file_metadata.path
+ else:
+ model_path = ''
+
+ model_name = os.path.splitext(os.path.basename(model_path))[0]
+ return [
+ (model_path, ModelMetadata(model_id=model_path, model_name=model_name))
+ ]
+
+
+class WatchFilePattern(beam.PTransform):
+ def __init__(
+ self,
+ file_pattern,
+ interval=360,
+ stop_timestamp=MAX_TIMESTAMP,
+ ):
+ """
+ Watches a directory for updates to files matching a given file pattern.
+
+ Args:
+ file_pattern: The file path to read from as a local file path or a
+ GCS ``gs://`` path. The path can contain glob characters
+ (``*``, ``?``, and ``[...]`` sets).
+ interval: Interval at which to check for files matching file_pattern
+ in seconds.
+ stop_timestamp: Timestamp after which no more files will be checked.
+
+ **Note**:
+
+ 1. Any previously used filenames cannot be reused. If a file is added
+ or updated to a previously used filename, this transform will ignore
+ that update. To trigger a model update, always upload a file with
+ unique name.
+ 2. Initially, before the pipeline startup time, WatchFilePattern expects
+ at least one file present that matches the file_pattern.
+ 3. This transform is supported in streaming mode since
+ MatchContinuously produces an unbounded source. Running in batch
+ mode can lead to undesired results or result in pipeline being stuck.
+
+
+ """
+ self.file_pattern = file_pattern
+ self.interval = interval
+ self.stop_timestamp = stop_timestamp
+
+ def expand(self, pcoll) -> beam.PCollection[ModelMetadata]:
+ return (
+ pcoll
+ | 'MatchContinuously' >> MatchContinuously(
+ file_pattern=self.file_pattern,
+ interval=self.interval,
+ stop_timestamp=self.stop_timestamp,
+ empty_match_treatment=EmptyMatchTreatment.DISALLOW)
+ | "AttachKey" >> beam.Map(lambda x: (x.path, x))
+ | "GetLatestFileMetaData" >> beam.ParDo(_GetLatestFileByTimeStamp())
+ | "AcceptNewSideInputOnly" >> beam.ParDo(_ConvertIterToSingleton())
+ | 'ApplyGlobalWindow' >> beam.transforms.WindowInto(
+ window.GlobalWindows(),
+ trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)),
+ accumulation_mode=trigger.AccumulationMode.DISCARDING))
diff --git a/sdks/python/apache_beam/ml/inference/utils_test.py b/sdks/python/apache_beam/ml/inference/utils_test.py
new file mode 100644
index 00000000000..66499a5a6f4
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/utils_test.py
@@ -0,0 +1,103 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# pytype: skip-file
+
+import unittest
+
+import apache_beam as beam
+from apache_beam.io.filesystem import FileMetadata
+from apache_beam.ml.inference import utils
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+
+
+class WatchFilePatternTest(unittest.TestCase):
+ def test_latest_file_by_timestamp_default_value(self):
+ # match continuously returns the files in sorted timestamp order.
+ main_input_pcoll = [
+ FileMetadata(
+ 'path1.py',
+ 10,
+ last_updated_in_seconds=utils._START_TIME_STAMP - 20),
+ FileMetadata(
+ 'path2.py',
+ 10,
+ last_updated_in_seconds=utils._START_TIME_STAMP - 10)
+ ]
+ with TestPipeline() as p:
+ files_pc = (
+ p
+ | beam.Create(main_input_pcoll)
+ | beam.Map(lambda x: (x.path, x))
+ | beam.ParDo(utils._GetLatestFileByTimeStamp())
+ | beam.Map(lambda x: x[0]))
+ assert_that(files_pc, equal_to(['', '']))
+
+ def test_latest_file_with_timestamp_after_pipeline_construction_time(self):
+ main_input_pcoll = [
+ FileMetadata(
+ 'path1.py',
+ 10,
+ last_updated_in_seconds=utils._START_TIME_STAMP + 10)
+ ]
+ with TestPipeline() as p:
+ files_pc = (
+ p
+ | beam.Create(main_input_pcoll)
+ | beam.Map(lambda x: (x.path, x))
+ | beam.ParDo(utils._GetLatestFileByTimeStamp())
+ | beam.Map(lambda x: x[0]))
+ assert_that(files_pc, equal_to(['path1.py']))
+
+ def test_emitting_singleton_output(self):
+ # match continuously returns the files in sorted timestamp order.
+ main_input_pcoll = [
+ FileMetadata(
+ 'path1.py',
+ 10,
+ last_updated_in_seconds=utils._START_TIME_STAMP - 20),
+ # returns default
+ FileMetadata(
+ 'path2.py',
+ 10,
+ last_updated_in_seconds=utils._START_TIME_STAMP - 10),
+ # returns default
+ FileMetadata(
+ 'path3.py',
+ 10,
+ last_updated_in_seconds=utils._START_TIME_STAMP + 10),
+ FileMetadata(
+ 'path4.py',
+ 10,
+ last_updated_in_seconds=utils._START_TIME_STAMP + 20)
+ ]
+ # returns path3.py
+
+ with TestPipeline() as p:
+ files_pc = (
+ p
+ | beam.Create(main_input_pcoll)
+ | beam.Map(lambda x: (x.path, x))
+ | beam.ParDo(utils._GetLatestFileByTimeStamp())
+ | beam.ParDo(utils._ConvertIterToSingleton())
+ | beam.Map(lambda x: x[0]))
+ assert_that(files_pc, equal_to(['', 'path3.py', 'path4.py']))
+
+
+if __name__ == '__main__':
+ unittest.main()