You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/07/16 15:36:21 UTC

[tvm] branch main updated: [TIR] ThreadAllreduce warp-level primitive support with multi-warp (#15327)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e25b1ba70a [TIR] ThreadAllreduce warp-level primitive support with multi-warp (#15327)
e25b1ba70a is described below

commit e25b1ba70a27399fb4da257c521c1e2bbb178ad8
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Sun Jul 16 08:36:14 2023 -0700

    [TIR] ThreadAllreduce warp-level primitive support with multi-warp (#15327)
    
    This PR enhances the implementation of the LowerThreadAllreduce pass.
    
    Prior to this PR, for CUDA backend we will leverage warp-level
    primitives only when
    * the reducing threads are a sub-warp (i.e., size 16, 8, 4, 2), or
    * the number of reducing threads is less then 32, and equals the
    reduction extent.
    
    Under the requirement above, for reductions that have large number
    of reducing threads (e.g., reducing over 128, 256 or larger number
    or threads), the generated code is inefficient.
    
    This PR improves the LowerThreadAllreduce pass, so that we now generate
    more efficient CUDA code in such cases, when the number of reducing
    threads is a multiple of warp size, with the help of warp-level
    primitives.
    
    Specifically, in such cases, we first reducing 32 elements within
    each warp, getting the results of each warp stored in shared memory.
    We then trigger a second round of warp-level primitive reduction
    within the first warp, and get the final reduction results.
    
    In addition to using warp-level primitives, by doing this we also
    reduce the size of the shared memory. For example, even when reducing
    over 1024 threads, we now only require shared memory of size 32,
    compared with 1024 prior to this PR.
    
    Tests are added to ensure correctness.
---
 python/tvm/tir/op.py                               |   2 +-
 src/te/operation/cross_thread_reduction.cc         |  13 +-
 src/tir/transforms/lower_thread_allreduce.cc       | 339 +++++++++++-------
 .../test_tir_transform_lower_thread_all_reduce.py  | 396 ++++++++++++++++++++-
 4 files changed, 613 insertions(+), 137 deletions(-)

diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index cdbdb4b542..378be84621 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -616,7 +616,7 @@ def tvm_storage_sync(storage_scope):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("handle", "tir.tvm_storage_sync", storage_scope)
+    return call_intrin("int32", "tir.tvm_storage_sync", storage_scope)
 
 
 def tvm_warp_shuffle(mask, value, warp_id, width, warp_size):
diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc
index 8cbcfbc78f..52e38c7ba2 100644
--- a/src/te/operation/cross_thread_reduction.cc
+++ b/src/te/operation/cross_thread_reduction.cc
@@ -181,22 +181,23 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
     freduce_args.push_back(dummy_load);
   }
 
+  // Checks for the thread.
+  std::vector<PrimExpr> output_preds;
+  if (stage->store_predicate.defined()) {
+    output_preds.emplace_back(stage->store_predicate);
+  }
+
   for (IterVar iv : stage->leaf_iter_vars) {
     if (iv->iter_type == kCommReduce) {
       auto it = stage->iter_var_attrs.find(iv);
       if (it != stage->iter_var_attrs.end() && (*it).second->bind_thread.defined()) {
         IterVar tv = (*it).second->bind_thread;
         freduce_args.push_back(tv->var);
+        output_preds.push_back(tv->var == make_const(tv->var->dtype, 0));
       }
     }
   }
 
-  // Checks for the thread.
-  std::vector<PrimExpr> output_preds;
-  if (stage->store_predicate.defined()) {
-    output_preds.emplace_back(stage->store_predicate);
-  }
-
   // Apply the existing input predicate if any.
   output_preds.push_back(input_pred);
 
diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc
index 97a34a6ede..b47e837711 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -63,7 +63,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
  public:
   explicit ThreadAllreduceBuilder(const TargetNode* target)
       : target_(target),
-        warp_size_(target->GetAttr<Integer>("thread_warp_size", 1).value().IntValue()) {}
+        warp_size_(target->GetAttr<Integer>("thread_warp_size", 1).value().IntValue()),
+        max_num_threads_(target->GetAttr<Integer>("max_num_threads", -1).value().IntValue()) {}
 
   Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::thread_extent) {
@@ -279,9 +280,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     }
 
     std::vector<Stmt> seq;
-    std::vector<Var> shared_buffer_vars(size);
-    std::vector<Buffer> shared_bufs(size);
-    std::vector<Buffer> local_bufs;
+    std::vector<Buffer> new_alloc_bufs;
     //
     // This is an optimization. For small reduction sizes, it may be beneficial
     // for a single warp to performance the entire reduction. No trips to shared
@@ -299,131 +298,87 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     // broadcast results from lane 0 to all other lanes and store
     // the final reduction result to the proper location.
     //
