You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by "AnandInguva (via GitHub)" <gi...@apache.org> on 2023/04/24 14:53:09 UTC

[GitHub] [beam] AnandInguva commented on a diff in pull request #26309: Add support for pre/post processing in RunInference

AnandInguva commented on code in PR #26309:
URL: https://github.com/apache/beam/pull/26309#discussion_r1175360831


##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -58,7 +59,9 @@
 
 ModelT = TypeVar('ModelT')
 ExampleT = TypeVar('ExampleT')
+PreT = TypeVar('PreT')

Review Comment:
   
   ```suggestion
   PreProcessT = TypeVar('PreT')
   ```



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -382,14 +572,38 @@ def expand(
             **resource_hints)
 
     if self._with_exception_handling:
-      run_inference_pardo = run_inference_pardo.with_exception_handling(
+      results, bad_inference = (
+          batched_elements_pcoll
+          | 'BeamML_RunInference' >>
+          run_inference_pardo.with_exception_handling(
           exc_class=self._exc_class,
           use_subprocess=self._use_subprocess,
-          threshold=self._threshold)
+          threshold=self._threshold))
+    else:
+      results = (
+          batched_elements_pcoll
+          | 'BeamML_RunInference' >> run_inference_pardo)
+
+    for idx in range(len(postprocess_fns)):
+      fn = postprocess_fns[idx]
+      if self._with_exception_handling:
+        results, bad = (results
+        | f"BeamML_RunInference_Postprocess-{idx}" >> beam.Map(
+          fn).with_exception_handling(
+          exc_class=self._exc_class,
+          use_subprocess=self._use_subprocess,
+          threshold=self._threshold))

Review Comment:
   
   Just a suggestion, Can we move this `for` loop to method since preprocess and postprocess use same logic? the method will accept pcoll, a list of fns and a str(namespace)



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -404,13 +618,31 @@ def with_exception_handling(
 
     For example, one would write::
 
-        good, bad = RunInference(
+        main, other = RunInference(
           maybe_error_raising_model_handler
         ).with_exception_handling()
 
-    and `good` will be a PCollection of PredictionResults and `bad` will
-    contain a tuple of all batches that raised exceptions, along with their
-    corresponding exception.
+    and `good` will be a PCollection of PredictionResults and `other` will
+    contain a `RunInferenceDLQ` object with PCollections containing failed
+    records for each failed inference, preprocess operation, or postprocess
+    operation. To access each collection of failed records, one would write:
+
+        failed_inferences = other.failed_inferences
+        failed_preprocessing = other.failed_preprocessing
+        failed_postprocessing = other.failed_postprocessing
+
+    failed_inferences is in the form
+    PCollection[Tuple[failed batch, exception]].
+
+    failed_preprocessing is in the form
+    PCollectionList[Tuple[failed record, exception]]], where each element of
+    the list corresponds to a preprocess function. These PCollections are
+    in the same order that the preprocess functions are applied.
+
+    failed_postprocessing is in the form
+    PCollectionList[Tuple[failed record, exception]]], where each element of

Review Comment:
   Since Python doesn't have a `PCollectionList`



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -404,13 +618,31 @@ def with_exception_handling(
 
     For example, one would write::
 
-        good, bad = RunInference(
+        main, other = RunInference(
           maybe_error_raising_model_handler
         ).with_exception_handling()
 
-    and `good` will be a PCollection of PredictionResults and `bad` will
-    contain a tuple of all batches that raised exceptions, along with their
-    corresponding exception.
+    and `good` will be a PCollection of PredictionResults and `other` will
+    contain a `RunInferenceDLQ` object with PCollections containing failed
+    records for each failed inference, preprocess operation, or postprocess
+    operation. To access each collection of failed records, one would write:
+
+        failed_inferences = other.failed_inferences
+        failed_preprocessing = other.failed_preprocessing
+        failed_postprocessing = other.failed_postprocessing
+
+    failed_inferences is in the form
+    PCollection[Tuple[failed batch, exception]].
+
+    failed_preprocessing is in the form
+    PCollectionList[Tuple[failed record, exception]]], where each element of
+    the list corresponds to a preprocess function. These PCollections are
+    in the same order that the preprocess functions are applied.
+
+    failed_postprocessing is in the form
+    PCollectionList[Tuple[failed record, exception]]], where each element of

