You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/05/30 23:53:06 UTC
[tvm] branch main updated: [Frontend][PyTorch][Bugfix] Ignore Cuda in PyTorch version number when comparing versions (#11511)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bc14f26aca [Frontend][PyTorch][Bugfix] Ignore Cuda in PyTorch version number when comparing versions (#11511)
bc14f26aca is described below
commit bc14f26aca1963a3ab858afa92a799729f6bd145
Author: Steven S. Lyubomirsky <ss...@cs.washington.edu>
AuthorDate: Mon May 30 16:53:00 2022 -0700
[Frontend][PyTorch][Bugfix] Ignore Cuda in PyTorch version number when comparing versions (#11511)
* Do not consider cuda in the PT version number
* Add docstring
---
python/tvm/relay/frontend/pytorch_utils.py | 12 +++++++++++-
1 file changed, 11 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py
index 9dfda6b0b7..da4c9e039e 100644
--- a/python/tvm/relay/frontend/pytorch_utils.py
+++ b/python/tvm/relay/frontend/pytorch_utils.py
@@ -31,10 +31,20 @@ from ..dataflow_pattern import (
def is_version_greater_than(ver):
+ """
+ Returns True if the local PyTorch version is greater
+ than the one given as an argument.
+ """
import torch
from distutils.version import LooseVersion
- return LooseVersion(torch.__version__) > ver
+ torch_ver = torch.__version__
+ # PT version numbers can include +cu[cuda version code]
+ # and we don't want to include that in the comparison
+ if "+cu" in torch_ver:
+ torch_ver = torch_ver.split("+cu")[0]
+
+ return LooseVersion(torch_ver) > ver
def getattr_attr_name(node):