You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2022/07/22 06:24:07 UTC

[GitHub] [beam] tvalentyn commented on a diff in pull request #22164: Modify RunInference to return PipelineResult for the benchmark tests

tvalentyn commented on code in PR #22164:
URL: https://github.com/apache/beam/pull/22164#discussion_r927319071


##########
sdks/python/apache_beam/examples/inference/pytorch_image_classification.py:
##########
@@ -95,21 +98,28 @@ def parse_known_args(argv):
   return parser.parse_known_args(argv)
 
 
-def run(argv=None, model_class=None, model_params=None, save_main_session=True):
+def run(
+    argv=None,
+    model_class=None,
+    model_params=None,
+    save_main_session=True,
+    test_pipeline=None) -> PipelineResult:
   """
   Args:
     argv: Command line arguments defined for this example.
     model_class: Reference to the class definition of the model.
     model_params: Parameters passed to the constructor of the model_class.
                   These will be used to instantiate the model object in the
                   RunInference API.
+    test_pipeline: used for internal testing. No backwards-compatibility,

Review Comment:
   Remove `No backwards-compatibility,`
   This is an example, not an API. Backwards compatibility is not relevant.



##########
sdks/python/apache_beam/examples/inference/pytorch_image_classification.py:
##########
@@ -95,21 +98,28 @@ def parse_known_args(argv):
   return parser.parse_known_args(argv)
 
 
-def run(argv=None, model_class=None, model_params=None, save_main_session=True):
+def run(
+    argv=None,
+    model_class=None,
+    model_params=None,
+    save_main_session=True,
+    test_pipeline=None) -> PipelineResult:
   """
   Args:
     argv: Command line arguments defined for this example.
     model_class: Reference to the class definition of the model.
     model_params: Parameters passed to the constructor of the model_class.
                   These will be used to instantiate the model object in the
                   RunInference API.
+    test_pipeline: used for internal testing. No backwards-compatibility,

Review Comment:
   Add that save_main_session is used for testing only as well.



##########
sdks/python/apache_beam/examples/inference/pytorch_image_classification.py:
##########
@@ -72,6 +74,7 @@ def parse_known_args(argv):
   """Parses args for the workflow."""
   parser = argparse.ArgumentParser()
   parser.add_argument(
+      '--input_file',

Review Comment:
   why do we need two flags? Should we just use `--input` throughout all the examples & docs?



##########
sdks/python/apache_beam/examples/inference/pytorch_image_classification.py:
##########
@@ -120,27 +130,35 @@ def run(argv=None, model_class=None, model_params=None, save_main_session=True):
           model_class=model_class,
           model_params=model_params))
 
-  with beam.Pipeline(options=pipeline_options) as p:
-    filename_value_pair = (
-        p
-        | 'ReadImageNames' >> beam.io.ReadFromText(
-            known_args.input, skip_header_lines=1)
-        | 'ReadImageData' >> beam.Map(
-            lambda image_name: read_image(
-                image_file_name=image_name, path_to_dir=known_args.images_dir))
-        | 'PreprocessImages' >> beam.MapTuple(
-            lambda file_name, data: (file_name, preprocess_image(data))))
-    predictions = (
-        filename_value_pair
-        | 'PyTorchRunInference' >> RunInference(model_handler)
-        | 'ProcessOutput' >> beam.ParDo(PostProcessor()))
-
-    if known_args.output:
-      predictions | "WriteOutputToGCS" >> beam.io.WriteToText( # pylint: disable=expression-not-assigned
-        known_args.output,
-        shard_name_template='',
-        append_trailing_newlines=True)
+  if not test_pipeline:

Review Comment:
   you could s/`test_pipeline`/`pipeline` in args, and delete 135-136.



##########
sdks/python/apache_beam/examples/inference/pytorch_image_classification.py:
##########
@@ -95,21 +98,28 @@ def parse_known_args(argv):
   return parser.parse_known_args(argv)
 
 
-def run(argv=None, model_class=None, model_params=None, save_main_session=True):
+def run(
+    argv=None,
+    model_class=None,
+    model_params=None,
+    save_main_session=True,
+    test_pipeline=None) -> PipelineResult:
   """
   Args:
     argv: Command line arguments defined for this example.
     model_class: Reference to the class definition of the model.
     model_params: Parameters passed to the constructor of the model_class.
                   These will be used to instantiate the model object in the
                   RunInference API.
+    test_pipeline: used for internal testing. No backwards-compatibility,

Review Comment:
   Use consistent capitalization in docstring, e.g. `Used` instead of `used`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org