You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2021/03/04 22:04:24 UTC
[tvm] 03/11: test cumsum on vulkan
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch vk-i64
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 1bea761d9423d1a79f05caaa9d043a106be2bfe4
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Wed Mar 3 08:18:24 2021 +0900
test cumsum on vulkan
---
tests/python/topi/python/test_topi_cumsum.py | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py
index a01a496..bf962d9 100644
--- a/tests/python/topi/python/test_topi_cumsum.py
+++ b/tests/python/topi/python/test_topi_cumsum.py
@@ -28,6 +28,7 @@ def test_cumsum(ctx, target):
"generic": (lambda x: topi.cumsum(x, axis, dtype), topi.generic.schedule_extern),
"cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
"nvptx": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
+ "vulkan": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
}
fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule)
@@ -40,8 +41,10 @@ def test_cumsum(ctx, target):
check_cumsum(np.cumsum(data, dtype=np.int32), data)
check_cumsum(np.cumsum(data), data, dtype="int64")
- data = np.random.rand(10) > 0.5
- check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32")
+ if str(target.kind) != "vulkan":
+ # TODO(masahi): Support bool tensor in SPIRV codegen
+ data = np.random.rand(10) > 0.5
+ check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32")
for in_dtype in ["float32", "float64"]:
data = np.random.randn(10, 10).astype(in_dtype)
@@ -70,3 +73,4 @@ if __name__ == "__main__":
test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm"))
test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda"))
test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx"))
+ test_cumsum(tvm.context("vulkan"), tvm.target.Target("vulkan"))