-    if (is_warp_reduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) {
-      ICHECK_LE(reduce_extent, warp_size_) << "not a warp reduction";
-      //
-      // This is the index to the reduction variable, one reduction
-      // variable per warp. Local scope seems easier to reason without
-      // relying on a pattern match pass to fix it later.
-      Array<PrimExpr> zero_indices = {0};
-
-      for (size_t idx = 0; idx < size; ++idx) {
-        Array<PrimExpr> shape = {1};
-
-        Buffer buffer = decl_buffer(shape, types[idx], "red_buf" + std::to_string(idx));
-        Var buffer_var = buffer->data;
-
-        shared_buffer_vars[idx] = buffer_var;
-        shared_bufs[idx] = buffer;
-
-        PrimExpr pred = const_true(types[idx].lanes());
-        seq.emplace_back(BufferStore(shared_bufs[idx], values[idx], zero_indices));
-
-        // Uses a local variable to store the shuffled data.  Later
-        // on, an allocation will be built for this local variable.
-        local_bufs.push_back(decl_buffer(shape, types[idx], "t" + std::to_string(idx)));
-      }
-
-      // The mask for this reducer, as this reducer may sit inside
-      // a divergent control flow. Here it uses a variable to cache the current
-      // active channels.
-      //
+    // When the thread extent is multiple of warp size, we can use a two-stage
+    // warp-level reduction to optimize. This is implemented by applying the
+    // algorithm above twice.
+    //
+    // For example, suppose we want to use 512 threads to reduce 512 elements
+    // and the warp size is 32. In this case there are (512 / 32) = 16 warps.
+    // In the first stage, each of the 16 warps reduces 32 elements. So after
+    // the stage, we have 16 remaining elements to be reduced, one for each warp.
+    // We store the 16 elements in shared memory, and start the second stage.
+    // In the second stage we use the first 16 lanes of the first warp to reduce
+    // the remaining elements, and this reduction can also be optimized by
+    // shuffle_down warp-level primitives.
+    if (IsWarpReduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) {
+      std::vector<PrimExpr> reduce_results;
       DataType mask_dtype = DataType::UInt(32);
-      Buffer mask_buffer = decl_buffer({1}, mask_dtype, "mask");
-      {
-        PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
-        if (group_extent > 1) {
-          mask = mask & (make_const(mask_dtype, (1ll << reduce_extent) - 1)
-                         << (reduce_extent * cast(mask_dtype, group_index)));
+      PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
+
+      if (reduce_extent <= warp_size_) {
+        if (group_extent > 1 && reduce_extent < warp_size_) {
+          mask = mask &
+                 (((1 << reduce_extent) - 1) << (reduce_extent * cast(mask_dtype, group_index)));
         }
-        seq.emplace_back(BufferStore(mask_buffer, mask, zero_indices));
-        // Push the buffer description.  Later this will have an
-        // allocation built for it.
-        local_bufs.push_back(mask_buffer);
-      }
+        std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce(
+            values, types, combiner, reduce_index, reduce_extent, group_index, mask, NullOpt, &seq);
+      } else {
+        int n_warps = reduce_extent / warp_size_;
+        std::vector<Buffer> local_bufs;
 
-      // Emit reductions within a warp.
-      int start_offset = 1;
-      while (start_offset * 2 < reduce_extent) {
-        start_offset *= 2;
-      }
-      for (int offset = start_offset; offset > 0; offset /= 2) {
-        // Load reduction values, no synchronization needed.
-        Array<PrimExpr> a, b;
+        // 1. Create the staging buffer in shared memory.
+        std::vector<Buffer> staging_shared_bufs;
+        staging_shared_bufs.reserve(size);
         for (size_t i = 0; i < size; ++i) {
-          Buffer shared_buf = shared_bufs[i];
-          BufferLoad val(shared_buf, zero_indices);
-          ICHECK_EQ(val->dtype, types[i]);
-          a.push_back(val);
-
-          // __shfl_*sync calls shall not appear in if_then_else expressions
-          // as this is causing extra divergency. E.g.
-          //
-          // v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0);
-          //
-          // behaves differently from
-          //
-          // int t = __shfl_sync(mask, v1, 0);
-          // v1 = (v2 < v3) ? v3 : t;
-          //
-          // The former may cause dead lock as there is a divergent
-          // branch with a warp sync call inside.
-          //
-          PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_buffer, val, offset);
-          Buffer local_buf = local_bufs[i];
-          Stmt s = BufferStore(local_buf, other, zero_indices);
-          seq.push_back(s);
-
-          BufferLoad load = BufferLoad(local_buf, zero_indices);
-          ICHECK_EQ(load->dtype, types[i]);
-          b.push_back(load);
+          Buffer staging_shared_buf = decl_buffer(
+              /*shape=*/{make_const(reduce_index->dtype, n_warps * group_extent)},
+              /*dtype=*/buffers[i]->dtype, /*name=*/"red_buf_staging", /*storage_scope=*/"shared");
+          staging_shared_bufs.push_back(staging_shared_buf);
+          new_alloc_bufs.push_back(staging_shared_buf);
         }
 
-        // Do reductions.
-        Array<PrimExpr> ret = (*combiner)(a, b);
+        // 2. First round of allreduce.
+        std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
+            values, types, combiner, reduce_index, warp_size_, group_index, mask, NullOpt, &seq);
+        new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end());
 
-        // Store the reduction result to itself.
-        std::vector<Stmt> stores(size);
+        // 3. Write allreduce results to staging buffer.
+        std::vector<Stmt> write_staging_buf;
+        write_staging_buf.reserve(size);
         for (size_t i = 0; i < size; ++i) {
-          Buffer buf = shared_bufs[i];
-          stores[i] = BufferStore(buf, ret[i], zero_indices);
+          new_alloc_bufs.push_back(Downcast<BufferLoad>(reduce_results[i])->buffer);
+          write_staging_buf.push_back(BufferStore(
+              /*buffer=*/staging_shared_bufs[i],
+              /*value=*/reduce_results[i],
+              /*indices=*/{group_index * n_warps + floordiv(reduce_index, warp_size_)}));
         }
+        PrimExpr cond = floormod(reduce_index, warp_size_) == make_const(reduce_index->dtype, 0);
+        seq.push_back(IfThenElse(cond, SeqStmt::Flatten(write_staging_buf)));
+        seq.push_back(SyncThread("shared"));
 
-        // During the sub-warp reduction, values from inactive threads could be read,
-        // which is an undefined behavior according to the cuda document.
-        //
-        // In practice, the return value are usually 0, which does no harm to sum reduction.
-        // However, the result can be incorrect in max or prod reduction.
-        // Therefore an additional range check has to be performed to ensure the correctness.
-        if (offset * 2 > reduce_extent) {
-          PrimExpr cond = reduce_index + offset < reduce_extent;
-          seq.push_back(IfThenElse(cond, SeqStmt::Flatten(stores)));
-        } else {
-          seq.push_back(SeqStmt::Flatten(stores));
+        // 4. Load staging buffer.
+        //    Second round of allreduce.
+        for (size_t i = 0; i < size; ++i) {
+          values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i], /*indices=*/{reduce_index});
         }
+        if (n_warps < warp_size_) {
+          mask = mask & (((1 << n_warps) - 1) << group_index);
+        }
+        std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
+            values, types, combiner, reduce_index, n_warps, group_index, mask,
+            /*predicate=*/reduce_index < make_const(reduce_index->dtype, group_extent * n_warps),
+            &seq);
+        new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end());
       }
 
