You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/04/12 12:03:14 UTC

[tvm] branch unity updated: [Unity][PyTorch] Disable gradient during dynamo subgraph capture to save RAM (#14602)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 88f5b8f590 [Unity][PyTorch] Disable gradient during dynamo subgraph capture to save RAM (#14602)
88f5b8f590 is described below

commit 88f5b8f59091e9cd4a96cf2367fa405f4ff611c8
Author: masahi <ma...@gmail.com>
AuthorDate: Wed Apr 12 21:03:06 2023 +0900

    [Unity][PyTorch] Disable gradient during dynamo subgraph capture to save RAM (#14602)
    
    disable gradient in dynamo subgraph capture to save RAM
---
 python/tvm/relax/frontend/torch/dynamo.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/torch/dynamo.py b/python/tvm/relax/frontend/torch/dynamo.py
index f48a2cde3c..56eb257421 100644
--- a/python/tvm/relax/frontend/torch/dynamo.py
+++ b/python/tvm/relax/frontend/torch/dynamo.py
@@ -159,7 +159,10 @@ def dynamo_capture_subgraphs(model, *params, **kwargs) -> tvm.IRModule:
 
     dynamo.reset()
     compiled_model = torch.compile(model, backend=_capture)
-    compiled_model(*params, **kwargs)
+
+    with torch.no_grad():
+        compiled_model(*params, **kwargs)
+
     return mod