You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/10/25 01:35:10 UTC

[GitHub] [tvm] MasterJH5574 opened a new pull request #9360: [TensorIR] Cross-Thread Reduction

MasterJH5574 opened a new pull request #9360:
URL: https://github.com/apache/tvm/pull/9360


   Hi community! This PR adds cross-thread reduction support for TensorIR. After this PR, cross-thread reduction patterns in TIR can be successfully lowered.
   
   cc @Hzfengsy @vinx13 @comaniac @junrushao1994 @jcf94 @jinhongyii @spectrometerHBH  @tqchen 
   
   Co-authored-by: Wuwei Lin <wu...@apache.org>
   Co-authored-by: Junru Shao <ju...@gmail.com>
   Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
   Co-authored-by: Hongyi Jin <32...@qq.com>
   Co-authored-by: Bohan Hou <32...@users.noreply.github.com>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 merged pull request #9360: [TensorIR] Cross-Thread Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 merged pull request #9360:
URL: https://github.com/apache/tvm/pull/9360


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on a change in pull request #9360: [TensorIR] Cross-Thread Reduction

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on a change in pull request #9360:
URL: https://github.com/apache/tvm/pull/9360#discussion_r735241475



##########
File path: src/tir/schedule/analysis.h
##########
@@ -323,6 +327,53 @@ struct ProducerConsumerSplit {
  */
 Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write);
 
+/******** Reduction Block Related ********/
+
+/*!
+ * \brief Convert the `init` and `body` of the input block to BufferStores
+ * \tparam in_schedule Whether the function is called by schedule primitives
+ * \param self The schedule state
+ * \param block The block to be analyzed
+ * \return The BufferStores of the `init` and `body` of the input block
+ * \throw ScheduleError If the `init` or `body` is not BufferStore, or they don't write to the same
+ * buffer
+ */
+template <bool in_schedule>
+std::pair<BufferStore, BufferStore> GetBufferStoreNodes(const ScheduleState& self,

Review comment:
       Prefer name `GetBufferStoreFromReductionBlock`

##########
File path: src/tir/schedule/analysis/analysis.cc
##########
@@ -552,6 +520,9 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize,
     } else {
       has_block_vars_of_other_types = true;
     }
