You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/09/18 03:00:39 UTC

[GitHub] [tvm] spectrometerHBH opened a new pull request #9041: [TensorIR][M2a] Decompose-Reduction

spectrometerHBH opened a new pull request #9041:
URL: https://github.com/apache/tvm/pull/9041


   This PR is part of the TensorIR upstreaming effort (#7527), which adds the following schedule primitives:
   
   decompose_reduction
   Co-authored-by: Junru Shao junrushao1994@gmail.com
   Co-authored-by: Ruihang Lai lairuihangdongdong@qq.com
   Co-authored-by: Hongyi Jin 3231950289@qq.com
   Co-authored-by: Wuwei Lin wuwei@apache.org
   Co-authored-by: Siyuan Feng Hzfengsy@sjtu.edu.cn


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] xqdan commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
xqdan commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r713509277



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,289 @@
 namespace tvm {
 namespace tir {
 
+bool ListContainsElement(const Array<StmtSRef>& list, const StmtSRef& element) {
+  for (const StmtSRef& ele : list) {
+    if (ele.same_as(element)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  ICHECK(block_sref.defined())
+      << "ValueError: 'decompose_reduction' expect a block as first argument, but get value 'None'";
+  const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  if (!ListContainsElement(loops, loop_sref)) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  const StmtSRef& scope_root_sref = GetScopeRoot(self, block_sref, false, false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = loops.size() - 1; i >= 0; --i) {

Review comment:
       size_t i




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r717223253



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;

Review comment:
       nit: don't set predicate here because we will mutate later on line 271




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] spectrometerHBH commented on pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#issuecomment-928829921


   > @spectrometerHBH Sorry for the delay! I think I finally did a complete with the code review. Let me know if it makes sense to you
   
   Thanks!I might be slow to response in the next 2 days


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r712003001



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,289 @@
 namespace tvm {
 namespace tir {
 
+bool ListContainsElement(const Array<StmtSRef>& list, const StmtSRef& element) {
+  for (const StmtSRef& ele : list) {
+    if (ele.same_as(element)) {
+      return true;
+    }
+  }
+  return false;
+}

Review comment:
       Probably we can use `std::find` to replace this function.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#issuecomment-928828584


   @spectrometerHBH Sorry for the delay! I think I finally did a complete with the code review. Let me know if it makes sense to you


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] spectrometerHBH commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r720405611



##########
File path: include/tvm/tir/schedule/state.h
##########
@@ -142,6 +142,8 @@ class ScheduleStateNode : public Object {
   /******** Property of blocks ********/
   /*! \brief Returns the BlockInfo correpsonding to the block sref */
   TVM_DLL BlockInfo GetBlockInfo(const StmtSRef& block_sref) const;
+  /*! \brief Recalculate the BlockInfo recursively under stmt */
+  TVM_DLL void UpdateSubtreeBlockInfo(const Stmt& stmt);

Review comment:
       The function just updates the BlockInfo. I think it is always doable to do so.
   If we want to enforce some constraint, then the constraint should be all loop/blocks in the subtree must have sref.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r716921095



##########
File path: include/tvm/tir/schedule/state.h
##########
@@ -142,6 +142,8 @@ class ScheduleStateNode : public Object {
   /******** Property of blocks ********/
   /*! \brief Returns the BlockInfo correpsonding to the block sref */
   TVM_DLL BlockInfo GetBlockInfo(const StmtSRef& block_sref) const;
+  /*! \brief Recalculate the BlockInfo recursively under stmt */
+  TVM_DLL void UpdateSubtreeBlockInfo(const Stmt& stmt);

Review comment:
       Do we have any additional constraint of this API? For example:
   - `stmt` must be a `Block`?
   - All loop /blocks in the subtree must not have any sref




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r713628506



##########
File path: include/tvm/tir/schedule/state.h
##########
@@ -142,6 +142,8 @@ class ScheduleStateNode : public Object {
   /******** Property of blocks ********/
   /*! \brief Returns the BlockInfo correpsonding to the block sref */
   TVM_DLL BlockInfo GetBlockInfo(const StmtSRef& block_sref) const;
+  /*! \brief Recalculate the BlockInfo recursively under stmt */
+  TVM_DLL void UpdateBlockInfo(const Stmt& stmt);

Review comment:
       Are we updating the BlockInfo in the entire subtree? If so, let's find a better name for this method. Candidate: UpdateSubtreeBlockInfo




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r712627111



##########
File path: src/tir/schedule/primitive/loop_transformation.cc
##########
@@ -482,7 +482,7 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
   Array<PrimExpr> substitute_value;
   substitute_value.resize(loops.size());
   PrimExpr tot = fused_var;
-  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; i--) {
+  for (int i = loops.size() - 1; i >= 0; i--) {

Review comment:
       We shouldn't make this change, as when `loops` is empty, `loops.size() - 1` is an unexpected value.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] xqdan commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
xqdan commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r713741028



##########
File path: src/tir/schedule/primitive/loop_transformation.cc
##########
@@ -482,7 +482,7 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
   Array<PrimExpr> substitute_value;
   substitute_value.resize(loops.size());
   PrimExpr tot = fused_var;
-  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; i--) {
+  for (int i = loops.size() - 1; i >= 0; i--) {

Review comment:
       You are right,didn't notice i starts from loops.size()




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r716306841



##########
File path: src/tir/schedule/primitive.h
##########
@@ -224,6 +224,18 @@ TVM_DLL void ComputeInline(ScheduleState self, const StmtSRef& block_sref);
  */
 TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref);
 /******** Schedule: Reduction ********/
+/*!
+ * \brief Decompose a reduction block into init block and update block, where the newly generated
+   init block will be before the specified loop.
+   1) The block is a reduction block.
+   2) The loop is the ancestor of the block.
+   3) The loop is not lower than all the loops related to reduce block var.
+ * \param block_rv The reduction block to be decomposed

Review comment:
       Let's copy the document in schedule.h and paste it here.

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, discarded_loops);

Review comment:
       By doing this we can remove line 217 above, as `init_realize->predicate` isn't mutated between line 217 and line 271.
   https://github.com/apache/tvm/blob/ee6e57dc7788918a7380ac3fa36a52b0f549df50/src/tir/schedule/primitive/reduction.cc#L217
   ```suggestion
     init_realize->predicate = RemakePredicate(realize->predicate, discarded_loops);
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r717224898



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };

Review comment:
       nit:
   
   ```suggestion
     auto f = [&](const VarNode* var) { return discarded_loops.count(var); };
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r717006103



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),

Review comment:
       do we want to rename these block vars, or simply using the same name is good enough?
    
   ```suggestion
                            /*var=*/iter_var->var.copy_with_suffix(""),
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r717173337



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }

Review comment:
       Do we make the assumption here about the write region?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r712030876



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,289 @@
 namespace tvm {
 namespace tir {
 
+bool ListContainsElement(const Array<StmtSRef>& list, const StmtSRef& element) {
+  for (const StmtSRef& ele : list) {
+    if (ele.same_as(element)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  ICHECK(block_sref.defined())
+      << "ValueError: 'decompose_reduction' expect a block as first argument, but get value 'None'";
+  const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  if (!ListContainsElement(loops, loop_sref)) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  const StmtSRef& scope_root_sref = GetScopeRoot(self, block_sref, false, false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = loops.size() - 1; i >= 0; --i) {
+    const auto* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, discarded_loops);
+  // Step 5. Create new loops above init block
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {
+    const auto* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);

Review comment:
       Specify the type here.
   ```suggestion
       const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,289 @@
 namespace tvm {
 namespace tir {
 
+bool ListContainsElement(const Array<StmtSRef>& list, const StmtSRef& element) {
+  for (const StmtSRef& ele : list) {
+    if (ele.same_as(element)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  ICHECK(block_sref.defined())
+      << "ValueError: 'decompose_reduction' expect a block as first argument, but get value 'None'";
+  const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  if (!ListContainsElement(loops, loop_sref)) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  const StmtSRef& scope_root_sref = GetScopeRoot(self, block_sref, false, false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = loops.size() - 1; i >= 0; --i) {
+    const auto* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, discarded_loops);
+  // Step 5. Create new loops above init block
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {
+    const auto* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);
+    // Create a new equivalent to the chosen loop
+    Var old_loop_var = old_loop->loop_var;
+    Var new_loop_var = old_loop_var.copy_with_suffix("_init");
+    body = For(/*loop_var=*/new_loop_var,
+               /*min=*/old_loop->min,
+               /*extent=*/old_loop->extent,
+               /*kind=*/ForKind::kSerial,
+               /*body=body*/ Substitute(body, {{old_loop_var, new_loop_var}}));

Review comment:
       The `Substitute(...)` here incurs multiple visits of the AST. Could we collect all the loops first and only substitute once after that?

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,289 @@
 namespace tvm {
 namespace tir {
 
+bool ListContainsElement(const Array<StmtSRef>& list, const StmtSRef& element) {
+  for (const StmtSRef& ele : list) {
+    if (ele.same_as(element)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  ICHECK(block_sref.defined())
+      << "ValueError: 'decompose_reduction' expect a block as first argument, but get value 'None'";
+  const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  if (!ListContainsElement(loops, loop_sref)) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  const StmtSRef& scope_root_sref = GetScopeRoot(self, block_sref, false, false);

Review comment:
       Let's document what the two "false" are representing.
   ```suggestion
     const StmtSRef& scope_root_sref = GetScopeRoot(self, block_sref,  //
                                                    /*require_stage_pipeline=*/false,
                                                    /*require_subtree_compact_dataflow=*/false);
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] spectrometerHBH edited a comment on pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
spectrometerHBH edited a comment on pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#issuecomment-928829921






-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r711976931



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,289 @@
 namespace tvm {
 namespace tir {
 
+bool ListContainsElement(const Array<StmtSRef>& list, const StmtSRef& element) {
+  for (const StmtSRef& ele : list) {
+    if (ele.same_as(element)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  ICHECK(block_sref.defined())
+      << "ValueError: 'decompose_reduction' expect a block as first argument, but get value 'None'";

Review comment:
       We don't need this check, as `block_sref` is got from a `BlockRV` in concrete-schedule. Therefore `block_sref` is certainly defined.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r711996916



##########
File path: include/tvm/tir/schedule/schedule.h
##########
@@ -364,6 +364,17 @@ class ScheduleNode : public runtime::Object {
    */
   virtual void ReverseComputeInline(const BlockRV& block) = 0;
   /******** Schedule: Reduction ********/
+  /*!
+   * \brief Decompose a reduction block into init block and update block, where the newly generated
+     init block will be before the specified loop.
+     1) The block is a reduction block.
+     2) The loop is the ancestor of the block.
+     3) The loop is not lower than all the loops related to reduce block var.

Review comment:
       ```suggestion
      * \brief Decompose a reduction block into init block and update block, where the newly generated
        init block will be before the specified loop. It requires that
        1) The input block is a reduction block.
        2) The input loop is an ancestor of the block.
        3) The input loop is above all the loops related to reduce block var.
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] spectrometerHBH commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r715262220



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,289 @@
 namespace tvm {
 namespace tir {
 
+bool ListContainsElement(const Array<StmtSRef>& list, const StmtSRef& element) {
+  for (const StmtSRef& ele : list) {
+    if (ele.same_as(element)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }

Review comment:
       It's not optimal in time complexity.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r717022936



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {
+    const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);
+    // Create a new equivalent to the chosen loop
+    Var old_loop_var = old_loop->loop_var;
+    Var new_loop_var = old_loop_var.copy_with_suffix("_init");
+    loop_var_map[old_loop_var] = new_loop_var;
+    body = For(/*loop_var=*/new_loop_var,
+               /*min=*/old_loop->min,
+               /*extent=*/old_loop->extent,

Review comment:
       do we need to handle the case where `min` and `extent` contains some other loop variables?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] spectrometerHBH commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r716055130



##########
File path: src/tir/schedule/state.cc
##########
@@ -206,25 +188,11 @@ class StateCreator : private StmtVisitor {
    * \param stmt A for-loop statement or a block statement
    * \return A sref to the stmt
    */
-  StmtSRef PushSRef(const StmtNode* stmt) {
-    if (srefs_.empty()) {
-      srefs_.push_back(
-          StmtSRef(stmt,
-                   /*parent=*/nullptr,
-                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
-    } else {
-      StmtSRefNode* parent = srefs_.back().get();
-      srefs_.push_back(
-          StmtSRef(stmt, parent,
-                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
-    }
-    return srefs_.back();
-  }
+  void PushSRef(const StmtNode* stmt) { srefs_.push_back(self_->stmt2ref.at(stmt)); }
 
-  /*! \brief Pop the top of the scope and record it in stmt2ref map */
-  StmtSRef PopAndRecordSRef() {
-    StmtSRef sref = std::move(srefs_.back());
-    self_->stmt2ref[sref->stmt] = sref;
+  /*! \brief Pop the top of the scope */

Review comment:
       The flag only affects the info binding calculation, I think?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] spectrometerHBH commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r720393958



##########
File path: src/tir/schedule/state.cc
##########
@@ -206,25 +188,11 @@ class StateCreator : private StmtVisitor {
    * \param stmt A for-loop statement or a block statement
    * \return A sref to the stmt
    */
-  StmtSRef PushSRef(const StmtNode* stmt) {
-    if (srefs_.empty()) {
-      srefs_.push_back(
-          StmtSRef(stmt,
-                   /*parent=*/nullptr,
-                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
-    } else {
-      StmtSRefNode* parent = srefs_.back().get();
-      srefs_.push_back(
-          StmtSRef(stmt, parent,
-                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
-    }
-    return srefs_.back();
-  }
+  void PushSRef(const StmtNode* stmt) { srefs_.push_back(self_->stmt2ref.at(stmt)); }
 
-  /*! \brief Pop the top of the scope and record it in stmt2ref map */
-  StmtSRef PopAndRecordSRef() {
-    StmtSRef sref = std::move(srefs_.back());
-    self_->stmt2ref[sref->stmt] = sref;
+  /*! \brief Pop the top of the scope */

Review comment:
       consider `ScheduleStateNode::UpdateSubtreeblockInfo(stmt)`
   If stmt is a Block, which means the outer loops of Block are not passed into the function, then we should expect the Block to be root(i.e. it doesn't have outer loops).
   If stmt is a Loop, the the blocks under the loop can not be judged as root block.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] spectrometerHBH commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r720435157



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),

Review comment:
       I have no preference.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#issuecomment-925944448


   @spectrometerHBH Can you address the comments and then we can start another round of review


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r716925384



##########
File path: src/tir/schedule/state.cc
##########
@@ -421,6 +389,86 @@ class StateCreator : private StmtVisitor {
   arith::Analyzer analyzer_;
 };
 
+/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
+class StateCreator : private StmtVisitor {

Review comment:
       Update: I'm not sure I'm 100% correct after a second check...Let me know :-)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] spectrometerHBH commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r720435157



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),

Review comment:
       I have no preference. Maybe the shorter is better.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r711997862



##########
File path: include/tvm/tir/schedule/schedule.h
##########
@@ -364,6 +364,17 @@ class ScheduleNode : public runtime::Object {
    */
   virtual void ReverseComputeInline(const BlockRV& block) = 0;
   /******** Schedule: Reduction ********/
+  /*!
+   * \brief Decompose a reduction block into init block and update block, where the newly generated
+     init block will be before the specified loop.
+     1) The block is a reduction block.
+     2) The loop is the ancestor of the block.
+     3) The loop is not lower than all the loops related to reduce block var.
+   * \param block_rv The reduction block to be decomposed
+   * \param loop_rv The position where init block is inserted

Review comment:
       ```suggestion
      * \param loop_rv The loop above which the init block is inserted
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] spectrometerHBH edited a comment on pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
spectrometerHBH edited a comment on pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#issuecomment-928829921






-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r713606220



##########
File path: include/tvm/tir/schedule/schedule.h
##########
@@ -364,6 +364,17 @@ class ScheduleNode : public runtime::Object {
    */
   virtual void ReverseComputeInline(const BlockRV& block) = 0;
   /******** Schedule: Reduction ********/
+  /*!
+   * \brief Decompose a reduction block into init block and update block, where the newly generated
+     init block will be before the specified loop.

Review comment:
       ```suggestion
      * \brief Decompose a reduction block into two separate blocks.
      * a) The init block, which is translated from the init statement of the reduction block;
      * b) The update block, which is the original block without init statement.
      * 
      * The init block is inserted right before the given loop.
      * 
      * The schedule primitive requires:
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r717022306



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {

Review comment:
       nit: we don't need const ref to integers because direct copy may be equally fast
   
   ```suggestion
     for (int i : chosen_loops) {
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {

Review comment:
       nit: we don't need const ref to integers because direct copy may be faster or at least equally fast
   
   ```suggestion
     for (int i : chosen_loops) {
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r713617813



##########
File path: src/tir/schedule/state.cc
##########
@@ -206,25 +188,11 @@ class StateCreator : private StmtVisitor {
    * \param stmt A for-loop statement or a block statement
    * \return A sref to the stmt
    */
-  StmtSRef PushSRef(const StmtNode* stmt) {
-    if (srefs_.empty()) {
-      srefs_.push_back(
-          StmtSRef(stmt,
-                   /*parent=*/nullptr,
-                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
-    } else {
-      StmtSRefNode* parent = srefs_.back().get();
-      srefs_.push_back(
-          StmtSRef(stmt, parent,
-                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
-    }
-    return srefs_.back();
-  }
+  void PushSRef(const StmtNode* stmt) { srefs_.push_back(self_->stmt2ref.at(stmt)); }
 
-  /*! \brief Pop the top of the scope and record it in stmt2ref map */
-  StmtSRef PopAndRecordSRef() {
-    StmtSRef sref = std::move(srefs_.back());
-    self_->stmt2ref[sref->stmt] = sref;
+  /*! \brief Pop the top of the scope */

Review comment:
       If I understand correctly, the `MakeBlockInfo` method may be problematic in terms of the `is_root_block` flag, because `sref_` can be empty on a subtree root instead of the root of the entire tree. Is that correct?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r716920036



##########
File path: src/tir/schedule/state.cc
##########
@@ -206,25 +188,11 @@ class StateCreator : private StmtVisitor {
    * \param stmt A for-loop statement or a block statement
    * \return A sref to the stmt
    */
-  StmtSRef PushSRef(const StmtNode* stmt) {
-    if (srefs_.empty()) {
-      srefs_.push_back(
-          StmtSRef(stmt,
-                   /*parent=*/nullptr,
-                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
-    } else {
-      StmtSRefNode* parent = srefs_.back().get();
-      srefs_.push_back(
-          StmtSRef(stmt, parent,
-                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
-    }
-    return srefs_.back();
-  }
+  void PushSRef(const StmtNode* stmt) { srefs_.push_back(self_->stmt2ref.at(stmt)); }
 
-  /*! \brief Pop the top of the scope and record it in stmt2ref map */
-  StmtSRef PopAndRecordSRef() {
-    StmtSRef sref = std::move(srefs_.back());
-    self_->stmt2ref[sref->stmt] = sref;
+  /*! \brief Pop the top of the scope */

Review comment:
       Correct me if I was wrong, but `ScheduleStateNode::UpdateSubtreeblockInfo(stmt)` calls `BlockInfoCollector::Collect`, where `stmt` can be the a root block of a subtree, but it's binding might not be affine; However, in `MakeBlockInfo`, we would assume `is_root_block == True`, and make this binding affine - is that correct?

##########
File path: include/tvm/tir/schedule/state.h
##########
@@ -142,6 +142,8 @@ class ScheduleStateNode : public Object {
   /******** Property of blocks ********/
   /*! \brief Returns the BlockInfo correpsonding to the block sref */
   TVM_DLL BlockInfo GetBlockInfo(const StmtSRef& block_sref) const;
+  /*! \brief Recalculate the BlockInfo recursively under stmt */
+  TVM_DLL void UpdateSubtreeBlockInfo(const Stmt& stmt);

Review comment:
       Do we have any additional constraint of this API? For example:
   - `stmt` must be a `Block`?
   - All loop /blocks in the subtree must not have any sref

##########
File path: src/tir/schedule/state.cc
##########
@@ -421,6 +389,86 @@ class StateCreator : private StmtVisitor {
   arith::Analyzer analyzer_;
 };
 
+/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
+class StateCreator : private StmtVisitor {

Review comment:
       I just checked. Actually StateCreator, including its contents like PushSRef, VisitStmt, is not used anymore, so let's remove them all and inline the Create method into `ScheduleState::ScheduleState`

##########
File path: src/tir/schedule/state.cc
##########
@@ -421,6 +389,86 @@ class StateCreator : private StmtVisitor {
   arith::Analyzer analyzer_;
 };
 
+/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
+class StateCreator : private StmtVisitor {

Review comment:
       Update: I'm not sure I'm 100% correct after a second check...Let me know :-)

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }

Review comment:
       nitpick
   
   ```suggestion
     if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) {
       throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
     }
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;

Review comment:
       nit
   
   ```suggestion
     std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
     block_var_map.reserve(block->iter_vars.size());
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),

Review comment:
       do we want to rename these block vars, or simply using the same name is good enough?
    
   ```suggestion
                            /*var=*/iter_var->var.copy_with_suffix(""),
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {

Review comment:
       nit: we don't need const ref to integers because direct copy may be equally fast
   
   ```suggestion
     for (int i : chosen_loops) {
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {

Review comment:
       nit: we don't need const ref to integers because direct copy may be faster or at least equally fast
   
   ```suggestion
     for (int i : chosen_loops) {
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {
+    const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);
+    // Create a new equivalent to the chosen loop
+    Var old_loop_var = old_loop->loop_var;
+    Var new_loop_var = old_loop_var.copy_with_suffix("_init");
+    loop_var_map[old_loop_var] = new_loop_var;
+    body = For(/*loop_var=*/new_loop_var,
+               /*min=*/old_loop->min,
+               /*extent=*/old_loop->extent,

Review comment:
       do we need to handle the case where `min` and `extent` contains some other loop variables?

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {
+    const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);
+    // Create a new equivalent to the chosen loop
+    Var old_loop_var = old_loop->loop_var;
+    Var new_loop_var = old_loop_var.copy_with_suffix("_init");
+    loop_var_map[old_loop_var] = new_loop_var;
+    body = For(/*loop_var=*/new_loop_var,
+               /*min=*/old_loop->min,
+               /*extent=*/old_loop->extent,

Review comment:
       do we need to handle the case where `min` and `extent` contains some other loop variables? are they handled in line 287?

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {
+    const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);
+    // Create a new equivalent to the chosen loop
+    Var old_loop_var = old_loop->loop_var;
+    Var new_loop_var = old_loop_var.copy_with_suffix("_init");
+    loop_var_map[old_loop_var] = new_loop_var;
+    body = For(/*loop_var=*/new_loop_var,
+               /*min=*/old_loop->min,
+               /*extent=*/old_loop->extent,
+               /*kind=*/ForKind::kSerial,
+               /*body=*/body);
+  }
+  body = Substitute(body, loop_var_map);
+  // Step 6. Mutate IR
+  const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(old_scope_root, scope_root_sref);
+  Block new_scope_root;
+  Block new_reduction_block;

Review comment:
       ```suggestion
     Block new_scope_root{nullptr};
     Block new_reduction_block{nullptr};
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {
+    const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);
+    // Create a new equivalent to the chosen loop
+    Var old_loop_var = old_loop->loop_var;
+    Var new_loop_var = old_loop_var.copy_with_suffix("_init");
+    loop_var_map[old_loop_var] = new_loop_var;
+    body = For(/*loop_var=*/new_loop_var,
+               /*min=*/old_loop->min,
+               /*extent=*/old_loop->extent,
+               /*kind=*/ForKind::kSerial,
+               /*body=*/body);
+  }
+  body = Substitute(body, loop_var_map);
+  // Step 6. Mutate IR
+  const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(old_scope_root, scope_root_sref);
+  Block new_scope_root;
+  Block new_reduction_block;

Review comment:
       nit
   
   ```suggestion
     Block new_scope_root{nullptr};
     Block new_reduction_block{nullptr};
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }

Review comment:
       Do we make the assumption here about the write region?

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;

Review comment:
       nit: don't set predicate here because we will mutate later on line 271

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };

Review comment:
       nit:
   
   ```suggestion
     auto f = [&](const VarNode* var) { return discarded_loops.count(var); };
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;

Review comment:
       nit: swap the two statements below
   
   ```suggestion
     if (is_one(pred)) return new_pred;
     PrimExpr new_pred = Bool(true);
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());

Review comment:
       Do we have any checks on the rhs?

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }

Review comment:
       Why do we need special handling of SeqStmt btw?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] spectrometerHBH commented on pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#issuecomment-922398848


   cc @tqchen @junrushao1994 @MasterJH5574 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r712627574



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,289 @@
 namespace tvm {
 namespace tir {
 
+bool ListContainsElement(const Array<StmtSRef>& list, const StmtSRef& element) {
+  for (const StmtSRef& ele : list) {
+    if (ele.same_as(element)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  ICHECK(block_sref.defined())
+      << "ValueError: 'decompose_reduction' expect a block as first argument, but get value 'None'";
+  const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  if (!ListContainsElement(loops, loop_sref)) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  const StmtSRef& scope_root_sref = GetScopeRoot(self, block_sref, false, false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = loops.size() - 1; i >= 0; --i) {

Review comment:
       Ditto.
   ```suggestion
     for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 merged pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 merged pull request #9041:
URL: https://github.com/apache/tvm/pull/9041


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] spectrometerHBH commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r715306043



##########
File path: src/tir/schedule/state.cc
##########
@@ -206,25 +188,11 @@ class StateCreator : private StmtVisitor {
    * \param stmt A for-loop statement or a block statement
    * \return A sref to the stmt
    */
-  StmtSRef PushSRef(const StmtNode* stmt) {
-    if (srefs_.empty()) {
-      srefs_.push_back(
-          StmtSRef(stmt,
-                   /*parent=*/nullptr,
-                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
-    } else {
-      StmtSRefNode* parent = srefs_.back().get();
-      srefs_.push_back(
-          StmtSRef(stmt, parent,
-                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
-    }
-    return srefs_.back();
-  }
+  void PushSRef(const StmtNode* stmt) { srefs_.push_back(self_->stmt2ref.at(stmt)); }
 
-  /*! \brief Pop the top of the scope and record it in stmt2ref map */
-  StmtSRef PopAndRecordSRef() {
-    StmtSRef sref = std::move(srefs_.back());
-    self_->stmt2ref[sref->stmt] = sref;
+  /*! \brief Pop the top of the scope */

Review comment:
       Yes. But it doesn't matter.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r717024869



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {
+    const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);
+    // Create a new equivalent to the chosen loop
+    Var old_loop_var = old_loop->loop_var;
+    Var new_loop_var = old_loop_var.copy_with_suffix("_init");
+    loop_var_map[old_loop_var] = new_loop_var;
+    body = For(/*loop_var=*/new_loop_var,
+               /*min=*/old_loop->min,
+               /*extent=*/old_loop->extent,
+               /*kind=*/ForKind::kSerial,
+               /*body=*/body);
+  }
+  body = Substitute(body, loop_var_map);
+  // Step 6. Mutate IR
+  const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(old_scope_root, scope_root_sref);
+  Block new_scope_root;
+  Block new_reduction_block;

Review comment:
       ```suggestion
     Block new_scope_root{nullptr};
     Block new_reduction_block{nullptr};
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {
+    const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);
+    // Create a new equivalent to the chosen loop
+    Var old_loop_var = old_loop->loop_var;
+    Var new_loop_var = old_loop_var.copy_with_suffix("_init");
+    loop_var_map[old_loop_var] = new_loop_var;
+    body = For(/*loop_var=*/new_loop_var,
+               /*min=*/old_loop->min,
+               /*extent=*/old_loop->extent,
+               /*kind=*/ForKind::kSerial,
+               /*body=*/body);
+  }
+  body = Substitute(body, loop_var_map);
+  // Step 6. Mutate IR
+  const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(old_scope_root, scope_root_sref);
+  Block new_scope_root;
+  Block new_reduction_block;

Review comment:
       nit
   
   ```suggestion
     Block new_scope_root{nullptr};
     Block new_reduction_block{nullptr};
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r716922693



##########
File path: src/tir/schedule/state.cc
##########
@@ -421,6 +389,86 @@ class StateCreator : private StmtVisitor {
   arith::Analyzer analyzer_;
 };
 
+/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
+class StateCreator : private StmtVisitor {

Review comment:
       I just checked. Actually StateCreator, including its contents like PushSRef, VisitStmt, is not used anymore, so let's remove them all and inline the Create method into `ScheduleState::ScheduleState`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r713613369



##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -1223,6 +1223,70 @@ def after_inline(a: ty.handle, c: ty.handle) -> None:
 
     ########## Schedule: Reduction ##########
 
+    def decompose_reduction(self, block: BlockRV, loop: LoopRV) -> BlockRV:
+        """Decompose a reduction block into init block and update block, where the newly generated
+        init block will be before the specified loop.
+
+        1) The block is a reduction block.
+
+        2) The loop is the ancestor of the block.
+
+        3) The loop is not lower than all the loops related to reduce block var.
+
+        Parameters
+        ----------
+        block : BlockRV
+            The reduction block to be decomposed
+        loop : LoopRV
+            The position where init block is inserted
+
+        Examples
+        --------
+        Before decompose-reduction, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_decompose(a: ty.handle, c: ty.handle) -> None:
+                A = tir.match_buffer(a, [128, 128])
+                B = tir.match_buffer(b, [128, 128])
+                C = tir.match_buffer(c, [128, 128])
+                for i, j, k in tir.grid(128, 128, 128):
+                    with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
+                        with tir.init():
+                            C[vi, vj] = 0.0
+                        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+        Create the schedule and do decompose-reduction with specified loop:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_decompose)
+            C = sch.get_block("C")
+            i, j, k = sch.get_loops(C)
+            sch.decompose_reduction(C, i)
+            print(tvm.script.asscript(sch.mod["main"]))
+
+        After applying decompose-reduction, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_decompose(a: ty.handle, c: ty.handle) -> None:
+                A = tir.match_buffer(a, [128, 128])
+                B = tir.match_buffer(b, [128, 128])
+                C = tir.match_buffer(c, [128, 128])
+                for i in tir.serial(128):
+                    for j in tir.serial(128):
+                        with tir.block([128, 128]) as [vi, vj]:
+                            C[vi, vj] = 0.0
+                for i, j, k in tir.grid(128, 128, 128):
+                    with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
+                        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+        """
+        _ffi_api.ScheduleDecomposeReduction(self, block, loop)  # type: ignore # pylint: disable=no-member

Review comment:
       Looks like we forgot to return anything?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r714278995



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,289 @@
 namespace tvm {
 namespace tir {
 
+bool ListContainsElement(const Array<StmtSRef>& list, const StmtSRef& element) {
+  for (const StmtSRef& ele : list) {
+    if (ele.same_as(element)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  ICHECK(block_sref.defined())
+      << "ValueError: 'decompose_reduction' expect a block as first argument, but get value 'None'";

Review comment:
       Yep we don't need this check. We always assume the `block_sref` points to a block and it is not nullable

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,289 @@
 namespace tvm {
 namespace tir {
 
+bool ListContainsElement(const Array<StmtSRef>& list, const StmtSRef& element) {
+  for (const StmtSRef& ele : list) {
+    if (ele.same_as(element)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  ICHECK(block_sref.defined())
+      << "ValueError: 'decompose_reduction' expect a block as first argument, but get value 'None'";
+  const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  if (!ListContainsElement(loops, loop_sref)) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  const StmtSRef& scope_root_sref = GetScopeRoot(self, block_sref, false, false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = loops.size() - 1; i >= 0; --i) {
+    const auto* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, discarded_loops);
+  // Step 5. Create new loops above init block
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {
+    const auto* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);
+    // Create a new equivalent to the chosen loop
+    Var old_loop_var = old_loop->loop_var;
+    Var new_loop_var = old_loop_var.copy_with_suffix("_init");
+    body = For(/*loop_var=*/new_loop_var,
+               /*min=*/old_loop->min,
+               /*extent=*/old_loop->extent,
+               /*kind=*/ForKind::kSerial,
+               /*body=body*/ Substitute(body, {{old_loop_var, new_loop_var}}));
+  }
+  // Step 6. Mutate IR
+  const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(old_scope_root, scope_root_sref);
+  Block new_scope_root;
+  Block new_reduction_block;
+  std::tie(new_scope_root, new_reduction_block) = DecomposeReductionBlockReplacer::Replace(
+      GetRef<Block>(old_scope_root), GetRef<For>(loop), body, GetRef<Block>(block));
+  self->Replace(scope_root_sref, new_scope_root,
+                {{GetRef<Block>(old_scope_root), new_scope_root},
+                 {GetRef<Block>(block), new_reduction_block}});
+  self->UpdateBlockInfo(new_scope_root);
+  StmtSRef init_block_sref = self->stmt2ref.at(init_block.get());
+  return init_block_sref;

Review comment:
       nit: it's okay to directly return it
   
   ```suggestion
     return self->stmt2ref.at(init_block.get());
   ```

##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -1223,6 +1223,70 @@ def after_inline(a: ty.handle, c: ty.handle) -> None:
 
     ########## Schedule: Reduction ##########
 
+    def decompose_reduction(self, block: BlockRV, loop: LoopRV) -> BlockRV:
+        """Decompose a reduction block into init block and update block, where the newly generated
+        init block will be before the specified loop.
+
+        1) The block is a reduction block.
+
+        2) The loop is the ancestor of the block.
+
+        3) The loop is not lower than all the loops related to reduce block var.
+
+        Parameters
+        ----------
+        block : BlockRV
+            The reduction block to be decomposed
+        loop : LoopRV
+            The position where init block is inserted
+

Review comment:
       Looks like we forgot to mention the return value in the docstring

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,289 @@
 namespace tvm {
 namespace tir {
 
+bool ListContainsElement(const Array<StmtSRef>& list, const StmtSRef& element) {
+  for (const StmtSRef& ele : list) {
+    if (ele.same_as(element)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  ICHECK(block_sref.defined())
+      << "ValueError: 'decompose_reduction' expect a block as first argument, but get value 'None'";
+  const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  if (!ListContainsElement(loops, loop_sref)) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  const StmtSRef& scope_root_sref = GetScopeRoot(self, block_sref, false, false);

Review comment:
       Also it makes more sense not to use const& here




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r717225041



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;

Review comment:
       nit: swap the two statements below
   
   ```suggestion
     if (is_one(pred)) return new_pred;
     PrimExpr new_pred = Bool(true);
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] spectrometerHBH commented on pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#issuecomment-928829921


   > @spectrometerHBH Sorry for the delay! I think I finally did a complete with the code review. Let me know if it makes sense to you
   
   Thanks!I might be slow to response in the next 2 days


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r717004462



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;

Review comment:
       nit
   
   ```suggestion
     std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
     block_var_map.reserve(block->iter_vars.size());
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r717022936



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {
+    const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);
+    // Create a new equivalent to the chosen loop
+    Var old_loop_var = old_loop->loop_var;
+    Var new_loop_var = old_loop_var.copy_with_suffix("_init");
+    loop_var_map[old_loop_var] = new_loop_var;
+    body = For(/*loop_var=*/new_loop_var,
+               /*min=*/old_loop->min,
+               /*extent=*/old_loop->extent,

Review comment:
       do we need to handle the case where `min` and `extent` contains some other loop variables? are they handled in line 287?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#issuecomment-931841743


   Would be nice if we could get this merged this week or weekend :-)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r713560746



##########
File path: include/tvm/tir/schedule/schedule.h
##########
@@ -364,6 +364,17 @@ class ScheduleNode : public runtime::Object {
    */
   virtual void ReverseComputeInline(const BlockRV& block) = 0;
   /******** Schedule: Reduction ********/
+  /*!
+   * \brief Decompose a reduction block into init block and update block, where the newly generated
+     init block will be before the specified loop.
+     1) The block is a reduction block.
+     2) The loop is the ancestor of the block.
+     3) The loop is not lower than all the loops related to reduce block var.
+   * \param block_rv The reduction block to be decomposed
+   * \param loop_rv The position where init block is inserted
+   * \return The init block
+   */
+  virtual BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;

Review comment:
       Did we discontinue our support for the case where `loop_rv == None`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r715349916



##########
File path: src/tir/schedule/state.cc
##########
@@ -206,25 +188,11 @@ class StateCreator : private StmtVisitor {
    * \param stmt A for-loop statement or a block statement
    * \return A sref to the stmt
    */
-  StmtSRef PushSRef(const StmtNode* stmt) {
-    if (srefs_.empty()) {
-      srefs_.push_back(
-          StmtSRef(stmt,
-                   /*parent=*/nullptr,
-                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
-    } else {
-      StmtSRefNode* parent = srefs_.back().get();
-      srefs_.push_back(
-          StmtSRef(stmt, parent,
-                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
-    }
-    return srefs_.back();
-  }
+  void PushSRef(const StmtNode* stmt) { srefs_.push_back(self_->stmt2ref.at(stmt)); }
 
-  /*! \brief Pop the top of the scope and record it in stmt2ref map */
-  StmtSRef PopAndRecordSRef() {
-    StmtSRef sref = std::move(srefs_.back());
-    self_->stmt2ref[sref->stmt] = sref;
+  /*! \brief Pop the top of the scope */

Review comment:
       but why? Will we assume the flag for the sub tree root to be true?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#issuecomment-922431314


   Sorry I have a deadline on Monday...Will review do after that


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r717227641



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());

Review comment:
       Do we have any checks on the rhs?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r713554606



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,289 @@
 namespace tvm {
 namespace tir {
 
+bool ListContainsElement(const Array<StmtSRef>& list, const StmtSRef& element) {
+  for (const StmtSRef& ele : list) {
+    if (ele.same_as(element)) {
+      return true;
+    }
+  }
+  return false;
+}

Review comment:
       +1

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,289 @@
 namespace tvm {
 namespace tir {
 
+bool ListContainsElement(const Array<StmtSRef>& list, const StmtSRef& element) {
+  for (const StmtSRef& ele : list) {
+    if (ele.same_as(element)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";

Review comment:
       I'm not sure but it seems like a user error rather than an internal error




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r711994987



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,289 @@
 namespace tvm {
 namespace tir {
 
+bool ListContainsElement(const Array<StmtSRef>& list, const StmtSRef& element) {
+  for (const StmtSRef& ele : list) {
+    if (ele.same_as(element)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }

Review comment:
       There's a helper function in analysis.h which collects the variables contained in the bindings of data-parallel block iters and reduction block iters respectively. https://github.com/apache/tvm/blob/44b644c6a37266c6c49eaa7e8c87c7809b882da5/src/tir/schedule/analysis.h#L215-L226
   We can use it here to remove some repetitive code.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r711996916



##########
File path: include/tvm/tir/schedule/schedule.h
##########
@@ -364,6 +364,17 @@ class ScheduleNode : public runtime::Object {
    */
   virtual void ReverseComputeInline(const BlockRV& block) = 0;
   /******** Schedule: Reduction ********/
+  /*!
+   * \brief Decompose a reduction block into init block and update block, where the newly generated
+     init block will be before the specified loop.
+     1) The block is a reduction block.
+     2) The loop is the ancestor of the block.
+     3) The loop is not lower than all the loops related to reduce block var.

Review comment:
       ```suggestion
      * \brief Decompose a reduction block into init block and update block, where the newly generated
        init block will be before the specified loop. It requires that
        1) The input block is a reduction block.
        2) The input loop is an ancestor of the block.
        3) The input loop is not lower than all the loops related to reduce block var.
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] xqdan commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
xqdan commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r713509277



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,289 @@
 namespace tvm {
 namespace tir {
 
+bool ListContainsElement(const Array<StmtSRef>& list, const StmtSRef& element) {
+  for (const StmtSRef& ele : list) {
+    if (ele.same_as(element)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  ICHECK(block_sref.defined())
+      << "ValueError: 'decompose_reduction' expect a block as first argument, but get value 'None'";
+  const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const auto* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  if (!ListContainsElement(loops, loop_sref)) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  const StmtSRef& scope_root_sref = GetScopeRoot(self, block_sref, false, false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = loops.size() - 1; i >= 0; --i) {

Review comment:
       size_t i




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] xqdan commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
xqdan commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r713509194



##########
File path: src/tir/schedule/primitive/loop_transformation.cc
##########
@@ -482,7 +482,7 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
   Array<PrimExpr> substitute_value;
   substitute_value.resize(loops.size());
   PrimExpr tot = fused_var;
-  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; i--) {
+  for (int i = loops.size() - 1; i >= 0; i--) {

Review comment:
       size_t i




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r713522208



##########
File path: src/tir/schedule/primitive/loop_transformation.cc
##########
@@ -482,7 +482,7 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
   Array<PrimExpr> substitute_value;
   substitute_value.resize(loops.size());
   PrimExpr tot = fused_var;
-  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; i--) {
+  for (int i = loops.size() - 1; i >= 0; i--) {

Review comment:
       No we should keep with `static_cast` instead. As a principle, `size_t` is all evil and should be avoided unless we have good reasons. In our particular case, if `loops.size()` is 0, underflow of `size_t` it may cause tremendous trouble. Even if by reading around we might have the assumption that `loops.size() > 0`, it is still not as convenient as just casting to signed integers.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#issuecomment-928828584


   @spectrometerHBH Sorry for the delay! I think I finally did a complete with the code review. Let me know if it makes sense to you


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] spectrometerHBH commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r715293249



##########
File path: include/tvm/tir/schedule/schedule.h
##########
@@ -364,6 +364,17 @@ class ScheduleNode : public runtime::Object {
    */
   virtual void ReverseComputeInline(const BlockRV& block) = 0;
   /******** Schedule: Reduction ********/
+  /*!
+   * \brief Decompose a reduction block into init block and update block, where the newly generated
+     init block will be before the specified loop.
+     1) The block is a reduction block.
+     2) The loop is the ancestor of the block.
+     3) The loop is not lower than all the loops related to reduce block var.
+   * \param block_rv The reduction block to be decomposed
+   * \param loop_rv The position where init block is inserted
+   * \return The init block
+   */
+  virtual BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;

Review comment:
       Yes




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r717228066



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }

Review comment:
       Why do we need special handling of SeqStmt btw?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] spectrometerHBH commented on pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#issuecomment-926639300


   cc @junrushao1994 @MasterJH5574 @Hzfengsy 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r716920036



##########
File path: src/tir/schedule/state.cc
##########
@@ -206,25 +188,11 @@ class StateCreator : private StmtVisitor {
    * \param stmt A for-loop statement or a block statement
    * \return A sref to the stmt
    */
-  StmtSRef PushSRef(const StmtNode* stmt) {
-    if (srefs_.empty()) {
-      srefs_.push_back(
-          StmtSRef(stmt,
-                   /*parent=*/nullptr,
-                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
-    } else {
-      StmtSRefNode* parent = srefs_.back().get();
-      srefs_.push_back(
-          StmtSRef(stmt, parent,
-                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
-    }
-    return srefs_.back();
-  }
+  void PushSRef(const StmtNode* stmt) { srefs_.push_back(self_->stmt2ref.at(stmt)); }
 
-  /*! \brief Pop the top of the scope and record it in stmt2ref map */
-  StmtSRef PopAndRecordSRef() {
-    StmtSRef sref = std::move(srefs_.back());
-    self_->stmt2ref[sref->stmt] = sref;
+  /*! \brief Pop the top of the scope */

Review comment:
       Correct me if I was wrong, but `ScheduleStateNode::UpdateSubtreeblockInfo(stmt)` calls `BlockInfoCollector::Collect`, where `stmt` can be the a root block of a subtree, but it's binding might not be affine; However, in `MakeBlockInfo`, we would assume `is_root_block == True`, and make this binding affine - is that correct?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9041: [TensorIR][M2a] Decompose-Reduction

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r716935541



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }

Review comment:
       nitpick
   
   ```suggestion
     if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) {
       throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
     }
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org