You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/06/04 23:12:16 UTC

[tvm] branch main updated: Run ONNX Node Tests on available targets (#8189)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e0baf80  Run ONNX Node Tests on available targets (#8189)
e0baf80 is described below

commit e0baf80d886fe00a30b0a52ff3cb207b03b3ee8e
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Fri Jun 4 17:11:50 2021 -0600

    Run ONNX Node Tests on available targets (#8189)
---
 tests/python/frontend/onnx/test_forward.py | 33 +++++++++++++++++++++++-------
 1 file changed, 26 insertions(+), 7 deletions(-)

diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 423f031..6ac747c 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -4306,8 +4306,25 @@ unsupported_onnx_tests = [
 ]
 
 
+targets = [tgt for (tgt, _) in tvm.testing.enabled_targets()]
+
+target_skips = {
+    "cuda": [
+        "test_mod_mixed_sign_float16/",
+        "test_qlinearconv/",
+        "test_resize_upsample_sizes_nearest/",
+    ]
+}
+
+
+@pytest.mark.parametrize("target", targets)
 @pytest.mark.parametrize("test", onnx_test_folders)
-def test_onnx_nodes(test):
+def test_onnx_nodes(test, target):
+    if target in target_skips:
+        for failure in target_skips[target]:
+            if failure in test:
+                pytest.skip()
+                break
     for failure in unsupported_onnx_tests:
         if failure in test:
             pytest.skip()
@@ -4333,12 +4350,14 @@ def test_onnx_nodes(test):
                 outputs.append(numpy_helper.to_array(new_tensor))
             else:
                 raise ImportError(str(tensor) + " not labeled as an import or an output")
-        tvm_val = get_tvm_output_with_vm(onnx_model, inputs, "llvm", tvm.cpu(0))
-        if len(outputs) == 1:
-            tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=rtol, atol=atol)
-        else:
-            for output, val in zip(outputs, tvm_val):
-                tvm.testing.assert_allclose(output, val, rtol=rtol, atol=atol)
+
+    dev = tvm.device(target, 0)
+    tvm_val = get_tvm_output_with_vm(onnx_model, inputs, target, dev)
+    if len(outputs) == 1:
+        tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=rtol, atol=atol)
+    else:
+        for output, val in zip(outputs, tvm_val):
+            tvm.testing.assert_allclose(output, val, rtol=rtol, atol=atol)
 
 
 def test_wrong_input():