You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2022/06/14 14:42:49 UTC

[GitHub] [beam] ryanthompson591 commented on a diff in pull request #21810: Split PytorchModelHandler into PytorchModelHandlerTensor and PytorchModelHandlerKeyedTensor

ryanthompson591 commented on code in PR #21810:
URL: https://github.com/apache/beam/pull/21810#discussion_r896870386


##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -67,67 +84,114 @@ def __init__(
 
   def load_model(self) -> torch.nn.Module:
     """Loads and initializes a Pytorch model for processing."""
-    model = self._model_class(**self._model_params)
-    model.to(self._device)
-    file = FileSystems.open(self._state_dict_path, 'rb')
-    model.load_state_dict(torch.load(file))
-    model.eval()
-    return model
-
-  def _convert_to_device(self, examples: torch.Tensor) -> torch.Tensor:
+    return _load_model(
+        self._model_class,
+        self._state_dict_path,
+        self._device,
+        **self._model_params)

Review Comment:
   What do you think about naming model parameters as a dictionary?
   
   The advantage is that users can specify exactly what their parameters should be.
   
   They would specify the parameters like this:
   
   model_parameters = {
     'key_1': 'parameter_1' 
   }
   
   Then in the future if optional parameters are added they won't collide.
   
   
   Feel free to do that change in another PR if you think it's a good idea.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org