You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2022/09/08 13:50:42 UTC

[tvm] branch main updated: [Test] Add tvm.testing.requires_libtorch (#12737)

This is an automated email from the ASF dual-hosted git repository.

mousius pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ed630122c2 [Test] Add tvm.testing.requires_libtorch (#12737)
ed630122c2 is described below

commit ed630122c281f47493e2941a7dc471e201904587
Author: Leandro Nunes <le...@arm.com>
AuthorDate: Thu Sep 8 14:50:36 2022 +0100

    [Test] Add tvm.testing.requires_libtorch (#12737)
    
    Create a specific test dependency to map to USE_LIBTORCH, which
    is disabled by deafult, and is independent from torch being
    installed on the underlying machine, so it causes problems in
    machines that have torch installed but TVM is build with
    USE_LIBTORCH OFF.
    
    Mark tests.python.contrib.test_libtorch_ops.test_backend with
    this new decorator.
---
 python/tvm/testing/utils.py               | 3 +++
 tests/python/contrib/test_libtorch_ops.py | 2 ++
 2 files changed, 5 insertions(+)

diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 5b70eb0691..37a27a4213 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -945,6 +945,9 @@ requires_rpc = Feature("rpc", "RPC", cmake_flag="USE_RPC")
 # Mark a test as requiring Arm(R) Ethos(TM)-N to run
 requires_ethosn = Feature("ethosn", "Arm(R) Ethos(TM)-N", cmake_flag="USE_ETHOSN")
 
+# Mark a test as requiring libtorch to run
+requires_libtorch = Feature("libtorch", "LibTorch", cmake_flag="USE_LIBTORCH")
+
 # Mark a test as requiring Hexagon to run
 requires_hexagon = Feature(
     "hexagon",
diff --git a/tests/python/contrib/test_libtorch_ops.py b/tests/python/contrib/test_libtorch_ops.py
index 28ae39c329..2bfb78b407 100644
--- a/tests/python/contrib/test_libtorch_ops.py
+++ b/tests/python/contrib/test_libtorch_ops.py
@@ -19,6 +19,7 @@ import pytest
 
 import tvm.relay
 from tvm.relay.op.contrib import torchop
+from tvm.testing import requires_libtorch
 
 import_torch_error = None
 
@@ -30,6 +31,7 @@ except ImportError as e:
 
 
 @pytest.mark.skipif(torch is None, reason=f"PyTorch is not available: {import_torch_error}")
+@requires_libtorch
 def test_backend():
     @torch.jit.script
     def script_fn(x, y):