You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bh...@apache.org on 2022/04/25 23:09:30 UTC

[beam] branch master updated: Change return type for PytorchInferenceRunner (#17460)

This is an automated email from the ASF dual-hosted git repository.

bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new eedff4b996c Change return type for PytorchInferenceRunner (#17460)
eedff4b996c is described below

commit eedff4b996c042d4e5a61e42a2da5f776ff42afd
Author: Anand Inguva <34...@users.noreply.github.com>
AuthorDate: Mon Apr 25 23:09:22 2022 +0000

    Change return type for PytorchInferenceRunner (#17460)
---
 sdks/python/apache_beam/ml/inference/pytorch.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/sdks/python/apache_beam/ml/inference/pytorch.py b/sdks/python/apache_beam/ml/inference/pytorch.py
index 62741a899e8..438582d73dc 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch.py
@@ -37,7 +37,7 @@ class PytorchInferenceRunner(InferenceRunner):
     self._device = device
 
   def run_inference(self, batch: List[torch.Tensor],
-                    model: torch.nn.Module) -> Iterable[torch.Tensor]:
+                    model: torch.nn.Module) -> Iterable[PredictionResult]:
     """
     Runs inferences on a batch of Tensors and returns an Iterable of
     Tensor Predictions.