You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "MasterJH5574 (via GitHub)" <gi...@apache.org> on 2023/07/15 19:06:05 UTC

[GitHub] [tvm] MasterJH5574 opened a new pull request, #15327: [TIR] ThreadAllreduce warp-level primitive support with multi-warp

MasterJH5574 opened a new pull request, #15327:
URL: https://github.com/apache/tvm/pull/15327

   This PR enhances the implementation of the LowerThreadAllreduce pass.
   
   Prior to this PR, for CUDA backend we will leverage warp-level primitives only when
   * the reducing threads are a sub-warp (i.e., size 16, 8, 4, 2), or
   * the number of reducing threads is less then 32, and equals the reduction extent.
   
   Under the requirement above, for reductions that have large number of reducing threads (e.g., reducing over 128, 256 or larger number or threads), the generated code is inefficient.
   
   This PR improves the LowerThreadAllreduce pass, so that we now generate more efficient CUDA code in such cases, when the number of reducing threads is a multiple of warp size, with the help of warp-level primitives.
   
   Specifically, in such cases, we first reducing 32 elements within each warp, getting the results of each warp stored in shared memory. We then trigger a second round of warp-level primitive reduction within the first warp, and get the final reduction results.
   
   In addition to using warp-level primitives, by doing this we also reduce the size of the shared memory. For example, even when reducing over 1024 threads, we now only require shared memory of size 32, compared with 1024 prior to this PR.
   
   Tests are added to ensure correctness.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] MasterJH5574 commented on a diff in pull request #15327: [TIR] ThreadAllreduce warp-level primitive support with multi-warp

Posted by "MasterJH5574 (via GitHub)" <gi...@apache.org>.
MasterJH5574 commented on code in PR #15327:
URL: https://github.com/apache/tvm/pull/15327#discussion_r1264621628


