You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/04/19 03:12:03 UTC

[incubator-tvm] branch master updated: [TIR] Fix lower_warp_memory when there are >1 warp buffers (#5368)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new a2d6fe6  [TIR] Fix lower_warp_memory when there are >1 warp buffers (#5368)
a2d6fe6 is described below

commit a2d6fe65950e927ee291dd75bf658ce1d1429e41
Author: Tang, Shizhi <rd...@gmail.com>
AuthorDate: Sun Apr 19 11:11:51 2020 +0800

    [TIR] Fix lower_warp_memory when there are >1 warp buffers (#5368)
    
    * fix recursion in lower_warp_memory
    
    * post-order mutation
---
 src/tir/transforms/lower_warp_memory.cc            |  7 ++--
 .../test_tir_transform_lower_warp_memory.py        | 49 ++++++++++++++++++++++
 2 files changed, 53 insertions(+), 3 deletions(-)

diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc
index 71e7cfa..0aee3c2 100644
--- a/src/tir/transforms/lower_warp_memory.cc
+++ b/src/tir/transforms/lower_warp_memory.cc
@@ -377,12 +377,13 @@ class WarpMemoryRewriter : private StmtMutator {
 
  private:
   Stmt VisitStmt_(const AllocateNode* op) {
+    auto ret = StmtMutator::VisitStmt_(op);
+    op = ret.as<AllocateNode>();
     if (warp_buffer_.count(op->buffer_var.get())) {
       WarpAccessRewriter rewriter(warp_size_, &analyzer_);
-      return rewriter.Rewrite(op);
-    } else {
-      return StmtMutator::VisitStmt_(op);
+      ret = rewriter.Rewrite(op);
     }
+    return ret;
   }
 
   Stmt VisitStmt_(const AttrStmtNode* op) {
diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
index a761cf1..51be480 100644
--- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py
+++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
@@ -132,7 +132,56 @@ def test_lower_warp_memory_cuda_half_a_warp():
     check_cuda("float32")
     check_cuda("float16")
 
+def test_lower_warp_memory_cuda_2_buffers():
+    def check_cuda(dtype):
+        if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
+            print("skip because cuda is not enabled..")
+            return
+        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+            print("Skip because gpu does not have fp16 support")
+            return
+
+        m = 32
+        A = te.placeholder((m,), name='A', dtype=dtype)
+        B = te.placeholder((m,), name='B', dtype=dtype)
+        C = te.compute((m,), lambda i: A[(i + 1) % m] + B[(i + 1) % m], name='C')
+
+        cuda_target = tvm.target.create("cuda")
+        assert m <= cuda_target.thread_warp_size
+        with cuda_target:
+            s = te.create_schedule(C.op)
+            tx = te.thread_axis("threadIdx.x")
+            bx = te.thread_axis("blockIdx.x")
+
+            AA = s.cache_read(A, "warp", [C])
+            BB = s.cache_read(B, "warp", [C])
+            xo, xi = s[C].split(C.op.axis[0], nparts=1)
+            s[C].bind(xi, tx)
+            s[C].bind(xo, bx)
+            s[AA].compute_at(s[C], xo)
+            s[BB].compute_at(s[C], xo)
+            xo, xi = s[AA].split(s[AA].op.axis[0], nparts=1)
+            s[AA].bind(xo, bx)
+            s[AA].bind(xi, tx)
+            xo, xi = s[BB].split(s[BB].op.axis[0], nparts=1)
+            s[BB].bind(xo, bx)
+            s[BB].bind(xi, tx)
+
+            ctx = tvm.gpu(0)
+            func = tvm.build(s, [A, B, C], "cuda")
+            AB_np = np.array(list(range(m)), dtype=dtype)
+            C_np = np.array(list(range(1, m)) + [0], dtype=dtype) * 2
+            A_nd = tvm.nd.array(AB_np, ctx)
+            B_nd = tvm.nd.array(AB_np, ctx)
+            C_nd = tvm.nd.array(np.zeros(C_np.shape, dtype=C_np.dtype), ctx)
+            func(A_nd, B_nd, C_nd)
+            tvm.testing.assert_allclose(C_nd.asnumpy(), C_np, rtol=1e-3)
+
+    check_cuda("float32")
+    check_cuda("float16")
+
 if __name__ == "__main__":
     test_lower_warp_memory_local_scope()
     test_lower_warp_memory_cuda_end_to_end()
     test_lower_warp_memory_cuda_half_a_warp()
+    test_lower_warp_memory_cuda_2_buffers()