You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/06/13 10:04:17 UTC

[tvm] branch main updated: [UnitTests] Parametrized test_topi_argwhere.py (#11651)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new eb611482e3 [UnitTests] Parametrized test_topi_argwhere.py (#11651)
eb611482e3 is described below

commit eb611482e3fb6da463d7458424518792d03fb89e
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon Jun 13 05:04:11 2022 -0500

    [UnitTests] Parametrized test_topi_argwhere.py (#11651)
    
    Refactored while debugging breakage of tests in
    https://github.com/apache/tvm/pull/11646.  Submitting as a separate
    PR, as it isn't necessary or related to the primary changes in that
    PR.
---
 tests/python/topi/python/test_topi_argwhere.py | 72 ++++++++++++--------------
 1 file changed, 34 insertions(+), 38 deletions(-)

diff --git a/tests/python/topi/python/test_topi_argwhere.py b/tests/python/topi/python/test_topi_argwhere.py
index 8592f57b74..bc43dbb2b0 100644
--- a/tests/python/topi/python/test_topi_argwhere.py
+++ b/tests/python/topi/python/test_topi_argwhere.py
@@ -16,8 +16,10 @@
 # under the License.
 """Test for argwhere operator"""
 import numpy as np
+import pytest
 
 import tvm
+import tvm.testing
 from tvm import te
 from tvm import topi
 import tvm.topi.testing
@@ -29,56 +31,50 @@ _argwhere_schedule = {
 
 _argwhere_compute = {"llvm": topi.argwhere, "cuda": topi.cuda.argwhere}
 
+data_shape = tvm.testing.parameter(
+    (1,),
+    (100,),
+    (1, 1),
+    (5, 3),
+    (32, 64),
+    (128, 65),
+    (200, 500),
+    (6, 5, 3),
+    (1, 1, 1),
+    (1, 1, 1, 1),
+    (6, 4, 5, 3),
+    (1, 1, 1, 1, 1),
+    (6, 4, 5, 3, 7),
+)
 
-def verify_argwhere(data_shape):
+
+@tvm.testing.parametrize_targets("llvm", "cuda")
+def test_argwhere(target, dev, data_shape):
     dtype = "int32"
     np_data = np.random.choice([0, 1, 2, 3], size=data_shape).astype(dtype)
     np_out = np.argwhere(np_data)
     out_shape = np_out.shape[0]
+
     np_shape = np.ones(shape=(out_shape, len(data_shape)), dtype=dtype)
 
     out_shape = te.placeholder(shape=(out_shape, len(data_shape)), name="out_shape", dtype=dtype)
     condition = te.placeholder(shape=data_shape, name="condition", dtype=dtype)
 
-    def check_device(target):
-        dev = tvm.device(target, 0)
-        if not dev.exist or target not in _argwhere_compute:
-            return
-
-        with tvm.target.Target(target):
-            out = _argwhere_compute[target](out_shape, condition)
-            s_func = tvm.topi.testing.dispatch(target, _argwhere_schedule)
-            sch = s_func(out)
-
-        func = tvm.build(sch, [out_shape, condition, out], target, name="argwhere")
-
-        args = [tvm.nd.array(np_shape, dev)]
-        args.append(tvm.nd.array(np_data, dev))
-        args.append(tvm.nd.empty(out.shape, device=dev, dtype=condition.dtype))
-        func(*args)
-        np.set_printoptions(threshold=np.inf)
-        tvm.testing.assert_allclose(args[-1].numpy(), np.array(np_out))
-
-    for target, _ in tvm.testing.enabled_targets():
-        check_device(target)
+    with tvm.target.Target(target):
+        out = _argwhere_compute[target](out_shape, condition)
+        s_func = tvm.topi.testing.dispatch(target, _argwhere_schedule)
+        sch = s_func(out)
 
+    func = tvm.build(sch, [out_shape, condition, out], target, name="argwhere")
 
-@tvm.testing.uses_gpu
-def test_argwhere():
-    verify_argwhere((1,))
-    verify_argwhere((100,))
-    verify_argwhere((1, 1))
-    verify_argwhere((5, 3))
-    verify_argwhere((32, 64))
-    verify_argwhere((128, 65))
-    verify_argwhere((200, 500))
-    verify_argwhere((6, 5, 3))
-    verify_argwhere((1, 1, 1))
-    verify_argwhere((1, 1, 1, 1))
-    verify_argwhere((6, 4, 5, 3))
-    verify_argwhere((1, 1, 1, 1, 1))
-    verify_argwhere((6, 4, 5, 3, 7))
+    args = [tvm.nd.array(np_shape, dev)]
+    args.append(tvm.nd.array(np_data, dev))
+    args.append(tvm.nd.empty(out.shape, device=dev, dtype=condition.dtype))
+    func(*args)
+    np.set_printoptions(threshold=np.inf)
+    tvm_out = args[-1].numpy()
+    tvm.testing.assert_allclose(tvm_out, np_out)
 
 
 if __name__ == "__main__":
-    test_argwhere()
+    tvm.testing.main()