+    if (set == nullptr) {

Review comment:
       Please add a regression test if it's a bug

##########
File path: src/tir/transforms/lower_cross_thread_reduction.cc
##########
@@ -0,0 +1,590 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file lower_cross_thread_reduction.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include "../schedule/analysis.h"
+#include "ir_utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Check the dominant property of a block:
+ * the block is the only writer of its output, dominating the reader of its output buffers
+ * \param scope_block The scope block of the block to be checked
+ * \param block The block whose dominant property is to be checked
+ * \return A boolean indicating if the block is a dominant block
+ */
+bool IsDominantBlock(const Block& scope_block, const Block& block) {
+  // Step 1. Count the number of writers for each buffer written by the scope block.
+  std::unordered_map<const BufferNode*, int> buffer_writer_cnt;
+  PreOrderVisit(scope_block->body, [&buffer_writer_cnt](const ObjectRef& obj) {
+    if (const auto* block = obj.as<BlockNode>()) {
+      for (const BufferRegion& buffer_region : block->writes) {
+        ++buffer_writer_cnt[buffer_region->buffer.get()];
+      }
+      return false;
+    }
+    return true;
+  });
+  // Step 2. Check whether `block` is the only writer of its outputs.
+  for (const BufferRegion& buffer_region : block->writes) {
+    ICHECK(buffer_writer_cnt.count(buffer_region->buffer.get()));
+    if (buffer_writer_cnt[buffer_region->buffer.get()] != 1) {
+      return false;
+    }
+  }
+  return true;
+}
+
+/*!
+ * \brief Check whether the input block is a reduction block.
+ * \param block_realize The block to be checked
+ * \param loop_range_map The mapping from the loop variables outside the input block to their ranges
+ * \param scope_block The scope block of the input block
+ * \param analyzer The analyzer
+ * \return A boolean indicating whether the input block is a reduction block.
+ * \note A similar check has been implemented in "src/tir/schedule/analysis.h", but that check is
+ * based on `tir.Schedule`. Here we have no schedule information, and thus we must implement the
+ * check again.
+ */
+bool IsReductionBlock(const BlockRealize& block_realize, const Map<Var, Range>& loop_range_map,
+                      const Block& scope_block, arith::Analyzer* analyzer) {
+  const auto* block = block_realize->block.as<BlockNode>();
+  // Cond 1. The block has the `init` statement.
+  if (!block->init.defined()) {
+    return false;
+  }
+  // Cond 2. All the block bindings are quasi-affine expressions.
+  if (!IsAffineBinding(block_realize, loop_range_map, analyzer)) {
+    return false;
+  }
+  // Cond 3. All block vars are either data parallel block vars or reduction block vars. Meanwhile,
+  // we collect all the reduction block vars.
+  if (!ContainsOnlyDataParAndReductionBlockIter(block->iter_vars)) {
+    return false;
+  }
+  // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its
+  // output buffers.
+  if (!IsDominantBlock(scope_block, GetRef<Block>(block))) {
+    return false;
+  }
+  // Cond 5. The reduction block vars are not used to index the output buffers.
+  return ReductionIterNotIndexOutputBuffer(GetRef<Block>(block));
+}
+
+/*!
+ * \brief Create an intermediate buffer with specified name and data type
+ * \param name The specified name
+ * \param dtype The specified data type
+ * \return The created buffer
+ */
+Buffer CreateReductionBuffer(String name, const DataType& dtype) {
+  Var var(name, PointerType(PrimType(dtype), "local"));
+  return Buffer(var, dtype, {1}, {1}, PrimExpr(), std::move(name), 0, 0, kDefault);
+}
+
+/*!
+ * \brief Remove the BufferRegions whose buffer is the input buffer
+ * \param buffer_regions The array of BufferRegions to be
+ * \param buffer_to_remove The specified buffer
+ * \return The mutated array of BufferRegions, no longer containing BufferRegion of the input buffer
+ */
+Array<BufferRegion> RemoveBufferFromBufferRegions(const Array<BufferRegion>& buffer_regions,
+                                                  const Buffer& buffer_to_remove) {
+  Array<BufferRegion> res;
+  res.reserve(buffer_regions.size());
+  for (const BufferRegion& buffer_region : buffer_regions) {
+    if (!buffer_region->buffer.same_as(buffer_to_remove)) {
+      res.push_back(buffer_region);
+    }
+  }
+  return res;
+}
+
+/*!
+ * \brief Substitute a given source buffer with a given target buffer in statements or expressions
+ */
+class BufferAccessReplacer : public StmtExprMutator {
+ public:
+  explicit BufferAccessReplacer(Buffer src_buffer, Buffer tgt_buffer)
+      : src_buffer_(std::move(src_buffer)), tgt_buffer_(std::move(tgt_buffer)) {}
+
+ private:
+  PrimExpr VisitExpr_(const BufferLoadNode* load) final {
+    return load->buffer.same_as(src_buffer_) ? BufferLoad(tgt_buffer_, {0})
+                                             : GetRef<BufferLoad>(load);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* store) final {
+    if (store->buffer.same_as(src_buffer_)) {
+      PrimExpr value = StmtExprMutator::VisitExpr(store->value);
+      return BufferStore(tgt_buffer_, value, {0});
+    } else {
+      return StmtMutator::VisitStmt_(store);
+    }
+  }
+
+  Buffer src_buffer_;
+  Buffer tgt_buffer_;
+};
+
+/*!
+ * \brief Substitute a given source block with a given target block, or remove the source block
+ * branch from the AST if the target block is undefined
+ */
+class ReductionBlockReplacer : public StmtMutator {
+ public:
+  explicit ReductionBlockReplacer(const BlockRealizeNode* src_block, BlockRealize tgt_block)
+      : src_block_(src_block), tgt_block_(std::move(tgt_block)) {}
+
+ private:
+  Stmt VisitStmt_(const BlockRealizeNode* block_realize) final {
+    return block_realize == src_block_ ? tgt_block_ : GetRef<BlockRealize>(block_realize);
+  }
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    For res = Downcast<For>(StmtMutator::VisitStmt_(loop));
+    return !res.defined() ? Stmt{nullptr} : (res->thread_binding.defined() ? res->body : res);
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> results;
+    results.reserve(seq->size());
+    for (Stmt stmt : seq->seq) {

Review comment:
       ```suggestion
       for (const Stmt &stmt : seq->seq) {
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9360: [TensorIR] Cross-Thread Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9360:
URL: https://github.com/apache/tvm/pull/9360#discussion_r748664114



##########
File path: src/tir/transforms/lower_cross_thread_reduction.cc
##########
@@ -0,0 +1,590 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file lower_cross_thread_reduction.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include "../schedule/analysis.h"
+#include "ir_utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Check the dominant property of a block:
+ * the block is the only writer of its output, dominating the reader of its output buffers
+ * \param scope_block The scope block of the block to be checked
+ * \param block The block whose dominant property is to be checked
+ * \return A boolean indicating if the block is a dominant block
+ */
+bool IsDominantBlock(const Block& scope_block, const Block& block) {
+  // Step 1. Count the number of writers for each buffer written by the scope block.
+  std::unordered_map<const BufferNode*, int> buffer_writer_cnt;
+  PreOrderVisit(scope_block->body, [&buffer_writer_cnt](const ObjectRef& obj) {
+    if (const auto* block = obj.as<BlockNode>()) {
+      for (const BufferRegion& buffer_region : block->writes) {
+        ++buffer_writer_cnt[buffer_region->buffer.get()];
+      }
+      return false;
+    }
+    return true;
+  });
+  // Step 2. Check whether `block` is the only writer of its outputs.
+  for (const BufferRegion& buffer_region : block->writes) {
+    ICHECK(buffer_writer_cnt.count(buffer_region->buffer.get()));
+    if (buffer_writer_cnt[buffer_region->buffer.get()] != 1) {
+      return false;
+    }
+  }
+  return true;
+}
+
+/*!
+ * \brief Check whether the input block is a reduction block.
+ * \param block_realize The block to be checked
+ * \param loop_range_map The mapping from the loop variables outside the input block to their ranges
+ * \param scope_block The scope block of the input block
+ * \param analyzer The analyzer
+ * \return A boolean indicating whether the input block is a reduction block.
+ * \note A similar check has been implemented in "src/tir/schedule/analysis.h", but that check is
+ * based on `tir.Schedule`. Here we have no schedule information, and thus we must implement the
+ * check again.
+ */
+bool IsReductionBlock(const BlockRealize& block_realize, const Map<Var, Range>& loop_range_map,

Review comment:
       we cannot because the lowering process doesn't have access to cached flags in BlockScope - need recalculation




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on pull request #9360: [TensorIR] Cross-Thread Reduction

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on pull request #9360:
URL: https://github.com/apache/tvm/pull/9360#issuecomment-968257996


   @Hzfengsy Could you take another look? Junru's polishing looks very good, but I myself as the author cannot approve this PR 😅.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9360: [TensorIR] Cross-Thread Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9360:
URL: https://github.com/apache/tvm/pull/9360#issuecomment-966877328


   Love the PR and very comprehensively tested implementation ❤️


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9360: [TensorIR] Cross-Thread Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9360:
URL: https://github.com/apache/tvm/pull/9360#issuecomment-966789162


   Finally got some time for a detailed code review! Will take over this PR and try to get it merged!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on a change in pull request #9360: [TensorIR] Cross-Thread Reduction

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on a change in pull request #9360:
URL: https://github.com/apache/tvm/pull/9360#discussion_r748825044



##########
File path: src/tir/transforms/lower_cross_thread_reduction.cc
##########
@@ -156,145 +181,333 @@ class BufferAccessReplacer : public StmtExprMutator {
  * \brief Substitute a given source block with a given target block, or remove the source block
  * branch from the AST if the target block is undefined
  */
-class ReductionBlockReplacer : public StmtMutator {
+class InThreadReducerMaker : private StmtMutator {
  public:
-  explicit ReductionBlockReplacer(const BlockRealizeNode* src_block, BlockRealize tgt_block)
-      : src_block_(src_block), tgt_block_(std::move(tgt_block)) {}
+  static Optional<Stmt> Make(const BlockRealizeNode* src_realize,
+                             Optional<BlockRealize> tgt_realize, Stmt stmt) {
+    return InThreadReducerMaker(src_realize, std::move(tgt_realize))(std::move(stmt));
+  }
 
  private:
-  Stmt VisitStmt_(const BlockRealizeNode* block_realize) final {
-    return block_realize == src_block_ ? tgt_block_ : GetRef<BlockRealize>(block_realize);
+  explicit InThreadReducerMaker(const BlockRealizeNode* src_realize,
+                                Optional<BlockRealize> tgt_realize)
+      : src_realize_(src_realize), tgt_realize_(tgt_realize) {}
+  Stmt VisitStmt_(const BlockRealizeNode* realize) final {
+    if (realize == src_realize_) {
+      return tgt_realize_.defined()  //
+                 ? tgt_realize_.value()
+                 : Stmt{nullptr};
+    }
+    return GetRef<BlockRealize>(realize);
   }
 
   Stmt VisitStmt_(const ForNode* loop) final {
-    For res = Downcast<For>(StmtMutator::VisitStmt_(loop));
-    return !res.defined() ? Stmt{nullptr} : (res->thread_binding.defined() ? res->body : res);
+    if (Optional<For> opt_res = Downcast<Optional<For>>(StmtMutator::VisitStmt_(loop))) {
+      For res = opt_res.value();
+      if (res->thread_binding.defined()) {
+        return res->body;
+      } else {
+        return res;
+      }
+    } else {
+      return Stmt{nullptr};
+    }
   }
 
   Stmt VisitStmt_(const SeqStmtNode* seq) final {
-    Array<Stmt> results;
-    results.reserve(seq->size());
+    Array<Stmt> stmts;
+    stmts.reserve(seq->size());
     for (const Stmt& stmt : seq->seq) {
-      Stmt res = StmtMutator::VisitStmt(stmt);
-      if (res.defined()) {
-        results.push_back(res);
+      if (Optional<Stmt> opt_res = VisitStmt(stmt)) {
+        stmts.push_back(opt_res.value());
       }
     }
-    return results.empty() ? Stmt{nullptr} : SeqStmt(results);
+    return stmts.empty() ? Stmt{nullptr} : SeqStmt::Flatten(stmts);
   }
 
-  const BlockRealizeNode* src_block_;
-  BlockRealize tgt_block_;
+  const BlockRealizeNode* src_realize_;
+  Optional<BlockRealize> tgt_realize_;
 };
 
+/*!
+ * \brief Create the lowered allreduce block transformed from the input reduction block
+ * \param reduction_block The input reduction block
+ * \param it_buffer The buffer to store in-thread reduction results
+ * \param ct_buffer The buffer to store cross-thread reduction results
+ * \param reducer The reduction function
+ * \param combiner_rhs The RHS of the combiner
+ * \param reduction_loops The reduction loops
+ */
+Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optional<Buffer>& it_buffer,
+                             const Buffer& ct_buffer, const CommReducer& reducer,
+                             const PrimExpr& combiner_rhs,
+                             const std::vector<const ForNode*>& reduction_loops) {
+  const BlockNode* block = realize->block.get();
+  Buffer wb_buffer = block->writes[0]->buffer;
+  Array<Range> wb_region = block->writes[0]->region;
+
+  BufferRegion ct_buffer_region(ct_buffer, {Range::FromMinExtent(0, 1)});
+  Optional<BufferRegion> it_buffer_region = NullOpt;
+  if (it_buffer.defined()) {
+    it_buffer_region = BufferRegion(it_buffer.value(), {Range::FromMinExtent(0, 1)});
+  }
+  // In total, the block is transformed into at most 4 statements
+  // - Stmt 1: initialize the buffer for in-thread reduction
+  // - Stmt 2: do in-thread reduction
+  // - Stmt 3: do cross-thread reduction
+  // - Stmt 4: write cross-thread reduction result to the original buffer
+  Array<Stmt> stmts;
+  stmts.reserve(4);
+  // Stmt 1: initialize the buffer for in-thread reduction
+  if (it_buffer.defined()) {
+    BufferStore init = Downcast<BufferStore>(block->init);
+    stmts.push_back(BlockRealize(
+        /*iter_values=*/{},
+        /*predicate=*/const_true(),
+        /*block=*/
+        Block(/*iter_vars=*/{},
+              /*reads=*/{},
+              /*writes=*/{it_buffer_region.value()},
+              /*name_hint=*/block->name_hint + "_in_thread_init",
+              /*body=*/
+              BufferStore(/*buffer=*/it_buffer.value(),
+                          /*value=*/init->value,
+                          /*indices=*/{Integer(0)}))));
+  }
+  // Stmt 2: do in-thread reduction
+  {
+    Optional<BlockRealize> new_realize = NullOpt;
+    // If need to generate in-thread reduction,
+    // then replace `wb_buffer` with `it_buffer` accordingly in given BlockRealize
+    // otherwise, directly remove given BlockRealize
+    if (it_buffer.defined()) {
+      ObjectPtr<BlockNode> new_block = make_object<BlockNode>(*block);
+      new_block->reads = RemoveBufferFromBufferRegions(std::move(new_block->reads), wb_buffer);
+      new_block->reads.push_back(it_buffer_region.value());
+      new_block->writes = {it_buffer_region.value()};
+      new_block->name_hint = new_block->name_hint + "_in_thread";
+      new_block->body =
+          BufferReplacer::Run(wb_buffer, it_buffer.value(), std::move(new_block->body));
+      new_block->init = NullOpt;
+      ObjectPtr<BlockRealizeNode> n = make_object<BlockRealizeNode>(*realize);
+      n->block = Block(new_block);
+      new_realize = BlockRealize(n);
+    }
+    For loop = GetRef<For>(reduction_loops[0]);
+    if (Optional<Stmt> stmt = InThreadReducerMaker::Make(realize, new_realize, std::move(loop))) {
+      stmts.push_back(stmt.value());
+    }
+  }
+  // Stmt 3: do cross-thread reduction
+  {
+    // Step 3.1. Create the parameters to the intrinsic
+    Array<PrimExpr> parameters;
+    parameters.reserve(reduction_loops.size() + 4);
+    // 1-st argument: size
+    parameters.push_back(make_const(DataType::UInt(32), 1));
+    // 2-nd argument: source
+    if (it_buffer.defined()) {
+      parameters.push_back(BufferLoad(it_buffer.value(), {Integer(0)}));
+    } else {
+      parameters.push_back(combiner_rhs);
+    }
+    // 3-rd argument: predicate
+    parameters.push_back(const_true());
+    // 4-th argument: destination
+    parameters.push_back(ct_buffer->data);
+    // next arguments: all the reduction threads
+    for (const ForNode* reduction_loop : reduction_loops) {
+      if (reduction_loop->thread_binding.defined()) {
+        parameters.push_back(reduction_loop->loop_var);
+      }
+    }
+    // Step 3.2. Create the block and the block-realize.
+    Array<IterVar> iter_vars{nullptr};
+    Array<PrimExpr> bindings{nullptr};
+    Array<BufferRegion> reads{nullptr};
+    if (it_buffer.defined()) {
+      iter_vars = Array<IterVar>{};
+      bindings = Array<PrimExpr>{};
+      reads = {it_buffer_region.value()};
+    } else {
+      iter_vars = block->iter_vars;
+      bindings = realize->iter_values;
+      reads = {RemoveBufferFromBufferRegions(block->reads, wb_buffer)};
+    }
+    stmts.push_back(BlockRealize(
+        /*iter_values=*/std::move(bindings),
+        /*predicate=*/const_true(),
+        /*block=*/
+        Block(/*iter_vars=*/std::move(iter_vars),
+              /*reads=*/std::move(reads),
+              /*writes=*/{ct_buffer_region},
+              /*name_hint=*/block->name_hint + "_cross_thread",
+              /*body=*/
+              AttrStmt(/*node=*/reducer,
+                       /*attr_key=*/tir::attr::reduce_scope,
+                       /*value=*/make_zero(DataType::Handle()),
+                       /*body=*/
+                       Evaluate(Call(/*dtype=*/DataType::Handle(),
+                                     /*op=*/tir::builtin::tvm_thread_allreduce(),
+                                     /*args=*/std::move(parameters)))))));
+  }
+  // Stmt 4: write cross-thread reduction result to the original buffer
+  {
+    ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size());
+    int n_iter = static_cast<int>(block->iter_vars.size());
+    Array<IterVar> iter_vars;
+    Array<PrimExpr> bindings;
+    Map<Var, PrimExpr> var_map;
+    iter_vars.reserve(n_iter);
+    bindings.reserve(n_iter);
+    for (int i = 0; i < n_iter; ++i) {
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != kCommReduce) {
+        IterVar new_iter_var{nullptr};
+        {
+          ObjectPtr<IterVarNode> n = make_object<IterVarNode>(*iter_var.get());
+          ObjectPtr<VarNode> v = make_object<VarNode>(*iter_var->var.get());
+          n->var = Var(v);
+          new_iter_var = IterVar(n);
+        }
+        iter_vars.push_back(new_iter_var);
+        bindings.push_back(binding);
+        var_map.Set(iter_var->var, new_iter_var->var);
+      }
+    }
+    BufferStore update = Downcast<BufferStore>(block->body);
+    update = Downcast<BufferStore>(Substitute(std::move(update), var_map));
+    stmts.push_back(BlockRealize(
+        /*iter_values=*/std::move(bindings),
+        /*predicate=*/const_true(),
+        /*block=*/
+        Block(
+            /*iter_vars=*/std::move(iter_vars),
+            /*reads=*/{std::move(ct_buffer_region)},
+            /*writes=*/{BufferRegion(wb_buffer, Substitute(wb_region, var_map))},
+            /*name_hint=*/block->name_hint + "_write_back",
+            /*body=*/
+            BufferStore(/*buffer=*/wb_buffer,
+                        /*value=*/BufferLoad(ct_buffer, {Integer(0)}),
+                        /*indices=*/update->indices))));
+  }
+  // Final step: Wrap all the above four statements with the reduction loops bound to threadIdx
+  Stmt new_stmt = SeqStmt::Flatten(std::move(stmts));
+  for (auto rit = reduction_loops.rbegin(); rit != reduction_loops.rend(); ++rit) {
+    const ForNode* loop = *rit;
+    if (loop->thread_binding.defined()) {
+      ObjectPtr<ForNode> n = make_object<ForNode>(*loop);
+      n->body = std::move(new_stmt);
+      new_stmt = For(n);
+    }
+  }
+  return new_stmt;
+}
+
 /*!
  * \brief Detect cross-thread reduction pattern and then transform
  */
 class CrossThreadReductionTransformer : public StmtMutator {
  private:
   // Check if the input block needs cross-thread reduction.
-  bool NeedCrossThreadReduction(const BlockRealizeNode* block_realize) {
+  std::vector<const ForNode*> NeedCrossThreadReduction(const BlockRealizeNode* realize) {
     // Step 0. If the block is the root block, just return.
     if (block_stack_.empty()) {
-      return false;
+      return {};
     }
 
     // Step 1. If the block is not a reduction block, cross-thread reduction is not needed.
-    if (!IsReductionBlock(GetRef<BlockRealize>(block_realize), loop_range_map_,
+    if (!IsReductionBlock(GetRef<BlockRealize>(realize), loop_range_map_,
                           GetRef<Block>(block_stack_.back()), &analyzer_)) {
-      return false;
+      return {};
     }
 
     // Step 2. Collect all the vars that appear in the bindings of reduction block iters.
     std::unordered_set<const VarNode*> reduction_vars;
-    GetVarsTouchedByBlockIters(GetRef<BlockRealize>(block_realize), nullptr, &reduction_vars);
+    GetVarsTouchedByBlockIters(GetRef<BlockRealize>(realize), nullptr, &reduction_vars);
 
     // Step 3. Collect the loops whose loop vars appear in the bindings of reduction block iters.
     // We call these loops "reduction-related".
     // Step 4. See whether at least one reduction-related loop is bound to thread axis in GPU - if
     // so, cross-thread reduction is needed. If none of the reduction-related loops is bound to
     // thread axis, cross-thread reduction is not needed for the input block.
     bool need = false;
-    reduction_loops_.clear();
+    std::vector<const ForNode*> reduction_loops;
     for (const ForNode* loop : loop_stack_) {
       if (reduction_vars.count(loop->loop_var.get())) {
         // Step 3. Collect the loop.
-        reduction_loops_.push_back(loop);
+        reduction_loops.push_back(loop);
         // Step 4. See whether the loop is bound to some thread axis.
         if (loop->thread_binding.defined()) {
           need = true;
         }
       }
     }
-
-    return need;
+    return need ? reduction_loops : std::vector<const ForNode*>{};
   }
 
   // Given that the input block needs cross-thread reduction, check if cross-thread reduction can
   // be applied to the block (i.e., the block satisfies all necessary conditions of cross-thread
   // reduction).
-  void CheckCanApplyCrossThreadReduction(const BlockNode* block) {
-    const String& block_name = block->name_hint;
-
+  std::tuple<int, CommReducer, PrimExpr> CheckCanApplyCrossThreadReduction(
+      const BlockNode* block, const std::vector<const ForNode*>& reduction_loops) const {
     // Condition 1. The block being applied cross-thread reduction should write to single buffer.
-    int n_write_buffer = static_cast<int>(block->writes.size());
-    CHECK_EQ(n_write_buffer, 1) << "ValueError: Cross-thread reduction requires the block to only "
-                                   "write to single buffer. However, the block "
-                                << block_name << " writes to " << n_write_buffer << " buffer(s).";
+    CHECK_EQ(block->writes.size(), 1)
+        << "ValueError: Cross-thread reduction requires the block to only "
+           "write to single buffer. However, the block "
+        << block->name_hint << " writes to " << block->writes.size() << " buffer(s).";
 
     // Condition 2. All the reduction-related loops should be the deepest among all statements
     // outside the block (ignoring SeqStmt here).
     int n_deepest_reduction_loops = 0;
     for (auto rit = statement_stack_.rbegin() + 1; rit != statement_stack_.rend(); ++rit) {
-      if ((*rit)->IsInstance<SeqStmtNode>()) {
-        // Skip SeqStmt.
-        continue;
-      }
-      if (std::find(reduction_loops_.begin(), reduction_loops_.end(),
-                    reinterpret_cast<const ForNode*>(*rit)) == reduction_loops_.end()) {
-        break;
+      const StmtNode* stmt = *rit;
+      if (stmt->IsInstance<ForNode>()) {
+        const ForNode* loop = static_cast<const ForNode*>(stmt);
+        if (std::find(reduction_loops.begin(), reduction_loops.end(), loop) ==
+            reduction_loops.end()) {
+          break;
+        }
+        ++n_deepest_reduction_loops;

Review comment:
       @junrushao1994 IMO this check should be applied to all kinds of Stmt except for SeqStmt. My point is that the reduction loops should be the deepest *Stmts*, not only the deepest loops. If we only apply the check to loops, then it might be possible that there is a block or a IfThenElse among the loops - which is not allowed. 
   
   It's a corner case, and I believe that most of the time we won't encounter the case.
   
   Seems that I forgot to add such cases in tests 😅.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9360: [TensorIR] Cross-Thread Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9360:
URL: https://github.com/apache/tvm/pull/9360#issuecomment-968244975


   @Hzfengsy @MasterJH5574 Should be good to go. Please take another look :-) 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9360: [TensorIR] Cross-Thread Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9360:
URL: https://github.com/apache/tvm/pull/9360#issuecomment-966876231


   Did a pass over analysis and misc changes


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9360: [TensorIR] Cross-Thread Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9360:
URL: https://github.com/apache/tvm/pull/9360#discussion_r748818568



##########
File path: src/tir/transforms/lower_cross_thread_reduction.cc
##########
@@ -0,0 +1,590 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file lower_cross_thread_reduction.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include "../schedule/analysis.h"
+#include "ir_utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Check the dominant property of a block:
+ * the block is the only writer of its output, dominating the reader of its output buffers
+ * \param scope_block The scope block of the block to be checked
+ * \param block The block whose dominant property is to be checked
+ * \return A boolean indicating if the block is a dominant block
+ */
+bool IsDominantBlock(const Block& scope_block, const Block& block) {
+  // Step 1. Count the number of writers for each buffer written by the scope block.
+  std::unordered_map<const BufferNode*, int> buffer_writer_cnt;
+  PreOrderVisit(scope_block->body, [&buffer_writer_cnt](const ObjectRef& obj) {
+    if (const auto* block = obj.as<BlockNode>()) {
+      for (const BufferRegion& buffer_region : block->writes) {
+        ++buffer_writer_cnt[buffer_region->buffer.get()];
+      }
+      return false;
+    }
+    return true;
+  });
+  // Step 2. Check whether `block` is the only writer of its outputs.
+  for (const BufferRegion& buffer_region : block->writes) {
+    ICHECK(buffer_writer_cnt.count(buffer_region->buffer.get()));
+    if (buffer_writer_cnt[buffer_region->buffer.get()] != 1) {
+      return false;
+    }
+  }
+  return true;
+}
+
+/*!
+ * \brief Check whether the input block is a reduction block.
+ * \param block_realize The block to be checked
+ * \param loop_range_map The mapping from the loop variables outside the input block to their ranges
+ * \param scope_block The scope block of the input block
+ * \param analyzer The analyzer
+ * \return A boolean indicating whether the input block is a reduction block.
+ * \note A similar check has been implemented in "src/tir/schedule/analysis.h", but that check is
+ * based on `tir.Schedule`. Here we have no schedule information, and thus we must implement the
+ * check again.
+ */
+bool IsReductionBlock(const BlockRealize& block_realize, const Map<Var, Range>& loop_range_map,
+                      const Block& scope_block, arith::Analyzer* analyzer) {
+  const auto* block = block_realize->block.as<BlockNode>();
+  // Cond 1. The block has the `init` statement.
+  if (!block->init.defined()) {
+    return false;
+  }
+  // Cond 2. All the block bindings are quasi-affine expressions.
+  if (!IsAffineBinding(block_realize, loop_range_map, analyzer)) {
+    return false;
+  }
+  // Cond 3. All block vars are either data parallel block vars or reduction block vars. Meanwhile,
+  // we collect all the reduction block vars.
+  if (!ContainsOnlyDataParAndReductionBlockIter(block->iter_vars)) {
+    return false;
+  }
+  // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its
+  // output buffers.
+  if (!IsDominantBlock(scope_block, GetRef<Block>(block))) {
+    return false;
+  }
+  // Cond 5. The reduction block vars are not used to index the output buffers.
+  return ReductionIterNotIndexOutputBuffer(GetRef<Block>(block));
+}
+
+/*!
+ * \brief Create an intermediate buffer with specified name and data type
+ * \param name The specified name
+ * \param dtype The specified data type
+ * \return The created buffer
+ */
+Buffer CreateReductionBuffer(String name, const DataType& dtype) {
+  Var var(name, PointerType(PrimType(dtype), "local"));
+  return Buffer(var, dtype, {1}, {1}, PrimExpr(), std::move(name), 0, 0, kDefault);

Review comment:
       I double checked. It's a bit confusing but turns out to be correct - it constructs a scratchpad memory for both in-thread and cross-thread reduction. I refactored a bit and renamed this method for better clarity




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on a change in pull request #9360: [TensorIR] Cross-Thread Reduction

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on a change in pull request #9360:
URL: https://github.com/apache/tvm/pull/9360#discussion_r735752658



##########
File path: src/tir/schedule/analysis.h
##########
@@ -323,6 +327,53 @@ struct ProducerConsumerSplit {
  */
 Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write);
 
+/******** Reduction Block Related ********/
+
+/*!
+ * \brief Convert the `init` and `body` of the input block to BufferStores
+ * \tparam in_schedule Whether the function is called by schedule primitives
+ * \param self The schedule state
+ * \param block The block to be analyzed
+ * \return The BufferStores of the `init` and `body` of the input block
+ * \throw ScheduleError If the `init` or `body` is not BufferStore, or they don't write to the same
+ * buffer
+ */
+template <bool in_schedule>
+std::pair<BufferStore, BufferStore> GetBufferStoreNodes(const ScheduleState& self,

Review comment:
       Thanks! It's indeed a better name.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9360: [TensorIR] Cross-Thread Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9360:
URL: https://github.com/apache/tvm/pull/9360#issuecomment-966072853


   Will do the review tomorrow


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on a change in pull request #9360: [TensorIR] Cross-Thread Reduction

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on a change in pull request #9360:
URL: https://github.com/apache/tvm/pull/9360#discussion_r735751050



##########
File path: src/tir/schedule/analysis/analysis.cc
##########
@@ -552,6 +520,9 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize,
     } else {
       has_block_vars_of_other_types = true;
     }
+    if (set == nullptr) {

Review comment:
       No this is not a bug.
   
   Before this PR, we assume that the input pointers are both not `nullptr`. But after this PR, the input set pointer `data_par_vars` can be `nullptr`, and we won't collect variables for `data_par_vars` if it's `nullptr`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on a change in pull request #9360: [TensorIR] Cross-Thread Reduction

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on a change in pull request #9360:
URL: https://github.com/apache/tvm/pull/9360#discussion_r747369463



##########
File path: src/tir/transforms/lower_cross_thread_reduction.cc
##########
@@ -0,0 +1,590 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file lower_cross_thread_reduction.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include "../schedule/analysis.h"
+#include "ir_utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Check the dominant property of a block:
+ * the block is the only writer of its output, dominating the reader of its output buffers
+ * \param scope_block The scope block of the block to be checked
+ * \param block The block whose dominant property is to be checked
+ * \return A boolean indicating if the block is a dominant block
+ */
+bool IsDominantBlock(const Block& scope_block, const Block& block) {
+  // Step 1. Count the number of writers for each buffer written by the scope block.
+  std::unordered_map<const BufferNode*, int> buffer_writer_cnt;
+  PreOrderVisit(scope_block->body, [&buffer_writer_cnt](const ObjectRef& obj) {
+    if (const auto* block = obj.as<BlockNode>()) {
+      for (const BufferRegion& buffer_region : block->writes) {
+        ++buffer_writer_cnt[buffer_region->buffer.get()];
+      }
+      return false;
+    }
+    return true;
+  });
+  // Step 2. Check whether `block` is the only writer of its outputs.
+  for (const BufferRegion& buffer_region : block->writes) {
+    ICHECK(buffer_writer_cnt.count(buffer_region->buffer.get()));
+    if (buffer_writer_cnt[buffer_region->buffer.get()] != 1) {
+      return false;
+    }
+  }
+  return true;
+}
+
+/*!
+ * \brief Check whether the input block is a reduction block.
+ * \param block_realize The block to be checked
+ * \param loop_range_map The mapping from the loop variables outside the input block to their ranges
+ * \param scope_block The scope block of the input block
+ * \param analyzer The analyzer
+ * \return A boolean indicating whether the input block is a reduction block.
+ * \note A similar check has been implemented in "src/tir/schedule/analysis.h", but that check is
+ * based on `tir.Schedule`. Here we have no schedule information, and thus we must implement the
+ * check again.
+ */
+bool IsReductionBlock(const BlockRealize& block_realize, const Map<Var, Range>& loop_range_map,

Review comment:
       Can we reuse the current API?

##########
File path: src/tir/transforms/lower_cross_thread_reduction.cc
##########
@@ -0,0 +1,590 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file lower_cross_thread_reduction.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include "../schedule/analysis.h"
+#include "ir_utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Check the dominant property of a block:
+ * the block is the only writer of its output, dominating the reader of its output buffers
+ * \param scope_block The scope block of the block to be checked
+ * \param block The block whose dominant property is to be checked
+ * \return A boolean indicating if the block is a dominant block
+ */
+bool IsDominantBlock(const Block& scope_block, const Block& block) {
+  // Step 1. Count the number of writers for each buffer written by the scope block.
+  std::unordered_map<const BufferNode*, int> buffer_writer_cnt;
+  PreOrderVisit(scope_block->body, [&buffer_writer_cnt](const ObjectRef& obj) {
+    if (const auto* block = obj.as<BlockNode>()) {
+      for (const BufferRegion& buffer_region : block->writes) {
+        ++buffer_writer_cnt[buffer_region->buffer.get()];
+      }
+      return false;
+    }
+    return true;
+  });
+  // Step 2. Check whether `block` is the only writer of its outputs.
+  for (const BufferRegion& buffer_region : block->writes) {
+    ICHECK(buffer_writer_cnt.count(buffer_region->buffer.get()));
+    if (buffer_writer_cnt[buffer_region->buffer.get()] != 1) {
+      return false;
+    }
+  }
+  return true;
+}
+
+/*!
+ * \brief Check whether the input block is a reduction block.
+ * \param block_realize The block to be checked
+ * \param loop_range_map The mapping from the loop variables outside the input block to their ranges
+ * \param scope_block The scope block of the input block
+ * \param analyzer The analyzer
+ * \return A boolean indicating whether the input block is a reduction block.
+ * \note A similar check has been implemented in "src/tir/schedule/analysis.h", but that check is
+ * based on `tir.Schedule`. Here we have no schedule information, and thus we must implement the
+ * check again.
+ */
+bool IsReductionBlock(const BlockRealize& block_realize, const Map<Var, Range>& loop_range_map,
+                      const Block& scope_block, arith::Analyzer* analyzer) {
+  const auto* block = block_realize->block.as<BlockNode>();
+  // Cond 1. The block has the `init` statement.
+  if (!block->init.defined()) {
+    return false;
+  }
+  // Cond 2. All the block bindings are quasi-affine expressions.
+  if (!IsAffineBinding(block_realize, loop_range_map, analyzer)) {
+    return false;
+  }
+  // Cond 3. All block vars are either data parallel block vars or reduction block vars. Meanwhile,
+  // we collect all the reduction block vars.
+  if (!ContainsOnlyDataParAndReductionBlockIter(block->iter_vars)) {
+    return false;
+  }
+  // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its
+  // output buffers.
+  if (!IsDominantBlock(scope_block, GetRef<Block>(block))) {
+    return false;
+  }
+  // Cond 5. The reduction block vars are not used to index the output buffers.
+  return ReductionIterNotIndexOutputBuffer(GetRef<Block>(block));
+}
+
+/*!
+ * \brief Create an intermediate buffer with specified name and data type
+ * \param name The specified name
+ * \param dtype The specified data type
+ * \return The created buffer
+ */
+Buffer CreateReductionBuffer(String name, const DataType& dtype) {
+  Var var(name, PointerType(PrimType(dtype), "local"));
+  return Buffer(var, dtype, {1}, {1}, PrimExpr(), std::move(name), 0, 0, kDefault);

Review comment:
       It seems that the shape and other attrs are incorrect




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9360: [TensorIR] Cross-Thread Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9360:
URL: https://github.com/apache/tvm/pull/9360#issuecomment-952547274


   I will do another round of review next week!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org