You are viewing a plain text version of this content. The canonical link for it is here.
Posted to by GitBox <> on 2021/11/14 09:12:58 UTC

[GitHub] [tvm] MasterJH5574 commented on a change in pull request #9360: [TensorIR] Cross-Thread Reduction

MasterJH5574 commented on a change in pull request #9360:

File path: src/tir/transforms/
@@ -156,145 +181,333 @@ class BufferAccessReplacer : public StmtExprMutator {
  * \brief Substitute a given source block with a given target block, or remove the source block
  * branch from the AST if the target block is undefined
-class ReductionBlockReplacer : public StmtMutator {
+class InThreadReducerMaker : private StmtMutator {
-  explicit ReductionBlockReplacer(const BlockRealizeNode* src_block, BlockRealize tgt_block)
-      : src_block_(src_block), tgt_block_(std::move(tgt_block)) {}
+  static Optional<Stmt> Make(const BlockRealizeNode* src_realize,
+                             Optional<BlockRealize> tgt_realize, Stmt stmt) {
+    return InThreadReducerMaker(src_realize, std::move(tgt_realize))(std::move(stmt));
+  }
-  Stmt VisitStmt_(const BlockRealizeNode* block_realize) final {
-    return block_realize == src_block_ ? tgt_block_ : GetRef<BlockRealize>(block_realize);
+  explicit InThreadReducerMaker(const BlockRealizeNode* src_realize,
+                                Optional<BlockRealize> tgt_realize)
+      : src_realize_(src_realize), tgt_realize_(tgt_realize) {}
+  Stmt VisitStmt_(const BlockRealizeNode* realize) final {
+    if (realize == src_realize_) {
+      return tgt_realize_.defined()  //
+                 ? tgt_realize_.value()
+                 : Stmt{nullptr};
+    }
+    return GetRef<BlockRealize>(realize);
   Stmt VisitStmt_(const ForNode* loop) final {
-    For res = Downcast<For>(StmtMutator::VisitStmt_(loop));
-    return !res.defined() ? Stmt{nullptr} : (res->thread_binding.defined() ? res->body : res);
+    if (Optional<For> opt_res = Downcast<Optional<For>>(StmtMutator::VisitStmt_(loop))) {
+      For res = opt_res.value();
+      if (res->thread_binding.defined()) {
+        return res->body;
+      } else {
+        return res;
+      }
+    } else {
+      return Stmt{nullptr};
+    }
   Stmt VisitStmt_(const SeqStmtNode* seq) final {
-    Array<Stmt> results;
-    results.reserve(seq->size());
+    Array<Stmt> stmts;
+    stmts.reserve(seq->size());
     for (const Stmt& stmt : seq->seq) {
-      Stmt res = StmtMutator::VisitStmt(stmt);
-      if (res.defined()) {
-        results.push_back(res);
+      if (Optional<Stmt> opt_res = VisitStmt(stmt)) {
+        stmts.push_back(opt_res.value());
-    return results.empty() ? Stmt{nullptr} : SeqStmt(results);
+    return stmts.empty() ? Stmt{nullptr} : SeqStmt::Flatten(stmts);
-  const BlockRealizeNode* src_block_;
-  BlockRealize tgt_block_;
+  const BlockRealizeNode* src_realize_;
+  Optional<BlockRealize> tgt_realize_;
+ * \brief Create the lowered allreduce block transformed from the input reduction block
+ * \param reduction_block The input reduction block
+ * \param it_buffer The buffer to store in-thread reduction results
+ * \param ct_buffer The buffer to store cross-thread reduction results
+ * \param reducer The reduction function
+ * \param combiner_rhs The RHS of the combiner
+ * \param reduction_loops The reduction loops
+ */
+Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optional<Buffer>& it_buffer,
+                             const Buffer& ct_buffer, const CommReducer& reducer,
+                             const PrimExpr& combiner_rhs,
+                             const std::vector<const ForNode*>& reduction_loops) {
+  const BlockNode* block = realize->block.get();
+  Buffer wb_buffer = block->writes[0]->buffer;
+  Array<Range> wb_region = block->writes[0]->region;
+  BufferRegion ct_buffer_region(ct_buffer, {Range::FromMinExtent(0, 1)});
+  Optional<BufferRegion> it_buffer_region = NullOpt;
+  if (it_buffer.defined()) {
+    it_buffer_region = BufferRegion(it_buffer.value(), {Range::FromMinExtent(0, 1)});
+  }
+  // In total, the block is transformed into at most 4 statements
+  // - Stmt 1: initialize the buffer for in-thread reduction
+  // - Stmt 2: do in-thread reduction
+  // - Stmt 3: do cross-thread reduction
+  // - Stmt 4: write cross-thread reduction result to the original buffer
+  Array<Stmt> stmts;
+  stmts.reserve(4);
+  // Stmt 1: initialize the buffer for in-thread reduction
+  if (it_buffer.defined()) {
+    BufferStore init = Downcast<BufferStore>(block->init);
+    stmts.push_back(BlockRealize(
+        /*iter_values=*/{},
+        /*predicate=*/const_true(),
+        /*block=*/
+        Block(/*iter_vars=*/{},
+              /*reads=*/{},
+              /*writes=*/{it_buffer_region.value()},
+              /*name_hint=*/block->name_hint + "_in_thread_init",
+              /*body=*/
+              BufferStore(/*buffer=*/it_buffer.value(),
+                          /*value=*/init->value,
+                          /*indices=*/{Integer(0)}))));
+  }
+  // Stmt 2: do in-thread reduction
+  {
+    Optional<BlockRealize> new_realize = NullOpt;
+    // If need to generate in-thread reduction,
+    // then replace `wb_buffer` with `it_buffer` accordingly in given BlockRealize
+    // otherwise, directly remove given BlockRealize
+    if (it_buffer.defined()) {
+      ObjectPtr<BlockNode> new_block = make_object<BlockNode>(*block);
+      new_block->reads = RemoveBufferFromBufferRegions(std::move(new_block->reads), wb_buffer);
+      new_block->reads.push_back(it_buffer_region.value());
+      new_block->writes = {it_buffer_region.value()};
+      new_block->name_hint = new_block->name_hint + "_in_thread";
+      new_block->body =
+          BufferReplacer::Run(wb_buffer, it_buffer.value(), std::move(new_block->body));
+      new_block->init = NullOpt;
+      ObjectPtr<BlockRealizeNode> n = make_object<BlockRealizeNode>(*realize);
+      n->block = Block(new_block);
+      new_realize = BlockRealize(n);
+    }
+    For loop = GetRef<For>(reduction_loops[0]);
+    if (Optional<Stmt> stmt = InThreadReducerMaker::Make(realize, new_realize, std::move(loop))) {
+      stmts.push_back(stmt.value());
+    }
+  }
+  // Stmt 3: do cross-thread reduction
+  {
+    // Step 3.1. Create the parameters to the intrinsic
+    Array<PrimExpr> parameters;
+    parameters.reserve(reduction_loops.size() + 4);
+    // 1-st argument: size
+    parameters.push_back(make_const(DataType::UInt(32), 1));
+    // 2-nd argument: source
+    if (it_buffer.defined()) {
+      parameters.push_back(BufferLoad(it_buffer.value(), {Integer(0)}));
+    } else {
+      parameters.push_back(combiner_rhs);
+    }
+    // 3-rd argument: predicate
+    parameters.push_back(const_true());
+    // 4-th argument: destination
+    parameters.push_back(ct_buffer->data);
+    // next arguments: all the reduction threads
+    for (const ForNode* reduction_loop : reduction_loops) {
+      if (reduction_loop->thread_binding.defined()) {
+        parameters.push_back(reduction_loop->loop_var);
+      }
+    }
+    // Step 3.2. Create the block and the block-realize.
+    Array<IterVar> iter_vars{nullptr};
+    Array<PrimExpr> bindings{nullptr};
+    Array<BufferRegion> reads{nullptr};
+    if (it_buffer.defined()) {
+      iter_vars = Array<IterVar>{};
+      bindings = Array<PrimExpr>{};
+      reads = {it_buffer_region.value()};
+    } else {
+      iter_vars = block->iter_vars;
+      bindings = realize->iter_values;
+      reads = {RemoveBufferFromBufferRegions(block->reads, wb_buffer)};
+    }
+    stmts.push_back(BlockRealize(
+        /*iter_values=*/std::move(bindings),
+        /*predicate=*/const_true(),
+        /*block=*/
+        Block(/*iter_vars=*/std::move(iter_vars),
+              /*reads=*/std::move(reads),
+              /*writes=*/{ct_buffer_region},
+              /*name_hint=*/block->name_hint + "_cross_thread",
+              /*body=*/
+              AttrStmt(/*node=*/reducer,
+                       /*attr_key=*/tir::attr::reduce_scope,
+                       /*value=*/make_zero(DataType::Handle()),
+                       /*body=*/
+                       Evaluate(Call(/*dtype=*/DataType::Handle(),
+                                     /*op=*/tir::builtin::tvm_thread_allreduce(),
+                                     /*args=*/std::move(parameters)))))));
+  }
+  // Stmt 4: write cross-thread reduction result to the original buffer
+  {
+    ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size());
+    int n_iter = static_cast<int>(block->iter_vars.size());
+    Array<IterVar> iter_vars;
+    Array<PrimExpr> bindings;
+    Map<Var, PrimExpr> var_map;
+    iter_vars.reserve(n_iter);
+    bindings.reserve(n_iter);
+    for (int i = 0; i < n_iter; ++i) {
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != kCommReduce) {
+        IterVar new_iter_var{nullptr};
+        {
+          ObjectPtr<IterVarNode> n = make_object<IterVarNode>(*iter_var.get());
+          ObjectPtr<VarNode> v = make_object<VarNode>(*iter_var->var.get());
+          n->var = Var(v);
+          new_iter_var = IterVar(n);
+        }
+        iter_vars.push_back(new_iter_var);
+        bindings.push_back(binding);
+        var_map.Set(iter_var->var, new_iter_var->var);
+      }
+    }
+    BufferStore update = Downcast<BufferStore>(block->body);
+    update = Downcast<BufferStore>(Substitute(std::move(update), var_map));
+    stmts.push_back(BlockRealize(
+        /*iter_values=*/std::move(bindings),
+        /*predicate=*/const_true(),
+        /*block=*/
+        Block(
+            /*iter_vars=*/std::move(iter_vars),
+            /*reads=*/{std::move(ct_buffer_region)},
+            /*writes=*/{BufferRegion(wb_buffer, Substitute(wb_region, var_map))},
+            /*name_hint=*/block->name_hint + "_write_back",
+            /*body=*/
+            BufferStore(/*buffer=*/wb_buffer,
+                        /*value=*/BufferLoad(ct_buffer, {Integer(0)}),
+                        /*indices=*/update->indices))));
+  }
+  // Final step: Wrap all the above four statements with the reduction loops bound to threadIdx
+  Stmt new_stmt = SeqStmt::Flatten(std::move(stmts));
+  for (auto rit = reduction_loops.rbegin(); rit != reduction_loops.rend(); ++rit) {
+    const ForNode* loop = *rit;
+    if (loop->thread_binding.defined()) {
+      ObjectPtr<ForNode> n = make_object<ForNode>(*loop);
+      n->body = std::move(new_stmt);
+      new_stmt = For(n);
+    }
+  }
+  return new_stmt;
  * \brief Detect cross-thread reduction pattern and then transform
 class CrossThreadReductionTransformer : public StmtMutator {
   // Check if the input block needs cross-thread reduction.
-  bool NeedCrossThreadReduction(const BlockRealizeNode* block_realize) {
+  std::vector<const ForNode*> NeedCrossThreadReduction(const BlockRealizeNode* realize) {
     // Step 0. If the block is the root block, just return.
     if (block_stack_.empty()) {
-      return false;
+      return {};
     // Step 1. If the block is not a reduction block, cross-thread reduction is not needed.
-    if (!IsReductionBlock(GetRef<BlockRealize>(block_realize), loop_range_map_,
+    if (!IsReductionBlock(GetRef<BlockRealize>(realize), loop_range_map_,
                           GetRef<Block>(block_stack_.back()), &analyzer_)) {
-      return false;
+      return {};
     // Step 2. Collect all the vars that appear in the bindings of reduction block iters.
     std::unordered_set<const VarNode*> reduction_vars;
-    GetVarsTouchedByBlockIters(GetRef<BlockRealize>(block_realize), nullptr, &reduction_vars);
+    GetVarsTouchedByBlockIters(GetRef<BlockRealize>(realize), nullptr, &reduction_vars);
     // Step 3. Collect the loops whose loop vars appear in the bindings of reduction block iters.
     // We call these loops "reduction-related".
     // Step 4. See whether at least one reduction-related loop is bound to thread axis in GPU - if
     // so, cross-thread reduction is needed. If none of the reduction-related loops is bound to
     // thread axis, cross-thread reduction is not needed for the input block.
     bool need = false;
-    reduction_loops_.clear();
+    std::vector<const ForNode*> reduction_loops;
     for (const ForNode* loop : loop_stack_) {
       if (reduction_vars.count(loop->loop_var.get())) {
         // Step 3. Collect the loop.
-        reduction_loops_.push_back(loop);
+        reduction_loops.push_back(loop);
         // Step 4. See whether the loop is bound to some thread axis.
         if (loop->thread_binding.defined()) {
           need = true;
-    return need;
+    return need ? reduction_loops : std::vector<const ForNode*>{};
   // Given that the input block needs cross-thread reduction, check if cross-thread reduction can
   // be applied to the block (i.e., the block satisfies all necessary conditions of cross-thread
   // reduction).
-  void CheckCanApplyCrossThreadReduction(const BlockNode* block) {
-    const String& block_name = block->name_hint;
+  std::tuple<int, CommReducer, PrimExpr> CheckCanApplyCrossThreadReduction(
+      const BlockNode* block, const std::vector<const ForNode*>& reduction_loops) const {
     // Condition 1. The block being applied cross-thread reduction should write to single buffer.
-    int n_write_buffer = static_cast<int>(block->writes.size());
-    CHECK_EQ(n_write_buffer, 1) << "ValueError: Cross-thread reduction requires the block to only "
-                                   "write to single buffer. However, the block "
-                                << block_name << " writes to " << n_write_buffer << " buffer(s).";
+    CHECK_EQ(block->writes.size(), 1)
+        << "ValueError: Cross-thread reduction requires the block to only "
+           "write to single buffer. However, the block "
+        << block->name_hint << " writes to " << block->writes.size() << " buffer(s).";
     // Condition 2. All the reduction-related loops should be the deepest among all statements
     // outside the block (ignoring SeqStmt here).
     int n_deepest_reduction_loops = 0;
     for (auto rit = statement_stack_.rbegin() + 1; rit != statement_stack_.rend(); ++rit) {
-      if ((*rit)->IsInstance<SeqStmtNode>()) {
-        // Skip SeqStmt.
-        continue;
-      }
-      if (std::find(reduction_loops_.begin(), reduction_loops_.end(),
-                    reinterpret_cast<const ForNode*>(*rit)) == reduction_loops_.end()) {
-        break;
+      const StmtNode* stmt = *rit;
+      if (stmt->IsInstance<ForNode>()) {
+        const ForNode* loop = static_cast<const ForNode*>(stmt);
+        if (std::find(reduction_loops.begin(), reduction_loops.end(), loop) ==
+            reduction_loops.end()) {
+          break;
+        }
+        ++n_deepest_reduction_loops;

Review comment:
       @junrushao1994 IMO this check should be applied to all kinds of Stmt except for SeqStmt. My point is that the reduction loops should be the deepest *Stmts*, not only the deepest loops. If we only apply the check to loops, then it might be possible that there is a block or a IfThenElse among the loops - which is not allowed. 
   It's a corner case, and I believe that most of the time we won't encounter the case.
   Seems that I forgot to add such cases in tests 😅.

This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail:

For queries about this service, please contact Infrastructure at: