You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tv...@apache.org on 2022/06/29 18:58:11 UTC

[beam] branch master updated: Fix missing model_params in Pytorch docstring (#22100)

This is an automated email from the ASF dual-hosted git repository.

tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 9ffeced5f24 Fix missing model_params in Pytorch docstring  (#22100)
9ffeced5f24 is described below

commit 9ffeced5f246b3d72eedf8c55aa20574ae9d07cb
Author: Andy Ye <an...@gmail.com>
AuthorDate: Wed Jun 29 11:58:00 2022 -0700

    Fix missing model_params in Pytorch docstring  (#22100)
    
    * Fix docstring wording
    
    * Update sdks/python/apache_beam/ml/inference/pytorch_inference.py
    
    Co-authored-by: Anand Inguva <34...@users.noreply.github.com>
    
    * Update sdks/python/apache_beam/ml/inference/pytorch_inference.py
    
    Co-authored-by: Anand Inguva <34...@users.noreply.github.com>
    
    Co-authored-by: Anand Inguva <34...@users.noreply.github.com>
---
 sdks/python/apache_beam/ml/inference/pytorch_inference.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index 6875737a56d..d32ed50a618 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -72,6 +72,8 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
       state_dict_path: path to the saved dictionary of the model state.
       model_class: class of the Pytorch model that defines the model
         structure.
+      model_params: A dictionary of arguments required to instantiate the model
+        class.
       device: the device on which you wish to run the model. If
         ``device = GPU`` then a GPU device will be used if it is available.
         Otherwise, it will be CPU.
@@ -169,10 +171,11 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
       state_dict_path: path to the saved dictionary of the model state.
       model_class: class of the Pytorch model that defines the model
         structure.
+      model_params: A dictionary of arguments required to instantiate the model
+        class.
       device: the device on which you wish to run the model. If
         ``device = GPU`` then a GPU device will be used if it is available.
         Otherwise, it will be CPU.
-      model_params: A dictionary of parameters passed in to the model class.
     """
     self._state_dict_path = state_dict_path
     if device == 'GPU' and torch.cuda.is_available():