You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "MasterJH5574 (via GitHub)" <gi...@apache.org> on 2023/07/02 01:29:39 UTC

[GitHub] [tvm] MasterJH5574 opened a new pull request, #15192: [TIR] Support cross-threaad reduction lowering with thread-broadcasting rewrite

MasterJH5574 opened a new pull request, #15192:
URL: https://github.com/apache/tvm/pull/15192

   This PR enhances the LowerCrossThreadReduction pass with the thread-broadcasting block rewrite.
   
   Specifically, previously whenever a TIR block has thread-broadcast behavior (i.e., there exists some thread var which is free for the block), we never insert a predicate for the block and therefore the generated final code has race condition, which sometimes lead to wrong computation results.
   
   This PR enhances the pass by collecting thread var information along transformation, and rewrite the thread-broadcast TIR block with additional predicate clauses which bound the thread vars and effectively state that "only execute the block when `thread_var == 0`". Therefore, the race condition issue in such blocks is resolved.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] MasterJH5574 commented on a diff in pull request #15192: [TIR] Support cross-threaad reduction lowering with thread-broadcasting rewrite

Posted by "MasterJH5574 (via GitHub)" <gi...@apache.org>.
MasterJH5574 commented on code in PR #15192:
URL: https://github.com/apache/tvm/pull/15192#discussion_r1250425042


##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -645,13 +707,66 @@ class CrossThreadReductionTransformer : public StmtMutator {
       it_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/false);
       new_buffers.insert(new_buffers.end(), it_buffers.value().begin(), it_buffers.value().end());
     }
-    // Step 5. Transform.
+    // Step 4. Transform.
     loop2new_stmt_[reduction_loops[0]] =
         TransformReductionBlock(realize, it_buffers, ct_buffers, reduction_buffers, wb_indices,
                                 reducer, combiner_rhs, reduction_loops);
-    // Step 6. Return an empty statement, because the transformation result will be inserted when
-    // returning to the first reduction-related loop.
-    return Stmt{nullptr};
+  }
+
+  Stmt MakeCrossThreadBroadcast(
+      const BlockRealizeNode* realize,
+      const std::vector<std::pair<ThreadScope, Range>>& unbound_thread2range) {
+    // Step 1. Generate loop var for each unbound thread.
+    // Update the block predicate with clauses of `thread_var == min`.
+    PrimExpr predicate = realize->predicate;
+    Array<Var> loop_vars;
+    loop_vars.reserve(unbound_thread2range.size());
+    for (auto [scope, range] : unbound_thread2range) {
+      std::string dim_index(1, static_cast<char>(scope.dim_index + 'x'));
+      Var loop_var("t" + dim_index, range->min->dtype);
+      loop_vars.push_back(loop_var);
+      predicate = (loop_var == range->min) && predicate;
+    }
+
+    // Step 2. Update the BlockRealize with the new predicate.
+    ObjectPtr<BlockRealizeNode> p_realize = make_object<BlockRealizeNode>(*realize);
+    p_realize->predicate = std::move(predicate);
+
+    // Step 3. Wrap the updated BlockRealize with the new loops.
+    Stmt body(p_realize);
+    for (int i = 0; i < static_cast<int>(unbound_thread2range.size()); ++i) {
+      std::string dim_index(1, static_cast<char>(unbound_thread2range[i].first.dim_index + 'x'));
+      body = For(
+          /*loop_var=*/loop_vars[i],                          //
+          /*min=*/unbound_thread2range[i].second->min,        //
+          /*extent=*/unbound_thread2range[i].second->extent,  //
+          /*kind=*/ForKind::kThreadBinding,                   //
+          /*body=*/body,                                      //
+          /*thread_binding=*/
+          IterVar(NullValue<Range>(), Var(""), IterVarType::kThreadIndex,
+                  "threadIdx." + dim_index));
+    }
+    return body;
+  }
+
+  Stmt VisitStmt_(const BlockRealizeNode* realize) final {
+    // Part 1. Check if the block needs cross-thread reduction rewrite.
+    std::vector<const ForNode*> reduction_loops = NeedCrossThreadReduction(realize);
+    if (!reduction_loops.empty()) {
+      // Return an empty statement, because the transformation result will
+      // be inserted when returning to the first reduction-related loop.
+      MakeCrossThreadReduction(realize, reduction_loops);
+      return Stmt{nullptr};
+    }
+

Review Comment:
   Yep updated, as this pattern basically appears only when cross-thread reduction exists.



##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -478,6 +497,30 @@ class CrossThreadReductionTransformer : public StmtMutator {
     return need ? reduction_loops : std::vector<const ForNode*>{};
   }
 
