You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by "AnandInguva (via GitHub)" <gi...@apache.org> on 2023/02/06 20:26:57 UTC

[GitHub] [beam] AnandInguva commented on a diff in pull request #25321: Add support for loading torchscript models

AnandInguva commented on code in PR #25321:
URL: https://github.com/apache/beam/pull/25321#discussion_r1097882174


##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -174,6 +186,9 @@ def __init__(
         Otherwise, it will be CPU.
       inference_fn: the inference function to use during RunInference.
         default=_default_tensor_inference_fn
+      use_torch_script_format: When `use_torch_script_format` is set to `True`,
+        the model will be loaded using `torch.jit.load()`.
+        `model_class` and `model_params` arguments will be disregarded.

Review Comment:
   I will update the doc string to have `torch.jit.load(state_dict_path)`



##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -174,6 +186,9 @@ def __init__(
         Otherwise, it will be CPU.
       inference_fn: the inference function to use during RunInference.
         default=_default_tensor_inference_fn
+      use_torch_script_format: When `use_torch_script_format` is set to `True`,
+        the model will be loaded using `torch.jit.load()`.
+        `model_class` and `model_params` arguments will be disregarded.

Review Comment:
   I think the `use_torch_script_format` is a wrong name.  I can rename to something like `load_as_torchscript_model` or any suggestions?
   
   When user enables this, we call `torch.jit.load()`, which accepts `.pt`, `.pth` and also the new `zip` format torch is going to add as default soon https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference



##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -71,18 +73,26 @@ def _load_model(
   try:
     logging.info(
         "Loading state_dict_path %s onto a %s device", state_dict_path, device)
-    state_dict = torch.load(file, map_location=device)
+    if not use_torch_script_format:

Review Comment:
   >> My inclination after thinking more is that we should not do this since (a) it gives the users less safety from accidental errors, and (b) it is potentially constraining if we accept loading new model types that don't use state_dict_path or model_class in the future.
   
   Is this for the comment to make `model_class` as None and inferring from it? or is it regarding adding the parameter `use_torch_script_format`?



##########
sdks/python/apache_beam/ml/inference/pytorch_inference_test.py:
##########
@@ -609,6 +609,57 @@ def test_gpu_auto_convert_to_cpu(self):
           "are not available. Switching to CPU.",
           log.output)
 
+  def test_load_torch_script_model(self):
+    torch_model = PytorchLinearRegression(2, 1)
+    torch_script_model = torch.jit.script(torch_model)
+
+    torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt')
+
+    torch.jit.save(torch_script_model, torch_script_path)

Review Comment:
   These are pretty light models and takes ~3 seconds to run. 
   
   Also, the disadvantage of saving the models in GCS bucket for unittests is that we may need to manually update them when the pytorch future version becomes incompatible in any case.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org