You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tv...@apache.org on 2022/06/29 18:58:11 UTC
[beam] branch master updated: Fix missing model_params in Pytorch docstring (#22100)
This is an automated email from the ASF dual-hosted git repository.
tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 9ffeced5f24 Fix missing model_params in Pytorch docstring (#22100)
9ffeced5f24 is described below
commit 9ffeced5f246b3d72eedf8c55aa20574ae9d07cb
Author: Andy Ye <an...@gmail.com>
AuthorDate: Wed Jun 29 11:58:00 2022 -0700
Fix missing model_params in Pytorch docstring (#22100)
* Fix docstring wording
* Update sdks/python/apache_beam/ml/inference/pytorch_inference.py
Co-authored-by: Anand Inguva <34...@users.noreply.github.com>
* Update sdks/python/apache_beam/ml/inference/pytorch_inference.py
Co-authored-by: Anand Inguva <34...@users.noreply.github.com>
Co-authored-by: Anand Inguva <34...@users.noreply.github.com>
---
sdks/python/apache_beam/ml/inference/pytorch_inference.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index 6875737a56d..d32ed50a618 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -72,6 +72,8 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
state_dict_path: path to the saved dictionary of the model state.
model_class: class of the Pytorch model that defines the model
structure.
+ model_params: A dictionary of arguments required to instantiate the model
+ class.
device: the device on which you wish to run the model. If
``device = GPU`` then a GPU device will be used if it is available.
Otherwise, it will be CPU.
@@ -169,10 +171,11 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
state_dict_path: path to the saved dictionary of the model state.
model_class: class of the Pytorch model that defines the model
structure.
+ model_params: A dictionary of arguments required to instantiate the model
+ class.
device: the device on which you wish to run the model. If
``device = GPU`` then a GPU device will be used if it is available.
Otherwise, it will be CPU.
- model_params: A dictionary of parameters passed in to the model class.
"""
self._state_dict_path = state_dict_path
if device == 'GPU' and torch.cuda.is_available():