+  // Check if the input block needs thread broadcast rewrite.
+  // One block needs broadcast rewrite when there exists one or more thread
+  // vars which vars free variables to this block.
+  std::vector<std::pair<ThreadScope, Range>> NeedCrossThreadBroadcast(
+      const BlockRealizeNode* realize) {
+    std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual> unbound_thread2range =
+        thread2range_;
+    for (const PrimExpr& iter_value : realize->iter_values) {

Review Comment:
   Thanks for the good catch!! Updated.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] junrushao merged pull request #15192: [TIR] Support cross-threaad reduction lowering with thread-broadcasting rewrite

Posted by "junrushao (via GitHub)" <gi...@apache.org>.
junrushao merged PR #15192:
URL: https://github.com/apache/tvm/pull/15192


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #15192: [TIR] Support cross-threaad reduction lowering with thread-broadcasting rewrite

Posted by "tvm-bot (via GitHub)" <gi...@apache.org>.
tvm-bot commented on PR #15192:
URL: https://github.com/apache/tvm/pull/15192#issuecomment-1616257062

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   <!--bot-comment-ccs-start-->
    * cc @Hzfengsy, @junrushao, @quic-sanirudh, @shingjan <sub>See [#10317](https://github.com/apache/tvm/issues/10317) for details</sub><!--bot-comment-ccs-end-->
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #15192: [TIR] Support cross-threaad reduction lowering with thread-broadcasting rewrite

Posted by "tqchen (via GitHub)" <gi...@apache.org>.
tqchen commented on code in PR #15192:
URL: https://github.com/apache/tvm/pull/15192#discussion_r1249603596


##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -645,13 +707,66 @@ class CrossThreadReductionTransformer : public StmtMutator {
       it_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/false);
       new_buffers.insert(new_buffers.end(), it_buffers.value().begin(), it_buffers.value().end());
     }
-    // Step 5. Transform.
+    // Step 4. Transform.
     loop2new_stmt_[reduction_loops[0]] =
         TransformReductionBlock(realize, it_buffers, ct_buffers, reduction_buffers, wb_indices,
                                 reducer, combiner_rhs, reduction_loops);