Review Comment:
   ```suggestion
       List[PCollection[Tuple[failed record, exception]]]], where each element of
   ```



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -404,13 +618,31 @@ def with_exception_handling(
 
     For example, one would write::
 
-        good, bad = RunInference(
+        main, other = RunInference(
           maybe_error_raising_model_handler
         ).with_exception_handling()
 
-    and `good` will be a PCollection of PredictionResults and `bad` will
-    contain a tuple of all batches that raised exceptions, along with their
-    corresponding exception.
+    and `good` will be a PCollection of PredictionResults and `other` will

Review Comment:
   ```suggestion
       and `main` will be a PCollection of PredictionResults and `other` will
   ```



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -300,6 +359,116 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
   def update_model_path(self, model_path: Optional[str] = None):
     return self._unkeyed.update_model_path(model_path=model_path)
 
+  def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+    return self._unkeyed.get_preprocess_fns()
+
+  def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+    return self._unkeyed.get_postprocess_fns()
+
+
+class _PreProcessingModelHandler(Generic[ExampleT, PredictionT, ModelT, PreT],
+                                 ModelHandler[PreT, PredictionT, ModelT]):
+  def __init__(
+      self,
+      base: ModelHandler[ExampleT, PredictionT, ModelT],
+      preprocess_fn: Callable[[PreT], ExampleT]):
+    """A ModelHandler that has a preprocessing function associated with it.
+
+    Args:
+      base: An implementation of the underlying model handler.
+      preprocess_fn: the preprocessing function to use.
+    """
+    self._base = base
+    self._preprocess_fn = preprocess_fn
+
+  def load_model(self) -> ModelT:
+    return self._base.load_model()
+
+  def run_inference(
+      self,
+      batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
+      model: ModelT,
+      inference_args: Optional[Dict[str, Any]] = None
+  ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
+    return self._base.run_inference(batch, model, inference_args)
+
+  def get_num_bytes(
+      self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int:
+    return self._base.get_num_bytes(batch)
+
+  def get_metrics_namespace(self) -> str:
+    return self._base.get_metrics_namespace()
+
+  def get_resource_hints(self):
+    return self._base.get_resource_hints()
+
+  def batch_elements_kwargs(self):
+    return self._base.batch_elements_kwargs()
+
+  def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+    return self._base.validate_inference_args(inference_args)
+
+  def update_model_path(self, model_path: Optional[str] = None):
+    return self._base.update_model_path(model_path=model_path)
+
+  def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:

Review Comment:
   Why do we need this class? can we put this method in the `ModelHandler`s and append the `processing_fn` to a list when user calls `ModelHandler().with_preprocess_fn()` and return self?



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -404,13 +618,31 @@ def with_exception_handling(
 
     For example, one would write::
 
-        good, bad = RunInference(
+        main, other = RunInference(
           maybe_error_raising_model_handler
         ).with_exception_handling()
 
-    and `good` will be a PCollection of PredictionResults and `bad` will
-    contain a tuple of all batches that raised exceptions, along with their
-    corresponding exception.
+    and `good` will be a PCollection of PredictionResults and `other` will
+    contain a `RunInferenceDLQ` object with PCollections containing failed
+    records for each failed inference, preprocess operation, or postprocess
+    operation. To access each collection of failed records, one would write:
+
+        failed_inferences = other.failed_inferences
+        failed_preprocessing = other.failed_preprocessing
+        failed_postprocessing = other.failed_postprocessing
+
+    failed_inferences is in the form
+    PCollection[Tuple[failed batch, exception]].
+
+    failed_preprocessing is in the form
+    PCollectionList[Tuple[failed record, exception]]], where each element of
+    the list corresponds to a preprocess function. These PCollections are
+    in the same order that the preprocess functions are applied.
+
+    failed_postprocessing is in the form
+    PCollectionList[Tuple[failed record, exception]]], where each element of

Review Comment:
   ```suggestion
      list of  PCollection[Tuple[failed record, exception]]], where each element of
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org