You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/05/30 23:53:06 UTC

[tvm] branch main updated: [Frontend][PyTorch][Bugfix] Ignore Cuda in PyTorch version number when comparing versions (#11511)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bc14f26aca [Frontend][PyTorch][Bugfix] Ignore Cuda in PyTorch version number when comparing versions (#11511)
bc14f26aca is described below

commit bc14f26aca1963a3ab858afa92a799729f6bd145
Author: Steven S. Lyubomirsky <ss...@cs.washington.edu>
AuthorDate: Mon May 30 16:53:00 2022 -0700

    [Frontend][PyTorch][Bugfix] Ignore Cuda in PyTorch version number when comparing versions (#11511)
    
    * Do not consider cuda in the PT version number
    
    * Add docstring
---
 python/tvm/relay/frontend/pytorch_utils.py | 12 +++++++++++-
 1 file changed, 11 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py
index 9dfda6b0b7..da4c9e039e 100644
--- a/python/tvm/relay/frontend/pytorch_utils.py
+++ b/python/tvm/relay/frontend/pytorch_utils.py
@@ -31,10 +31,20 @@ from ..dataflow_pattern import (
 
 
 def is_version_greater_than(ver):
+    """
+    Returns True if the local PyTorch version is greater
+    than the one given as an argument.
+    """
     import torch
     from distutils.version import LooseVersion
 
-    return LooseVersion(torch.__version__) > ver
+    torch_ver = torch.__version__
+    # PT version numbers can include +cu[cuda version code]
+    # and we don't want to include that in the comparison
+    if "+cu" in torch_ver:
+        torch_ver = torch_ver.split("+cu")[0]
+
+    return LooseVersion(torch_ver) > ver
 
 
 def getattr_attr_name(node):