You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/08/04 15:59:02 UTC

[tvm] branch main updated: [BugFix][TIR] ThreadSync with shared.dyn awareness (#15478)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 543838303b [BugFix][TIR] ThreadSync with shared.dyn awareness (#15478)
543838303b is described below

commit 543838303b4289bb5669688efb9f88b15ddc2ebe
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Fri Aug 4 08:58:55 2023 -0700

    [BugFix][TIR] ThreadSync with shared.dyn awareness (#15478)
    
    This PR fixes an issue of the ThreadSync pass.
    
    Prior to this PR, the pass is not aware of `shared.dyn` scope whose
    users all share a same shared memory space. This feature is not
    necessarily already revealed in the IR at the time of applying
    ThreadSync. This means that when applying ThreadSync, in the IR,
    each buffer of `shared.dyn` scope still uses its own data Var,
    and ThreadSync is thus unable to detect the conflict properly and
    insert the sync instructions properly.
    
    This PR explicitly makes ThreadSync be aware of the `shared.dyn` scope,
    and redirect all the access vars of `shared.dyn` memory to a common var,
    so that ThreadSync analysis can find out the conflict and insert the
    sync instructions.
---
 src/tir/transforms/thread_storage_sync.cc          | 18 +++++++++-
 .../unittest/test_tir_transform_thread_sync.py     | 42 ++++++++++++++++++++++
 2 files changed, 59 insertions(+), 1 deletion(-)

diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc
index c21afe400c..d92986e51a 100644
--- a/src/tir/transforms/thread_storage_sync.cc
+++ b/src/tir/transforms/thread_storage_sync.cc
@@ -50,11 +50,27 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
   }
   // Plan the sync
   std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final {
+    // Redirect all "shared.dyn" buffer access to the same buffer var
+    // so that the accesses can be planned together.
+    Var shared_dyn_buf;
+    for (StmtEntry& entry : seq) {
+      for (AccessEntry& access : entry.access) {
+        if (access.scope.rank == StorageRank::kShared && access.scope.tag == ".dyn" &&
+            access.buffer.defined()) {
+          if (!shared_dyn_buf.defined()) {
+            shared_dyn_buf = access.buffer;
+          } else {
+            access.buffer = shared_dyn_buf;
+          }
+        }
+      }
+    }
+
     // Unsynced reads and writes
     std::vector<AccessEntry> reads;
     std::vector<AccessEntry> writes;
     // if it is a loop, rotate two times to consider effect of loop.
-    // simulation based approach to find dependenceies
+    // simulation based approach to find dependencies
     for (size_t i = 0; i < seq.size(); ++i) {
       const StmtEntry& s = seq[i];
       // check if sync before statement is needed.
diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py
index 57ea223cf9..571927dffe 100644
--- a/tests/python/unittest/test_tir_transform_thread_sync.py
+++ b/tests/python/unittest/test_tir_transform_thread_sync.py
@@ -119,7 +119,49 @@ def test_sync_read_thread_id_independent_location():
     assert "T.tvm_storage_sync" in str(mod)
 
 
+def test_sync_shared_dyn():
+    @T.prim_func(private=True)
+    def func(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")):
+        blockIdx_x = T.launch_thread("blockIdx.x", 1)
+        B = T.allocate([24], "float32", "shared.dyn")
+        C = T.allocate([1], "float32", "local")
+        D = T.allocate([16], "float32", "shared.dyn")
+        threadIdx_x = T.launch_thread("threadIdx.x", 16)
+        B_1 = T.Buffer((24,), data=B, scope="shared.dyn")
+        A_1 = T.Buffer((16,), data=A.data)
+        B_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] = A_1[threadIdx_x]
+        C_1 = T.Buffer((1,), data=C, scope="local")
+        C_1[0] = B_1[threadIdx_x // 4 * 6 + threadIdx_x % 4]
+        D_1 = T.Buffer((16,), data=D, scope="shared.dyn")
+        D_1[threadIdx_x] = C_1[0]
+        E_1 = T.Buffer((16,), data=E.data)
+        E_1[threadIdx_x] = D_1[threadIdx_x]
+
+    @T.prim_func(private=True)
+    def expected(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")):
+        blockIdx_x = T.launch_thread("blockIdx.x", 1)
+        B_1 = T.allocate([24], "float32", "shared.dyn")
+        C_1 = T.allocate([1], "float32", "local")
+        D_1 = T.allocate([16], "float32", "shared.dyn")
+        threadIdx_x = T.launch_thread("threadIdx.x", 16)
+        B_1_1 = T.Buffer((24,), data=B_1, scope="shared.dyn")
+        A_1 = T.Buffer((16,), data=A.data)
+        B_1_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] = A_1[threadIdx_x]
+        C_1_1 = T.Buffer((1,), data=C_1, scope="local")
+        C_1_1[0] = B_1_1[threadIdx_x // 4 * 6 + threadIdx_x % 4]
+        T.tvm_storage_sync("shared.dyn")
+        D_1_1 = T.Buffer((16,), data=D_1, scope="shared.dyn")
+        D_1_1[threadIdx_x] = C_1_1[0]
+        E_1 = T.Buffer((16,), data=E.data)
+        E_1[threadIdx_x] = D_1_1[threadIdx_x]
+
+    mod = tvm.IRModule({"main": func})
+    mod = tvm.tir.transform.ThreadSync("shared.dyn")(mod)
+    tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
 if __name__ == "__main__":
     test_thread_storage_sync()
     test_sync_else_branch()
     test_sync_read_thread_id_independent_location()
+    test_sync_shared_dyn()