You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by "damccorm (via GitHub)" <gi...@apache.org> on 2023/02/06 20:00:55 UTC

[GitHub] [beam] damccorm commented on a diff in pull request #25321: Add support for loading torchscript models

damccorm commented on code in PR #25321:
URL: https://github.com/apache/beam/pull/25321#discussion_r1097855765


##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -429,3 +467,12 @@ def get_metrics_namespace(self) -> str:
 
   def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
     pass
+
+  def _validate_func_args(self):

Review Comment:
   Can we share this validation function across model handlers?



##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -71,18 +73,26 @@ def _load_model(
   try:
     logging.info(
         "Loading state_dict_path %s onto a %s device", state_dict_path, device)
-    state_dict = torch.load(file, map_location=device)
+    if not use_torch_script_format:

Review Comment:
   Rather than parameterizing this, should we just accept None as values for model_class and make our decision based on that?



##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -174,6 +186,9 @@ def __init__(
         Otherwise, it will be CPU.
       inference_fn: the inference function to use during RunInference.
         default=_default_tensor_inference_fn
+      use_torch_script_format: When `use_torch_script_format` is set to `True`,
+        the model will be loaded using `torch.jit.load()`.
+        `model_class` and `model_params` arguments will be disregarded.

Review Comment:
   State_dict_path can now hold more than just a state_dict, right? (it can hold the torch_script file format) If so, could you update that example usage parameter?



##########
sdks/python/apache_beam/ml/inference/pytorch_inference_test.py:
##########
@@ -609,6 +609,57 @@ def test_gpu_auto_convert_to_cpu(self):
           "are not available. Switching to CPU.",
           log.output)
 
+  def test_load_torch_script_model(self):
+    torch_model = PytorchLinearRegression(2, 1)
+    torch_script_model = torch.jit.script(torch_model)
+
+    torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt')
+
+    torch.jit.save(torch_script_model, torch_script_path)

Review Comment:
   Does this work for remote models? If so, can we just use a remote model in gcs rather than saving a model for each test?



##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -71,18 +73,26 @@ def _load_model(
   try:
     logging.info(
         "Loading state_dict_path %s onto a %s device", state_dict_path, device)
-    state_dict = torch.load(file, map_location=device)
+    if not use_torch_script_format:

Review Comment:
   My inclination after thinking more is that we should not do this since (a) it gives the users less safety from accidental errors, and (b) it is potentially constraining if we accept loading new model types that don't use state_dict_path or model_class in the future.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org