You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/07/15 16:19:35 UTC

[tvm] branch main updated: [Fix][TIR] LowerThreadAllreduce with correct thread mask (#15323)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9af8efcd2d [Fix][TIR] LowerThreadAllreduce with correct thread mask (#15323)
9af8efcd2d is described below

commit 9af8efcd2d8174acf97f4921339c9327efda2af8
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Sat Jul 15 09:19:29 2023 -0700

    [Fix][TIR] LowerThreadAllreduce with correct thread mask (#15323)
    
    This PR fixes a bug in the LowerThreadAllreduce pass.
    
    Prior to this PR, in multi-group settings, the thread mask is not
    correctly set: when the reduction extent is 32, the thread mask will
    always be 0. This bug was not spotted because even when the mask is 0,
    the CUDA program still gives correct result. But in any way, having
    the zero mask is dangerous and should be fixed.
---
 src/tir/transforms/lower_thread_allreduce.cc       | 10 ++--
 .../test_tir_transform_lower_thread_all_reduce.py  | 65 ++++++++++++++++++++++
 2 files changed, 70 insertions(+), 5 deletions(-)

diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc
index c1566936c5..97a34a6ede 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -333,8 +333,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       {
         PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
         if (group_extent > 1) {
-          mask = mask &
-                 (((1 << reduce_extent) - 1) << (reduce_extent * cast(mask_dtype, group_index)));
+          mask = mask & (make_const(mask_dtype, (1ll << reduce_extent) - 1)
+                         << (reduce_extent * cast(mask_dtype, group_index)));
         }
         seq.emplace_back(BufferStore(mask_buffer, mask, zero_indices));
         // Push the buffer description.  Later this will have an
@@ -392,7 +392,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
         // During the sub-warp reduction, values from inactive threads could be read,
         // which is an undefined behavior according to the cuda document.
         //
-        // In practise, the return value are usually 0, which does no harm to sum reduction.
+        // In practice, the return value are usually 0, which does no harm to sum reduction.
         // However, the result can be incorrect in max or prod reduction.
         // Therefore an additional range check has to be performed to ensure the correctness.
         if (offset * 2 > reduce_extent) {
@@ -405,7 +405,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
 
       // Broadcast the reduction result from lane 0 to all other lanes.
       // This avoids to emit predicated stores, as all threads are
-      // uniformly writting the same result.
+      // uniformly writing the same result.
       //
       for (size_t i = 0; i < size; ++i) {
         Buffer buf = shared_bufs[i];
@@ -669,7 +669,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       return false;
     }
 
-    // whether reduce_extent and group_extent are vaild for warp reduction.
+    // whether reduce_extent and group_extent are valid for warp reduction.
     if (target_->kind->name == "rocm") {
       return reduce_extent == warp_size_;
     } else {  // target_->kind->name == "cuda"
diff --git a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
index f20d11ffb4..c9e6136ca8 100644
--- a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
+++ b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
@@ -235,5 +235,70 @@ class TestReduceSummation(BaseCompare):
                 B[i] = reduce[0]
 
 
+class TestMultiGroupMask(BaseCompare):
+    @T.prim_func
+    def before(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32,), "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        threadIdx_y = T.launch_thread("threadIdx.y", 32)
+        cross_thread_B = T.allocate([1], "float32", "local")
+        threadIdx_x = T.launch_thread("threadIdx.x", 32)
+        cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local")
+        with T.attr(
+            T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+            "reduce_scope",
+            T.reinterpret("handle", T.uint64(0)),
+        ):
+            A_1 = T.Buffer((1024,), data=A.data)
+            T.tvm_thread_allreduce(
+                T.uint32(1),
+                A_1[threadIdx_y * 32 + threadIdx_x],
+                T.bool(True),
+                cross_thread_B_1[0],
+                threadIdx_x,
+            )
+        if threadIdx_x == 0:
+            B_1 = T.Buffer((32,), data=B.data)
+            B_1[threadIdx_y] = cross_thread_B_1[0]
+
+    @T.prim_func
+    def expected(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32,), "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        threadIdx_y = T.launch_thread("threadIdx.y", 32)
+        red_buf0 = T.allocate([1], "float32", "local")
+        threadIdx_x = T.launch_thread("threadIdx.x", 32)
+        red_buf0_1 = T.Buffer((1,), data=red_buf0, scope="local")
+        with T.attr(
+            T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+            "reduce_scope",
+            T.reinterpret("handle", T.uint64(0)),
+        ):
+            mask = T.allocate([1], "uint32", "local")
+            t0 = T.allocate([1], "float32", "local")
+            A_1 = T.Buffer((1024,), data=A.data)
+            red_buf0_1[0] = A_1[threadIdx_y * 32 + threadIdx_x]
+
+            mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local")
+            mask_1[0] = T.bitwise_and(
+                T.tvm_warp_activemask(),
+                T.shift_left(T.uint32(4294967295), T.uint32(32) * T.Cast("uint32", threadIdx_y)),
+            )
+
+            t0_1 = T.Buffer((1,), data=t0, scope="local")
+            t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16, 32, 32)
+            red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
+            t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 8, 32, 32)
+            red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
+            t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 32)
+            red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
+            t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 32)
+            red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
+            t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32)
+            red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
+            red_buf0_1[0] = T.tvm_warp_shuffle(mask_1[0], red_buf0_1[0], 32 * threadIdx_y, 32, 32)
+        if threadIdx_x == 0:
+            B_1 = T.Buffer((32,), data=B.data)
+            B_1[threadIdx_y] = red_buf0_1[0]
+
+
 if __name__ == "__main__":
     tvm.testing.main()