You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2021/03/04 22:04:24 UTC

[tvm] 03/11: test cumsum on vulkan

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch vk-i64
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 1bea761d9423d1a79f05caaa9d043a106be2bfe4
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Wed Mar 3 08:18:24 2021 +0900

    test cumsum on vulkan
---
 tests/python/topi/python/test_topi_cumsum.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py
index a01a496..bf962d9 100644
--- a/tests/python/topi/python/test_topi_cumsum.py
+++ b/tests/python/topi/python/test_topi_cumsum.py
@@ -28,6 +28,7 @@ def test_cumsum(ctx, target):
             "generic": (lambda x: topi.cumsum(x, axis, dtype), topi.generic.schedule_extern),
             "cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
             "nvptx": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
+            "vulkan": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
         }
         fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
         tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule)
@@ -40,8 +41,10 @@ def test_cumsum(ctx, target):
     check_cumsum(np.cumsum(data, dtype=np.int32), data)
     check_cumsum(np.cumsum(data), data, dtype="int64")
 
-    data = np.random.rand(10) > 0.5
-    check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32")
+    if str(target.kind) != "vulkan":
+        # TODO(masahi): Support bool tensor in SPIRV codegen
+        data = np.random.rand(10) > 0.5
+        check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32")
 
     for in_dtype in ["float32", "float64"]:
         data = np.random.randn(10, 10).astype(in_dtype)
@@ -70,3 +73,4 @@ if __name__ == "__main__":
     test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm"))
     test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda"))
     test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx"))
+    test_cumsum(tvm.context("vulkan"), tvm.target.Target("vulkan"))