##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -676,8 +747,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       if (reduce_extent == 1) {
         return false;  // no need to warp reduce
       } else {
-        if (warp_size_ % reduce_extent == 0) {
-          return true;  // warp size is multiple of reduce extent
+        if (warp_size_ % reduce_extent == 0 ||
+            (max_num_threads_ != -1 && max_num_threads_ <= warp_size_ * warp_size_ &&

Review Comment:
   To make the behavior for `-1` consistent with the behavior prior to this pass, I think it’s fine not to throw.
   
   If unknown number of threads is now allowed, it must be reported at somewhere else.



##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -676,8 +747,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       if (reduce_extent == 1) {
         return false;  // no need to warp reduce
       } else {
-        if (warp_size_ % reduce_extent == 0) {
-          return true;  // warp size is multiple of reduce extent
+        if (warp_size_ % reduce_extent == 0 ||
+            (max_num_threads_ != -1 && max_num_threads_ <= warp_size_ * warp_size_ &&

Review Comment:
   To make the behavior for `-1` consistent with the behavior prior to this pass, I think it’s fine not to throw.
   
   If unknown number of threads is now allowed, it must be reported at somewhere else.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #15327: [TIR] ThreadAllreduce warp-level primitive support with multi-warp

Posted by "tvm-bot (via GitHub)" <gi...@apache.org>.
tvm-bot commented on PR #15327:
URL: https://github.com/apache/tvm/pull/15327#issuecomment-1636858984

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   <!--bot-comment-ccs-start-->
    * cc @Hzfengsy, @junrushao, @quic-sanirudh, @shingjan <sub>See [#10317](https://github.com/apache/tvm/issues/10317) for details</sub><!--bot-comment-ccs-end-->
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on pull request #15327: [TIR] ThreadAllreduce warp-level primitive support with multi-warp

Posted by "tqchen (via GitHub)" <gi...@apache.org>.
tqchen commented on PR #15327:
URL: https://github.com/apache/tvm/pull/15327#issuecomment-1636871360

   cc @masahi 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] MasterJH5574 commented on a diff in pull request #15327: [TIR] ThreadAllreduce warp-level primitive support with multi-warp

Posted by "MasterJH5574 (via GitHub)" <gi...@apache.org>.
MasterJH5574 commented on code in PR #15327:
URL: https://github.com/apache/tvm/pull/15327#discussion_r1264621347


##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -676,8 +747,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       if (reduce_extent == 1) {
         return false;  // no need to warp reduce
       } else {
-        if (warp_size_ % reduce_extent == 0) {
-          return true;  // warp size is multiple of reduce extent
+        if (warp_size_ % reduce_extent == 0 ||

Review Comment:
   Great suggestion. Added.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] MasterJH5574 commented on a diff in pull request #15327: [TIR] ThreadAllreduce warp-level primitive support with multi-warp

Posted by "MasterJH5574 (via GitHub)" <gi...@apache.org>.
MasterJH5574 commented on code in PR #15327:
URL: https://github.com/apache/tvm/pull/15327#discussion_r1264621385


##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -299,131 +298,75 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     // broadcast results from lane 0 to all other lanes and store
     // the final reduction result to the proper location.
     //

Review Comment:
   Done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] MrJungle1 commented on pull request #15327: [TIR] ThreadAllreduce warp-level primitive support with multi-warp

Posted by "MrJungle1 (via GitHub)" <gi...@apache.org>.
MrJungle1 commented on PR #15327:
URL: https://github.com/apache/tvm/pull/15327#issuecomment-1639235072

   @MasterJH5574 LGTM ! I also encountered the same problem when I searched for reduce_sum on Ansor. Is your work considered on Ansor?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen merged pull request #15327: [TIR] ThreadAllreduce warp-level primitive support with multi-warp

Posted by "tqchen (via GitHub)" <gi...@apache.org>.
tqchen merged PR #15327:
URL: https://github.com/apache/tvm/pull/15327


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #15327: [TIR] ThreadAllreduce warp-level primitive support with multi-warp

Posted by "yzh119 (via GitHub)" <gi...@apache.org>.
yzh119 commented on code in PR #15327:
URL: https://github.com/apache/tvm/pull/15327#discussion_r1264589363


##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -299,131 +298,75 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     // broadcast results from lane 0 to all other lanes and store
     // the final reduction result to the proper location.
     //
-    if (is_warp_reduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) {
-      ICHECK_LE(reduce_extent, warp_size_) << "not a warp reduction";
-      //
-      // This is the index to the reduction variable, one reduction
-      // variable per warp. Local scope seems easier to reason without
-      // relying on a pattern match pass to fix it later.
-      Array<PrimExpr> zero_indices = {0};
-
-      for (size_t idx = 0; idx < size; ++idx) {
-        Array<PrimExpr> shape = {1};
-
-        Buffer buffer = decl_buffer(shape, types[idx], "red_buf" + std::to_string(idx));
-        Var buffer_var = buffer->data;
-
-        shared_buffer_vars[idx] = buffer_var;
-        shared_bufs[idx] = buffer;
-
-        PrimExpr pred = const_true(types[idx].lanes());
-        seq.emplace_back(BufferStore(shared_bufs[idx], values[idx], zero_indices));
-
-        // Uses a local variable to store the shuffled data.  Later
-        // on, an allocation will be built for this local variable.
-        local_bufs.push_back(decl_buffer(shape, types[idx], "t" + std::to_string(idx)));
-      }
-
-      // The mask for this reducer, as this reducer may sit inside
-      // a divergent control flow. Here it uses a variable to cache the current
-      // active channels.
-      //
+    if (IsWarpReduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) {
+      std::vector<PrimExpr> reduce_results;
       DataType mask_dtype = DataType::UInt(32);
-      Buffer mask_buffer = decl_buffer({1}, mask_dtype, "mask");
-      {
-        PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
-        if (group_extent > 1) {
-          mask = mask & (make_const(mask_dtype, (1ll << reduce_extent) - 1)
-                         << (reduce_extent * cast(mask_dtype, group_index)));
+      PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
+
+      if (reduce_extent <= warp_size_) {

Review Comment:
   Enter single/sub warp reduction branch when `reduce_extent` is less than or equal to `warp_size_`



##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -676,8 +747,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       if (reduce_extent == 1) {
         return false;  // no need to warp reduce
       } else {
-        if (warp_size_ % reduce_extent == 0) {
-          return true;  // warp size is multiple of reduce extent
+        if (warp_size_ % reduce_extent == 0 ||

Review Comment:
   Create some bool variable to make this logic look more clear, e.g.:
   
   `is_subwarp_reduction = warp_size_ % reduce_extent == 0`
   `is_multiwarp_reduction = max_num_threads_ != -1 && max_num_threads_ <= warp_size_ * warp_size_ && reduce_extent % warp_size_ == 0`



##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -676,8 +747,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       if (reduce_extent == 1) {
         return false;  // no need to warp reduce
       } else {
-        if (warp_size_ % reduce_extent == 0) {
-          return true;  // warp size is multiple of reduce extent
+        if (warp_size_ % reduce_extent == 0 ||
+            (max_num_threads_ != -1 && max_num_threads_ <= warp_size_ * warp_size_ &&

Review Comment:
   Shall we throw error is `max_num_threads == -1`?



##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -299,131 +298,75 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     // broadcast results from lane 0 to all other lanes and store
     // the final reduction result to the proper location.
     //
-    if (is_warp_reduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) {
-      ICHECK_LE(reduce_extent, warp_size_) << "not a warp reduction";
-      //
-      // This is the index to the reduction variable, one reduction
-      // variable per warp. Local scope seems easier to reason without
-      // relying on a pattern match pass to fix it later.
-      Array<PrimExpr> zero_indices = {0};
-
-      for (size_t idx = 0; idx < size; ++idx) {
-        Array<PrimExpr> shape = {1};
-
-        Buffer buffer = decl_buffer(shape, types[idx], "red_buf" + std::to_string(idx));
-        Var buffer_var = buffer->data;
-
-        shared_buffer_vars[idx] = buffer_var;
-        shared_bufs[idx] = buffer;
-
-        PrimExpr pred = const_true(types[idx].lanes());
-        seq.emplace_back(BufferStore(shared_bufs[idx], values[idx], zero_indices));
-
-        // Uses a local variable to store the shuffled data.  Later
-        // on, an allocation will be built for this local variable.
-        local_bufs.push_back(decl_buffer(shape, types[idx], "t" + std::to_string(idx)));
-      }
-
-      // The mask for this reducer, as this reducer may sit inside
-      // a divergent control flow. Here it uses a variable to cache the current
-      // active channels.
-      //
+    if (IsWarpReduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) {
+      std::vector<PrimExpr> reduce_results;
       DataType mask_dtype = DataType::UInt(32);
-      Buffer mask_buffer = decl_buffer({1}, mask_dtype, "mask");
-      {
-        PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
-        if (group_extent > 1) {
-          mask = mask & (make_const(mask_dtype, (1ll << reduce_extent) - 1)
-                         << (reduce_extent * cast(mask_dtype, group_index)));
+      PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
+
+      if (reduce_extent <= warp_size_) {
+        if (group_extent > 1 && reduce_extent < warp_size_) {

Review Comment:
   If `reduce_extent` equals warp size we will skip mask here (0xFFF...)



##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -676,8 +747,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       if (reduce_extent == 1) {
         return false;  // no need to warp reduce
       } else {
-        if (warp_size_ % reduce_extent == 0) {
-          return true;  // warp size is multiple of reduce extent
+        if (warp_size_ % reduce_extent == 0 ||
+            (max_num_threads_ != -1 && max_num_threads_ <= warp_size_ * warp_size_ &&
+             reduce_extent % warp_size_ == 0)) {
+          return true;  // warp size is multiple or factor of reduce extent

Review Comment:
   Please update the comment here.



##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -299,131 +298,75 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     // broadcast results from lane 0 to all other lanes and store
     // the final reduction result to the proper location.
     //
-    if (is_warp_reduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) {
-      ICHECK_LE(reduce_extent, warp_size_) << "not a warp reduction";
-      //
-      // This is the index to the reduction variable, one reduction
-      // variable per warp. Local scope seems easier to reason without
-      // relying on a pattern match pass to fix it later.
-      Array<PrimExpr> zero_indices = {0};
-
-      for (size_t idx = 0; idx < size; ++idx) {
-        Array<PrimExpr> shape = {1};
-
-        Buffer buffer = decl_buffer(shape, types[idx], "red_buf" + std::to_string(idx));
-        Var buffer_var = buffer->data;
-
-        shared_buffer_vars[idx] = buffer_var;
-        shared_bufs[idx] = buffer;
-
-        PrimExpr pred = const_true(types[idx].lanes());
-        seq.emplace_back(BufferStore(shared_bufs[idx], values[idx], zero_indices));
-
-        // Uses a local variable to store the shuffled data.  Later
-        // on, an allocation will be built for this local variable.
-        local_bufs.push_back(decl_buffer(shape, types[idx], "t" + std::to_string(idx)));
-      }
-
-      // The mask for this reducer, as this reducer may sit inside
-      // a divergent control flow. Here it uses a variable to cache the current
-      // active channels.
-      //
+    if (IsWarpReduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) {
+      std::vector<PrimExpr> reduce_results;
       DataType mask_dtype = DataType::UInt(32);
-      Buffer mask_buffer = decl_buffer({1}, mask_dtype, "mask");
-      {
-        PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
-        if (group_extent > 1) {
-          mask = mask & (make_const(mask_dtype, (1ll << reduce_extent) - 1)
-                         << (reduce_extent * cast(mask_dtype, group_index)));
+      PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
+
+      if (reduce_extent <= warp_size_) {
+        if (group_extent > 1 && reduce_extent < warp_size_) {
+          mask = mask &
+                 (((1 << reduce_extent) - 1) << (reduce_extent * cast(mask_dtype, group_index)));
         }
-        seq.emplace_back(BufferStore(mask_buffer, mask, zero_indices));
-        // Push the buffer description.  Later this will have an
-        // allocation built for it.
-        local_bufs.push_back(mask_buffer);
-      }
+        std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce(
+            values, types, combiner, reduce_index, reduce_extent, group_index, mask, NullOpt, &seq);
+      } else {

Review Comment:
   Enter multiwarp reduction branch otherwise.



##########
src/tir/transforms/lower_thread_allreduce.cc:
##########
@@ -299,131 +298,75 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     // broadcast results from lane 0 to all other lanes and store
     // the final reduction result to the proper location.
     //

Review Comment:
   Please explain the two-stage logic in the comments as we did previously.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org