-      // Broadcast the reduction result from lane 0 to all other lanes.
-      // This avoids to emit predicated stores, as all threads are
-      // uniformly writing the same result.
-      //
-      for (size_t i = 0; i < size; ++i) {
-        Buffer buf = shared_bufs[i];
-        PrimExpr val = BufferLoad(buf, zero_indices);
-        ICHECK_EQ(val->dtype, types[i]);
-        PrimExpr splat =
-            WarpShuffle(builtin::tvm_warp_shuffle(), mask_buffer, val, reduce_extent * group_index);
-        seq.push_back(BufferStore(buf, splat, zero_indices));
-      }
-
-      // Update existing allocations.
+      // Write back allreduce results and update existing allocations.
       for (size_t i = 0; i < size; ++i) {
         ICHECK(!load_remap_.count(buffers[i]->data.get()));
         PrimExpr pred = const_true(types[i].lanes());
-        Buffer buf = shared_bufs[i];
-        PrimExpr val = BufferLoad(buf, zero_indices);
-        ICHECK_EQ(val->dtype, types[i]);
-        load_remap_[buffers[i]->data.get()] = val;
+        Buffer buf = Downcast<BufferLoad>(reduce_results[i])->buffer;
+        ICHECK_EQ(reduce_results[i]->dtype, types[i]);
+        load_remap_[buffers[i]->data.get()] = reduce_results[i];
+
         Array<PrimExpr> extents{PrimExpr(1)};
         auto node = Allocate(buf->data, types[i], extents, pred, Evaluate(0));
         alloc_remap_[buffers[i]->data.get()] = node;
@@ -432,6 +387,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
         warp_allocs_.insert(node.get());
       }
     } else {
+      std::vector<Buffer> shared_bufs(size);
       if (reduce_extent == 1) {
         // special case, no reduction is needed.
         std::vector<Stmt> stores;
@@ -444,12 +400,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       // previous iteration on the same buffer.
       seq.emplace_back(SyncThread("shared"));
       for (size_t idx = 0; idx < size; ++idx) {
-        Buffer buffer = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx));
-
-        shared_bufs[idx] = buffer;
-        shared_buffer_vars[idx] = buffer->data;
-
-        PrimExpr pred = const_true(types[idx].lanes());
+        shared_bufs[idx] = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx));
         seq.emplace_back(BufferStore(shared_bufs[idx], values[idx],
                                      {BufIndex(reduce_index, group_index, reduce_extent)}));
       }
