You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/06/13 10:04:17 UTC
[tvm] branch main updated: [UnitTests] Parametrized test_topi_argwhere.py (#11651)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new eb611482e3 [UnitTests] Parametrized test_topi_argwhere.py (#11651)
eb611482e3 is described below
commit eb611482e3fb6da463d7458424518792d03fb89e
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon Jun 13 05:04:11 2022 -0500
[UnitTests] Parametrized test_topi_argwhere.py (#11651)
Refactored while debugging breakage of tests in
https://github.com/apache/tvm/pull/11646. Submitting as a separate
PR, as it isn't necessary or related to the primary changes in that
PR.
---
tests/python/topi/python/test_topi_argwhere.py | 72 ++++++++++++--------------
1 file changed, 34 insertions(+), 38 deletions(-)
diff --git a/tests/python/topi/python/test_topi_argwhere.py b/tests/python/topi/python/test_topi_argwhere.py
index 8592f57b74..bc43dbb2b0 100644
--- a/tests/python/topi/python/test_topi_argwhere.py
+++ b/tests/python/topi/python/test_topi_argwhere.py
@@ -16,8 +16,10 @@
# under the License.
"""Test for argwhere operator"""
import numpy as np
+import pytest
import tvm
+import tvm.testing
from tvm import te
from tvm import topi
import tvm.topi.testing
@@ -29,56 +31,50 @@ _argwhere_schedule = {
_argwhere_compute = {"llvm": topi.argwhere, "cuda": topi.cuda.argwhere}
+data_shape = tvm.testing.parameter(
+ (1,),
+ (100,),
+ (1, 1),
+ (5, 3),
+ (32, 64),
+ (128, 65),
+ (200, 500),
+ (6, 5, 3),
+ (1, 1, 1),
+ (1, 1, 1, 1),
+ (6, 4, 5, 3),
+ (1, 1, 1, 1, 1),
+ (6, 4, 5, 3, 7),
+)
-def verify_argwhere(data_shape):
+
+@tvm.testing.parametrize_targets("llvm", "cuda")
+def test_argwhere(target, dev, data_shape):
dtype = "int32"
np_data = np.random.choice([0, 1, 2, 3], size=data_shape).astype(dtype)
np_out = np.argwhere(np_data)
out_shape = np_out.shape[0]
+
np_shape = np.ones(shape=(out_shape, len(data_shape)), dtype=dtype)
out_shape = te.placeholder(shape=(out_shape, len(data_shape)), name="out_shape", dtype=dtype)
condition = te.placeholder(shape=data_shape, name="condition", dtype=dtype)
- def check_device(target):
- dev = tvm.device(target, 0)
- if not dev.exist or target not in _argwhere_compute:
- return
-
- with tvm.target.Target(target):
- out = _argwhere_compute[target](out_shape, condition)
- s_func = tvm.topi.testing.dispatch(target, _argwhere_schedule)
- sch = s_func(out)
-
- func = tvm.build(sch, [out_shape, condition, out], target, name="argwhere")
-
- args = [tvm.nd.array(np_shape, dev)]
- args.append(tvm.nd.array(np_data, dev))
- args.append(tvm.nd.empty(out.shape, device=dev, dtype=condition.dtype))
- func(*args)
- np.set_printoptions(threshold=np.inf)
- tvm.testing.assert_allclose(args[-1].numpy(), np.array(np_out))
-
- for target, _ in tvm.testing.enabled_targets():
- check_device(target)
+ with tvm.target.Target(target):
+ out = _argwhere_compute[target](out_shape, condition)
+ s_func = tvm.topi.testing.dispatch(target, _argwhere_schedule)
+ sch = s_func(out)
+ func = tvm.build(sch, [out_shape, condition, out], target, name="argwhere")
-@tvm.testing.uses_gpu
-def test_argwhere():
- verify_argwhere((1,))
- verify_argwhere((100,))
- verify_argwhere((1, 1))
- verify_argwhere((5, 3))
- verify_argwhere((32, 64))
- verify_argwhere((128, 65))
- verify_argwhere((200, 500))
- verify_argwhere((6, 5, 3))
- verify_argwhere((1, 1, 1))
- verify_argwhere((1, 1, 1, 1))
- verify_argwhere((6, 4, 5, 3))
- verify_argwhere((1, 1, 1, 1, 1))
- verify_argwhere((6, 4, 5, 3, 7))
+ args = [tvm.nd.array(np_shape, dev)]
+ args.append(tvm.nd.array(np_data, dev))
+ args.append(tvm.nd.empty(out.shape, device=dev, dtype=condition.dtype))
+ func(*args)
+ np.set_printoptions(threshold=np.inf)
+ tvm_out = args[-1].numpy()
+ tvm.testing.assert_allclose(tvm_out, np_out)
if __name__ == "__main__":
- test_argwhere()
+ tvm.testing.main()