-    // Step 6. Return an empty statement, because the transformation result will be inserted when
-    // returning to the first reduction-related loop.
-    return Stmt{nullptr};
+  }
+
+  Stmt MakeCrossThreadBroadcast(
+      const BlockRealizeNode* realize,
+      const std::vector<std::pair<ThreadScope, Range>>& unbound_thread2range) {
+    // Step 1. Generate loop var for each unbound thread.
+    // Update the block predicate with clauses of `thread_var == min`.
+    PrimExpr predicate = realize->predicate;
+    Array<Var> loop_vars;
+    loop_vars.reserve(unbound_thread2range.size());
+    for (auto [scope, range] : unbound_thread2range) {
+      std::string dim_index(1, static_cast<char>(scope.dim_index + 'x'));
+      Var loop_var("t" + dim_index, range->min->dtype);
+      loop_vars.push_back(loop_var);
+      predicate = (loop_var == range->min) && predicate;
+    }
+
+    // Step 2. Update the BlockRealize with the new predicate.
+    ObjectPtr<BlockRealizeNode> p_realize = make_object<BlockRealizeNode>(*realize);
+    p_realize->predicate = std::move(predicate);
+
+    // Step 3. Wrap the updated BlockRealize with the new loops.
+    Stmt body(p_realize);
+    for (int i = 0; i < static_cast<int>(unbound_thread2range.size()); ++i) {
+      std::string dim_index(1, static_cast<char>(unbound_thread2range[i].first.dim_index + 'x'));
+      body = For(
+          /*loop_var=*/loop_vars[i],                          //
+          /*min=*/unbound_thread2range[i].second->min,        //
+          /*extent=*/unbound_thread2range[i].second->extent,  //
+          /*kind=*/ForKind::kThreadBinding,                   //
+          /*body=*/body,                                      //
+          /*thread_binding=*/
+          IterVar(NullValue<Range>(), Var(""), IterVarType::kThreadIndex,
+                  "threadIdx." + dim_index));
+    }
+    return body;
+  }
+
+  Stmt VisitStmt_(const BlockRealizeNode* realize) final {
+    // Part 1. Check if the block needs cross-thread reduction rewrite.
+    std::vector<const ForNode*> reduction_loops = NeedCrossThreadReduction(realize);
+    if (!reduction_loops.empty()) {
+      // Return an empty statement, because the transformation result will
+      // be inserted when returning to the first reduction-related loop.
+      MakeCrossThreadReduction(realize, reduction_loops);
+      return Stmt{nullptr};
+    }
+

Review Comment:
   Only checks if we already have cross thread reduction, this will reduce the amount of checks needed for other realize.



##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -578,9 +621,31 @@ class CrossThreadReductionTransformer : public StmtMutator {
   Stmt VisitStmt_(const ForNode* loop) final {
     loop_stack_.push_back(loop);
     loop_range_map_.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
+
+    // Collect loop-thread information:
+    // - when encountering a threadIdx loop, we keep note of its domain and
+    // the "loop var -> thread scope" relation, in order to collect all existing
+    // threads within a thread block.
+    // - we are careful about thread block boundary for safety.
+    int old_thread_block_depth = thread_block_depth_;
+    if (loop->kind == ForKind::kThreadBinding) {
+      ThreadScope scope = ThreadScope::Create(loop->thread_binding.value()->thread_tag);
+      if (scope.rank == 0 || !thread_block_depth_) {

Review Comment:
   Do clear immediately after VisitLoop, since that helps to remove prior states



##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -478,6 +497,30 @@ class CrossThreadReductionTransformer : public StmtMutator {
     return need ? reduction_loops : std::vector<const ForNode*>{};
   }
 
+  // Check if the input block needs thread broadcast rewrite.
+  // One block needs broadcast rewrite when there exists one or more thread
+  // vars which vars free variables to this block.
+  std::vector<std::pair<ThreadScope, Range>> NeedCrossThreadBroadcast(
+      const BlockRealizeNode* realize) {
+    std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual> unbound_thread2range =
+        thread2range_;
+    for (const PrimExpr& iter_value : realize->iter_values) {

Review Comment:
   Checking iter_values have a few disadvantages, for example, unit loop could be unbound. Instead, check the surrounding `loop_stack_` which is much more robust



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] MasterJH5574 commented on a diff in pull request #15192: [TIR] Support cross-threaad reduction lowering with thread-broadcasting rewrite

Posted by "MasterJH5574 (via GitHub)" <gi...@apache.org>.
MasterJH5574 commented on code in PR #15192:
URL: https://github.com/apache/tvm/pull/15192#discussion_r1250424365


##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -578,9 +621,31 @@ class CrossThreadReductionTransformer : public StmtMutator {
   Stmt VisitStmt_(const ForNode* loop) final {
     loop_stack_.push_back(loop);
     loop_range_map_.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
+
+    // Collect loop-thread information:
+    // - when encountering a threadIdx loop, we keep note of its domain and
+    // the "loop var -> thread scope" relation, in order to collect all existing
+    // threads within a thread block.
+    // - we are careful about thread block boundary for safety.
+    int old_thread_block_depth = thread_block_depth_;
+    if (loop->kind == ForKind::kThreadBinding) {
+      ThreadScope scope = ThreadScope::Create(loop->thread_binding.value()->thread_tag);
+      if (scope.rank == 0 || !thread_block_depth_) {

Review Comment:
   Now changed the logic to
   
   > When exiting a loop,
   > * if it is a `blockIdx` loop, clear the map,
   > * if it is a `threadIdx` loop, clear the map when both `threadIdx` depth and `blockIdx` depth are 0.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on pull request #15192: [TIR] Support cross-threaad reduction lowering with thread-broadcasting rewrite

Posted by "tqchen (via GitHub)" <gi...@apache.org>.
tqchen commented on PR #15192:
URL: https://github.com/apache/tvm/pull/15192#issuecomment-1616657358

   @tvm-bot rerun


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] MasterJH5574 commented on a diff in pull request #15192: [TIR] Support cross-threaad reduction lowering with thread-broadcasting rewrite

Posted by "MasterJH5574 (via GitHub)" <gi...@apache.org>.
MasterJH5574 commented on code in PR #15192:
URL: https://github.com/apache/tvm/pull/15192#discussion_r1250367349


##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -645,13 +707,66 @@ class CrossThreadReductionTransformer : public StmtMutator {
       it_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/false);
       new_buffers.insert(new_buffers.end(), it_buffers.value().begin(), it_buffers.value().end());
     }
-    // Step 5. Transform.
+    // Step 4. Transform.
     loop2new_stmt_[reduction_loops[0]] =
         TransformReductionBlock(realize, it_buffers, ct_buffers, reduction_buffers, wb_indices,
                                 reducer, combiner_rhs, reduction_loops);
-    // Step 6. Return an empty statement, because the transformation result will be inserted when
-    // returning to the first reduction-related loop.
-    return Stmt{nullptr};
+  }
+
+  Stmt MakeCrossThreadBroadcast(
+      const BlockRealizeNode* realize,
+      const std::vector<std::pair<ThreadScope, Range>>& unbound_thread2range) {
+    // Step 1. Generate loop var for each unbound thread.
+    // Update the block predicate with clauses of `thread_var == min`.
+    PrimExpr predicate = realize->predicate;
+    Array<Var> loop_vars;
+    loop_vars.reserve(unbound_thread2range.size());
+    for (auto [scope, range] : unbound_thread2range) {
+      std::string dim_index(1, static_cast<char>(scope.dim_index + 'x'));
+      Var loop_var("t" + dim_index, range->min->dtype);
+      loop_vars.push_back(loop_var);
+      predicate = (loop_var == range->min) && predicate;
+    }
+
+    // Step 2. Update the BlockRealize with the new predicate.
+    ObjectPtr<BlockRealizeNode> p_realize = make_object<BlockRealizeNode>(*realize);
+    p_realize->predicate = std::move(predicate);
+
+    // Step 3. Wrap the updated BlockRealize with the new loops.
+    Stmt body(p_realize);
+    for (int i = 0; i < static_cast<int>(unbound_thread2range.size()); ++i) {
+      std::string dim_index(1, static_cast<char>(unbound_thread2range[i].first.dim_index + 'x'));
+      body = For(
+          /*loop_var=*/loop_vars[i],                          //

Review Comment:
   To force a line break for formatter. So that it doesn't format it into
   ```c++
   For(/*loop_var=*/loop_vars[i], /*min=*/...)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #15192: [TIR] Support cross-threaad reduction lowering with thread-broadcasting rewrite

Posted by "yzh119 (via GitHub)" <gi...@apache.org>.
yzh119 commented on code in PR #15192:
URL: https://github.com/apache/tvm/pull/15192#discussion_r1249538453


##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -645,13 +707,66 @@ class CrossThreadReductionTransformer : public StmtMutator {
       it_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/false);
       new_buffers.insert(new_buffers.end(), it_buffers.value().begin(), it_buffers.value().end());
     }
-    // Step 5. Transform.
+    // Step 4. Transform.
     loop2new_stmt_[reduction_loops[0]] =
         TransformReductionBlock(realize, it_buffers, ct_buffers, reduction_buffers, wb_indices,
                                 reducer, combiner_rhs, reduction_loops);
-    // Step 6. Return an empty statement, because the transformation result will be inserted when
-    // returning to the first reduction-related loop.
-    return Stmt{nullptr};
+  }
+
+  Stmt MakeCrossThreadBroadcast(
+      const BlockRealizeNode* realize,
+      const std::vector<std::pair<ThreadScope, Range>>& unbound_thread2range) {
+    // Step 1. Generate loop var for each unbound thread.
+    // Update the block predicate with clauses of `thread_var == min`.
+    PrimExpr predicate = realize->predicate;
+    Array<Var> loop_vars;
+    loop_vars.reserve(unbound_thread2range.size());
+    for (auto [scope, range] : unbound_thread2range) {
+      std::string dim_index(1, static_cast<char>(scope.dim_index + 'x'));
+      Var loop_var("t" + dim_index, range->min->dtype);
+      loop_vars.push_back(loop_var);
+      predicate = (loop_var == range->min) && predicate;
+    }
+
+    // Step 2. Update the BlockRealize with the new predicate.
+    ObjectPtr<BlockRealizeNode> p_realize = make_object<BlockRealizeNode>(*realize);
+    p_realize->predicate = std::move(predicate);
+
+    // Step 3. Wrap the updated BlockRealize with the new loops.
+    Stmt body(p_realize);
+    for (int i = 0; i < static_cast<int>(unbound_thread2range.size()); ++i) {
+      std::string dim_index(1, static_cast<char>(unbound_thread2range[i].first.dim_index + 'x'));
+      body = For(
+          /*loop_var=*/loop_vars[i],                          //

Review Comment:
   What's the purpose of following `//`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org