@@ -473,14 +424,146 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
 
     // Fix all local allocations as all statements are built.
     Stmt body = SeqStmt::Flatten(seq);
-    for (Buffer buf : local_bufs) {
+    for (Buffer buf : new_alloc_bufs) {
       body = Allocate(buf->data, buf->dtype, buf->shape, const_true(buf->dtype.lanes()), body);
-      new_storage_scopes_[buf->data.get()] = "local";
+      if (buf.scope() != "shared") {
+        new_storage_scopes_[buf->data.get()] = "local";
+      }
     }
 
     return body;
   }
 
+  std::pair<std::vector<PrimExpr>, std::vector<Buffer>> MakeWarpAllreduce(
+      std::vector<PrimExpr> src_values,             //
+      std::vector<DataType> dtypes,                 //
+      const CommReducerNode* combiner,              //
+      PrimExpr reduce_index, int reduce_extent,     //
+      PrimExpr group_index,                         //
+      PrimExpr mask, Optional<PrimExpr> predicate,  //
+      std::vector<Stmt>* seq) {
+    int n_buffers = src_values.size();
+
+    std::vector<Buffer> shared_bufs;
+    std::vector<Buffer> local_bufs;
+    shared_bufs.reserve(n_buffers);
+
+    // This is the index to the reduction variable, one reduction
+    // variable per warp. Local scope seems easier to reason without
+    // relying on a pattern match pass to fix it later.
+    Array<PrimExpr> zero_indices = {0};
+    Array<PrimExpr> shape = {1};
+
+    std::vector<Stmt> load_values;
+    load_values.reserve(n_buffers);
+    for (int idx = 0; idx < n_buffers; ++idx) {
+      shared_bufs.push_back(decl_buffer(shape, dtypes[idx], "red_buf" + std::to_string(idx)));
+      load_values.push_back(BufferStore(shared_bufs[idx], src_values[idx], zero_indices));
+
+      // Uses a local variable to store the shuffled data.  Later
+      // on, an allocation will be built for this local variable.
+      local_bufs.push_back(decl_buffer(shape, dtypes[idx], "t" + std::to_string(idx)));
+    }
+
+    if (predicate.defined()) {
+      seq->push_back(IfThenElse(predicate.value(), SeqStmt::Flatten(load_values)));
+    } else {
+      seq->insert(seq->end(), load_values.begin(), load_values.end());
+    }
+
+    // The mask for this reducer, as this reducer may sit inside
+    // a divergent control flow. Here it uses a variable to cache the current
+    // active channels.
+    Buffer mask_buffer = decl_buffer(shape, mask->dtype, "mask");
+    {
+      seq->emplace_back(BufferStore(mask_buffer, mask, zero_indices));
+      // Push the buffer description.  Later this will have an
+      // allocation built for it.
+      local_bufs.push_back(mask_buffer);
+    }
+
+    // Emit reductions within a warp.
+    int start_offset = 1;
+    while (start_offset * 2 < reduce_extent) {
+      start_offset *= 2;
+    }
+    for (int offset = start_offset; offset > 0; offset /= 2) {
+      // Load reduction values, no synchronization needed.
+      Array<PrimExpr> a, b;
+      for (int i = 0; i < n_buffers; ++i) {
+        Buffer shared_buf = shared_bufs[i];
+        BufferLoad val(shared_buf, zero_indices);
+        ICHECK_EQ(val->dtype, dtypes[i]);
+        a.push_back(val);
+
+        // __shfl_*sync calls shall not appear in if_then_else expressions
+        // as this is causing extra divergency. E.g.
+        //
+        // v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0);
+        //
+        // behaves differently from
+        //
+        // int t = __shfl_sync(mask, v1, 0);
+        // v1 = (v2 < v3) ? v3 : t;
+        //
+        // The former may cause dead lock as there is a divergent
+        // branch with a warp sync call inside.
+        PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_buffer, val, offset);
+        Buffer local_buf = local_bufs[i];
+        Stmt s = BufferStore(local_buf, other, zero_indices);
+        seq->push_back(s);
+
+        BufferLoad load = BufferLoad(local_buf, zero_indices);
+        ICHECK_EQ(load->dtype, dtypes[i]);
+        b.push_back(load);
+      }
+
+      // Do reductions.
+      Array<PrimExpr> ret = (*combiner)(a, b);
+
+      // Store the reduction result to itself.
+      std::vector<Stmt> stores;
+      stores.reserve(n_buffers);
+      for (int i = 0; i < n_buffers; ++i) {
+        Buffer buf = shared_bufs[i];
+        stores.push_back(BufferStore(buf, ret[i], zero_indices));
+      }
+
+      // During the sub-warp reduction, values from inactive threads could be read,
+      // which is an undefined behavior according to the cuda document.
+      //
+      // In practice, the return value are usually 0, which does no harm to sum reduction.
+      // However, the result can be incorrect in max or prod reduction.
+      // Therefore an additional range check has to be performed to ensure the correctness.
+      if (offset * 2 > reduce_extent) {
+        PrimExpr cond = reduce_index + offset < reduce_extent;
+        seq->push_back(IfThenElse(cond, SeqStmt::Flatten(stores)));
+      } else {
+        seq->push_back(SeqStmt::Flatten(stores));
+      }
+    }
+
+    // Broadcast the reduction result from lane 0 to all other lanes.
+    // This avoids to emit predicated stores, as all threads are
+    // uniformly writing the same result.
+    for (int i = 0; i < n_buffers; ++i) {
+      Buffer buf = shared_bufs[i];
+      PrimExpr val = BufferLoad(buf, zero_indices);
+      ICHECK_EQ(val->dtype, dtypes[i]);
+      PrimExpr splat =
+          WarpShuffle(builtin::tvm_warp_shuffle(), mask_buffer, val, reduce_extent * group_index);
+      seq->push_back(BufferStore(buf, splat, zero_indices));
+    }
+
+    std::vector<PrimExpr> reduce_results;
+    reduce_results.reserve(n_buffers);
+    for (int i = 0; i < n_buffers; ++i) {
+      reduce_results.push_back(BufferLoad(shared_bufs[i], zero_indices));
+    }
+
+    return {reduce_results, local_bufs};
+  }
+
   // make allreduce.
   Stmt MakeBufAllreduce(const CommReducerNode* combiner, const std::vector<DataType>& types,
                         const Array<Buffer>& shared_bufs, PrimExpr reduce_index,
@@ -637,8 +720,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
   //
   // Note: The ROCm backend will only have warp reductions for now.
   // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda).
-  bool is_warp_reduction(const std::vector<DataType>& types, int group_extent, int reduce_extent,
-                         int contiguous_reduce_extent) const {
+  bool IsWarpReduction(const std::vector<DataType>& types, int group_extent, int reduce_extent,
+                       int contiguous_reduce_extent) const {
     // Only cuda target supports warp reductions.
     if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm")) return false;
 
@@ -676,8 +759,12 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       if (reduce_extent == 1) {
         return false;  // no need to warp reduce
       } else {
-        if (warp_size_ % reduce_extent == 0) {
-          return true;  // warp size is multiple of reduce extent
+        bool is_subwarp_reduction = warp_size_ % reduce_extent == 0;
+        bool is_multiwarp_reduction = max_num_threads_ != -1 &&
+                                      max_num_threads_ <= warp_size_ * warp_size_ &&
+                                      reduce_extent % warp_size_ == 0;
+        if (is_subwarp_reduction || is_multiwarp_reduction) {
+          return true;
         } else {
           return group_extent == 1 && reduce_extent <= warp_size_;
         }
@@ -690,6 +777,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
 
   // The warp size of the device.
   int warp_size_{1};
+  // The maximum number of threads of the device. "-1" denotes unknown.
+  int max_num_threads_{-1};
 
   // surrounding scope of thread extent.
   std::vector<const AttrStmtNode*> thread_extents_;
diff --git a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
index c9e6136ca8..f354dfe9ca 100644
--- a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
+++ b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
@@ -235,7 +235,7 @@ class TestReduceSummation(BaseCompare):
                 B[i] = reduce[0]
 
 
-class TestMultiGroupMask(BaseCompare):
+class TestMultiGroupReduction(BaseCompare):
     @T.prim_func
     def before(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32,), "float32")):
         T.func_attr({"target": T.target("cuda", host="llvm")})
@@ -278,10 +278,7 @@ class TestMultiGroupMask(BaseCompare):
             red_buf0_1[0] = A_1[threadIdx_y * 32 + threadIdx_x]
 
             mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local")
-            mask_1[0] = T.bitwise_and(
-                T.tvm_warp_activemask(),
-                T.shift_left(T.uint32(4294967295), T.uint32(32) * T.Cast("uint32", threadIdx_y)),
-            )
+            mask_1[0] = T.tvm_warp_activemask()
 
             t0_1 = T.Buffer((1,), data=t0, scope="local")
             t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16, 32, 32)
@@ -300,5 +297,394 @@ class TestMultiGroupMask(BaseCompare):
             B_1[threadIdx_y] = red_buf0_1[0]
 
 
+class TestMultiGroupMask1(BaseCompare):
+    @T.prim_func
+    def before(A: T.Buffer((32, 8), "float32"), B: T.Buffer((32,), "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        threadIdx_y = T.launch_thread("threadIdx.y", 32)
+        cross_thread_B = T.allocate([1], "float32", "local")
+        threadIdx_x = T.launch_thread("threadIdx.x", 8)
+        cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local")
+        with T.attr(
+            T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+            "reduce_scope",
+            T.reinterpret("handle", T.uint64(0)),
+        ):
+            A_1 = T.Buffer((256,), data=A.data)
+            T.tvm_thread_allreduce(
+                T.uint32(1),
+                A_1[threadIdx_y * 8 + threadIdx_x],
+                T.bool(True),
+                cross_thread_B_1[0],
+                threadIdx_x,
+            )
+        if threadIdx_x == 0:
+            B_1 = T.Buffer((32,), data=B.data)
+            B_1[threadIdx_y] = cross_thread_B_1[0]
+
+    @T.prim_func
+    def expected(A: T.Buffer((32, 8), "float32"), B: T.Buffer((32,), "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        threadIdx_y = T.launch_thread("threadIdx.y", 32)
+        red_buf0 = T.allocate([1], "float32", "local")
+        threadIdx_x = T.launch_thread("threadIdx.x", 8)
+        red_buf0_1 = T.Buffer((1,), data=red_buf0, scope="local")
+        with T.attr(
+            T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+            "reduce_scope",
+            T.reinterpret("handle", T.uint64(0)),
+        ):
+            mask = T.allocate([1], "uint32", "local")
+            t0 = T.allocate([1], "float32", "local")
+            A_1 = T.Buffer((256,), data=A.data)
+            red_buf0_1[0] = A_1[threadIdx_y * 8 + threadIdx_x]
+            mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local")
+            mask_1[0] = T.bitwise_and(
+                T.tvm_warp_activemask(),
+                T.shift_left(T.uint32(255), T.uint32(8) * T.Cast("uint32", threadIdx_y)),
+            )
+            t0_1 = T.Buffer((1,), data=t0, scope="local")
+            t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 32)
+            red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
+            t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 32)
+            red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
+            t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32)
+            red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
+            red_buf0_1[0] = T.tvm_warp_shuffle(mask_1[0], red_buf0_1[0], 8 * threadIdx_y, 32, 32)
+        if threadIdx_x == 0:
+            B_1 = T.Buffer((32,), data=B.data)
+            B_1[threadIdx_y] = red_buf0_1[0]
+
+
+class TestMultiWarpReduce1(BaseCompare):
+    @T.prim_func
+    def before(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        for i in range(128):
+            threadIdx_x = T.launch_thread("threadIdx.x", 128)
+            cross_thread_B = T.allocate([1], "float32", "local")
+            cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local")
+            with T.attr(
+                T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+                "reduce_scope",
+                T.reinterpret("handle", T.uint64(0)),
+            ):
+                A_1 = T.Buffer((16384,), data=A.data)
+                T.tvm_thread_allreduce(
+                    T.uint32(1),
+                    A_1[i * 128 + threadIdx_x],
+                    T.bool(True),
+                    cross_thread_B_1[0],
+                    threadIdx_x,
+                )
+            if threadIdx_x == 0:
+                B_1 = T.Buffer((128,), data=B.data)
+                B_1[i] = cross_thread_B_1[0]
+
+    @T.prim_func
+    def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        for i in range(128):
+            threadIdx_x = T.launch_thread("threadIdx.x", 128)
+            red_buf0 = T.allocate([1], "float32", "local")
+            red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
+            with T.attr(
+                T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+                "reduce_scope",
+                T.reinterpret("handle", T.uint64(0)),
+            ):
+                mask = T.allocate([1], "uint32", "local")
+                t0 = T.allocate([1], "float32", "local")
+                red_buf0_1 = T.allocate([1], "float32", "local")
+                mask_1 = T.allocate([1], "uint32", "local")
+                t0_1 = T.allocate([1], "float32", "local")
+                red_buf_staging = T.allocate([4], "float32", "shared")
+                red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local")
+                A_1 = T.Buffer((16384,), data=A.data)
+                red_buf0_2[0] = A_1[i * 128 + threadIdx_x]
+                mask_2 = T.Buffer((1,), "uint32", data=mask_1, scope="local")
+                mask_2[0] = T.tvm_warp_activemask()
+                t0_2 = T.Buffer((1,), data=t0_1, scope="local")
+                t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 16, 32, 32)
+                red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+                t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 8, 32, 32)
+                red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+                t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 4, 32, 32)
+                red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+                t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 2, 32, 32)
+                red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+                t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32)
+                red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+                red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 0, 32, 32)
+                red_buf_staging_1 = T.Buffer((4,), data=red_buf_staging, scope="shared")
+                if threadIdx_x % 32 == 0:
+                    red_buf_staging_1[threadIdx_x // 32] = red_buf0_2[0]
+                T.tvm_storage_sync("shared")
+                if threadIdx_x < 4:
+                    red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
+                mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
+                mask_3[0] = T.bitwise_and(T.tvm_warp_activemask(), T.uint32(15))
+                t0_3 = T.Buffer((1,), data=t0, scope="local")
+                t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32)
+                red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
+                t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
+                red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
+                red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 0, 32, 32)
+            if threadIdx_x == 0:
+                B_1 = T.Buffer((128,), data=B.data)
+                B_1[i] = red_buf0_3[0]
+
+
+class TestMultiWarpReduce2(BaseCompare):
+    @T.prim_func
+    def before(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        threadIdx_x = T.launch_thread("threadIdx.x", 1024)
+        cross_thread_B = T.allocate([1], "float32", "local")
+        cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local")
+        with T.attr(
+            T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+            "reduce_scope",
+            T.reinterpret("handle", T.uint64(0)),
+        ):
+            A_1 = T.Buffer((1024,), data=A.data)
+            T.tvm_thread_allreduce(
+                T.uint32(1), A_1[threadIdx_x], T.bool(True), cross_thread_B_1[0], threadIdx_x
+            )
+        if threadIdx_x == 0:
+            B_1 = T.Buffer((1,), data=B.data)
+            B_1[0] = cross_thread_B_1[0]
+
+    @T.prim_func
+    def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        threadIdx_x = T.launch_thread("threadIdx.x", 1024)
+        red_buf0 = T.allocate([1], "float32", "local")
+        red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
+        with T.attr(
+            T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+            "reduce_scope",
+            T.reinterpret("handle", T.uint64(0)),
+        ):
+            mask = T.allocate([1], "uint32", "local")
+            t0 = T.allocate([1], "float32", "local")
+            red_buf0_1 = T.allocate([1], "float32", "local")
+            mask_1 = T.allocate([1], "uint32", "local")
+            t0_1 = T.allocate([1], "float32", "local")
+            red_buf_staging = T.allocate([32], "float32", "shared")
+            red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local")
+            A_1 = T.Buffer((1024,), data=A.data)
+            red_buf0_2[0] = A_1[threadIdx_x]
+            mask_2 = T.Buffer((1,), "uint32", data=mask_1, scope="local")
+            mask_2[0] = T.tvm_warp_activemask()
+            t0_2 = T.Buffer((1,), data=t0_1, scope="local")
+            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 16, 32, 32)
+            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 8, 32, 32)
+            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 4, 32, 32)
+            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 2, 32, 32)
+            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32)
+            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+            red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 0, 32, 32)
+            red_buf_staging_1 = T.Buffer((32,), data=red_buf_staging, scope="shared")
+            if threadIdx_x % 32 == 0:
+                red_buf_staging_1[threadIdx_x // 32] = red_buf0_2[0]
+            T.tvm_storage_sync("shared")
+            if threadIdx_x < 32:
+                red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
+            mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
+            mask_3[0] = T.tvm_warp_activemask()
+            t0_3 = T.Buffer((1,), data=t0, scope="local")
+            t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 16, 32, 32)
+            red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
+            t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 8, 32, 32)
+            red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
+            t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 4, 32, 32)
+            red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
+            t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32)
+            red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
+            t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
+            red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
+            red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 0, 32, 32)
+        if threadIdx_x == 0:
+            B_1 = T.Buffer((1,), data=B.data)
+            B_1[0] = red_buf0_3[0]
+
+
+class TestMultiGroupMultiWarpReduction(BaseCompare):
+    @T.prim_func
+    def before(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        threadIdx_y = T.launch_thread("threadIdx.y", 4)
+        cross_thread_B = T.allocate([1], "float32", "local")
+        threadIdx_x = T.launch_thread("threadIdx.x", 128)
+        cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local")
+        with T.attr(
+            T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+            "reduce_scope",
+            T.reinterpret("handle", T.uint64(0)),
+        ):
+            A_1 = T.Buffer((512,), data=A.data)
+            T.tvm_thread_allreduce(
+                T.uint32(1),
+                A_1[threadIdx_y * 128 + threadIdx_x],
+                T.bool(True),
+                cross_thread_B_1[0],
+                threadIdx_x,
+            )
+        if threadIdx_x == 0:
+            B_1 = T.Buffer((4,), data=B.data)
+            B_1[threadIdx_y] = cross_thread_B_1[0]
+
+    @T.prim_func
+    def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        threadIdx_y = T.launch_thread("threadIdx.y", 4)
+        red_buf0 = T.allocate([1], "float32", "local")
+        threadIdx_x = T.launch_thread("threadIdx.x", 128)
+        red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
+        with T.attr(
+            T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+            "reduce_scope",
+            T.reinterpret("handle", T.uint64(0)),
+        ):
+            mask = T.allocate([1], "uint32", "local")
+            t0 = T.allocate([1], "float32", "local")
+            red_buf0_1 = T.allocate([1], "float32", "local")
+            mask_1 = T.allocate([1], "uint32", "local")
+            t0_1 = T.allocate([1], "float32", "local")
+            red_buf_staging = T.allocate([16], "float32", "shared")
+            red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local")
+            A_1 = T.Buffer((512,), data=A.data)
+            red_buf0_2[0] = A_1[threadIdx_y * 128 + threadIdx_x]
+            mask_2 = T.Buffer((1,), "uint32", data=mask_1, scope="local")
+            mask_2[0] = T.tvm_warp_activemask()
+            t0_2 = T.Buffer((1,), data=t0_1, scope="local")
+            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 16, 32, 32)
+            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 8, 32, 32)
+            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 4, 32, 32)
+            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 2, 32, 32)
+            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32)
+            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+            red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 32 * threadIdx_y, 32, 32)
+            red_buf_staging_1 = T.Buffer((16,), data=red_buf_staging, scope="shared")
+            if threadIdx_x % 32 == 0:
+                red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_2[0]
+            T.tvm_storage_sync("shared")
+            if threadIdx_x < 16:
+                red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
+            mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
+            mask_3[0] = T.bitwise_and(
+                T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y))
+            )
+            t0_3 = T.Buffer((1,), data=t0, scope="local")
+            t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32)
+            red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
+            t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
+            red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
+            red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 4 * threadIdx_y, 32, 32)
+        if threadIdx_x == 0:
+            B_1 = T.Buffer((4,), data=B.data)
+            B_1[threadIdx_y] = red_buf0_3[0]
+
+
+class TestMultiGroupMultiWarpPredicatedReduction(BaseCompare):
+    @T.prim_func
+    def before(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        threadIdx_y = T.launch_thread("threadIdx.y", 2)
+        in_thread_B = T.allocate([1], "float32", "local")
+        cross_thread_B = T.allocate([1], "float32", "local")
+        threadIdx_x = T.launch_thread("threadIdx.x", 512)
+        in_thread_B_1 = T.Buffer((1,), data=in_thread_B, scope="local")
+        in_thread_B_1[0] = T.float32(0)
+        if threadIdx_x < 70:
+            A_1 = T.Buffer((140,), data=A.data)
+            in_thread_B_1[0] = in_thread_B_1[0] + A_1[threadIdx_y * 70 + threadIdx_x]
+        cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local")
+        with T.attr(
+            T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+            "reduce_scope",
+            T.reinterpret("handle", T.uint64(0)),
+        ):
+            T.tvm_thread_allreduce(
+                T.uint32(1), in_thread_B_1[0], T.bool(True), cross_thread_B_1[0], threadIdx_x
+            )
+        if threadIdx_x == 0:
+            B_1 = T.Buffer((2,), data=B.data)
+            B_1[threadIdx_y] = cross_thread_B_1[0]
+
+    @T.prim_func
+    def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        threadIdx_y = T.launch_thread("threadIdx.y", 2)
+        in_thread_B = T.allocate([1], "float32", "local")
+        red_buf0 = T.allocate([1], "float32", "local")
+        threadIdx_x = T.launch_thread("threadIdx.x", 512)
+        in_thread_B_1 = T.Buffer((1,), data=in_thread_B, scope="local")
+        in_thread_B_1[0] = T.float32(0)
+        if threadIdx_x < 70:
+            A_1 = T.Buffer((140,), data=A.data)
+            in_thread_B_1[0] = in_thread_B_1[0] + A_1[threadIdx_y * 70 + threadIdx_x]
+        red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
+        with T.attr(
+            T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+            "reduce_scope",
+            T.reinterpret("handle", T.uint64(0)),
+        ):
+            mask = T.allocate([1], "uint32", "local")
+            t0 = T.allocate([1], "float32", "local")
+            red_buf0_1 = T.allocate([1], "float32", "local")
+            mask_1 = T.allocate([1], "uint32", "local")
+            t0_1 = T.allocate([1], "float32", "local")
+            red_buf_staging = T.allocate([32], "float32", "shared")
+            red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local")
+            red_buf0_2[0] = in_thread_B_1[0]
+            mask_2 = T.Buffer((1,), "uint32", data=mask_1, scope="local")
+            mask_2[0] = T.tvm_warp_activemask()
+            t0_2 = T.Buffer((1,), data=t0_1, scope="local")
+            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 16, 32, 32)
+            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 8, 32, 32)
+            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 4, 32, 32)
+            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 2, 32, 32)
+            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32)
+            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
+            red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 32 * threadIdx_y, 32, 32)
+            red_buf_staging_1 = T.Buffer((32,), data=red_buf_staging, scope="shared")
+            if threadIdx_x % 32 == 0:
+                red_buf_staging_1[threadIdx_y * 16 + threadIdx_x // 32] = red_buf0_2[0]
+            T.tvm_storage_sync("shared")
+            if threadIdx_x < 32:
+                red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
+            mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
+            mask_3[0] = T.bitwise_and(
+                T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y))
+            )
+            t0_3 = T.Buffer((1,), data=t0, scope="local")
+            t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 8, 32, 32)
+            red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
+            t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 4, 32, 32)
+            red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
+            t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32)
+            red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
+            t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
+            red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
+            red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 16 * threadIdx_y, 32, 32)
+        if threadIdx_x == 0:
+            B_1 = T.Buffer((2,), data=B.data)
+            B_1[threadIdx_y] = red_buf0_3[0]
+
+
 if __name__ == "__main__":
     tvm.testing.main()