You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2021/10/02 22:35:10 UTC

[tvm] branch main updated: [TensorIR][M2a] Decompose-Reduction (#9041)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6b3fe95  [TensorIR][M2a] Decompose-Reduction (#9041)
6b3fe95 is described below

commit 6b3fe95f3c5ff75f65d15d6fc2797f584c456a5a
Author: Bohan Hou <32...@users.noreply.github.com>
AuthorDate: Sat Oct 2 18:34:41 2021 -0400

    [TensorIR][M2a] Decompose-Reduction (#9041)
    
    This PR is part of the TensorIR upstreaming effort (#7527),
    which adds the `decompose-reduction` scheduling primitive.
    
    Co-authored-by: Junru Shao <ju...@gmail.com>
    Co-authored-by: Ruihang Lai <la...@qq.com>
    Co-authored-by: Hongyi Jin <32...@qq.com>
    Co-authored-by: Wuwei Lin <wu...@apache.org>
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
---
 include/tvm/tir/schedule/schedule.h                |  16 +
 include/tvm/tir/schedule/state.h                   |   6 +
 python/tvm/tir/schedule/schedule.py                |  76 +++
 src/tir/schedule/concrete_schedule.cc              |   9 +
 src/tir/schedule/concrete_schedule.h               |   1 +
 src/tir/schedule/primitive.h                       |  17 +
 src/tir/schedule/primitive/reduction.cc            | 302 +++++++++
 src/tir/schedule/schedule.cc                       |   2 +
 src/tir/schedule/state.cc                          | 147 +++--
 src/tir/schedule/traced_schedule.cc                |  10 +
 src/tir/schedule/traced_schedule.h                 |   1 +
 .../python/unittest/test_tir_schedule_reduction.py | 679 ++++-----------------
 ...e_reduction.py => test_tir_schedule_rfactor.py} |   0
 13 files changed, 664 insertions(+), 602 deletions(-)

diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h
index 9f48d9a..c4aa1c9 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -366,6 +366,22 @@ class ScheduleNode : public runtime::Object {
   virtual void ReverseComputeInline(const BlockRV& block) = 0;
   /******** Schedule: Reduction ********/
   /*!
+   * \brief Decompose a reduction block into two separate blocks.
+   * a) The init block, which is translated from the init statement of the reduction block;
+   * b) The update block, which is the original block without init statement.
+   *
+   * The init block is inserted right before the given loop.
+   *
+   * The schedule primitive requires:
+   * 1) The input block is a reduction block.
+   * 2) The input loop is the ancestor of the block.
+   * 3) The input loop is not lower than all the loops related to reduce block var.
+   * \param block_rv The reduction block to be decomposed
+   * \param loop_rv The loop above which the init block is inserted before.
+   * \return The init block
+   */
+  virtual BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;
+  /*!
    * \brief Factorize an associative reduction block by the specified loop.
    * \details An associative reduction cannot be parallelized directly,
    * because it leads to potential race condition during accumulation.
diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h
index 7cd1b00..201d78f 100644
--- a/include/tvm/tir/schedule/state.h
+++ b/include/tvm/tir/schedule/state.h
@@ -143,6 +143,12 @@ class ScheduleStateNode : public Object {
   /*! \brief Returns the BlockInfo correpsonding to the block sref */
   TVM_DLL BlockInfo GetBlockInfo(const StmtSRef& block_sref) const;
   /*!
+   * \brief Recalculate the BlockInfo recursively under stmt.
+   * If stmt is a Block itself, we will not reset its affine binding flag unless it doesn't
+   * have block vars, since the affine flag depends on the outer scope of stmt.
+   */
+  TVM_DLL void UpdateScopeBlockInfo(const Stmt& stmt);
+  /*!
    * \brief Get the BlockScope correpsonding to the sref of scope root block
    * \param scope_root The block sref to be retrieved
    * \return The corresponding BlockScope
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index 6e27015..09a52d2 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -1224,6 +1224,82 @@ class Schedule(Object):
 
     ########## Schedule: Reduction ##########
 
+    def decompose_reduction(self, block: BlockRV, loop: LoopRV) -> BlockRV:
+        """Decompose a reduction block into two separate blocks.
+
+        a) The init block, which is translated from the init statement of the reduction block;
+
+        b) The update block, which is the original block without init statement.
+
+        The init block is inserted right before the given loop.
+
+        The schedule primitive requires:
+
+        1) The input block is a reduction block.
+
+        2) The input loop is the ancestor of the block.
+
+        3) The input loop is not lower than all the loops related to reduce block var.
+
+        Parameters
+        ----------
+        block : BlockRV
+            The reduction block to be decomposed
+        loop : LoopRV
+            The loop above which the init block is inserted before.
+
+        Returns
+        -------
+        init_block : BlockRV
+            The init block
+
+        Examples
+        --------
+        Before decompose-reduction, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_decompose(a: ty.handle, c: ty.handle) -> None:
+                A = tir.match_buffer(a, [128, 128])
+                B = tir.match_buffer(b, [128, 128])
+                C = tir.match_buffer(c, [128, 128])
+                for i, j, k in tir.grid(128, 128, 128):
+                    with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
+                        with tir.init():
+                            C[vi, vj] = 0.0
+                        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+        Create the schedule and do decompose-reduction with specified loop:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_decompose)
+            C = sch.get_block("C")
+            i, j, k = sch.get_loops(C)
+            sch.decompose_reduction(C, i)
+            print(tvm.script.asscript(sch.mod["main"]))
+
+        After applying decompose-reduction, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_decompose(a: ty.handle, c: ty.handle) -> None:
+                A = tir.match_buffer(a, [128, 128])
+                B = tir.match_buffer(b, [128, 128])
+                C = tir.match_buffer(c, [128, 128])
+                for i in tir.serial(128):
+                    for j in tir.serial(128):
+                        with tir.block([128, 128]) as [vi, vj]:
+                            C[vi, vj] = 0.0
+                for i, j, k in tir.grid(128, 128, 128):
+                    with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
+                        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+        """
+        return _ffi_api.ScheduleDecomposeReduction(self, block, loop)  # type: ignore # pylint: disable=no-member
+
     def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV:
         """Factorize an associative reduction block by the specified loop.
 
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index 93eba52..4283907 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -501,6 +501,15 @@ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_inde
 
 /******** Schedule: Reduction ********/
 
+BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) {
+  StmtSRef result{nullptr};
+  TVM_TIR_SCHEDULE_BEGIN();
+  result = tir::DecomposeReduction(state_, this->GetSRef(block_rv), this->GetSRef(loop_rv));
+  TVM_TIR_SCHEDULE_END("decompose-reduction", this->error_render_level_);
+  this->state_->DebugVerify();
+  return CreateRV<BlockRV>(result);
+}
+
 BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) {
   StmtSRef result{nullptr};
   TVM_TIR_SCHEDULE_BEGIN();
diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h
index c9a9402..1f9aeec 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -115,6 +115,7 @@ class ConcreteScheduleNode : public ScheduleNode {
   void ReverseComputeInline(const BlockRV& block) override;
   /******** Schedule: Reduction ********/
   BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override;
+  BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) override;
   /******** Schedule: Block annotation ********/
   void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
                     int offset) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 8d8acd2..057e845 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -234,6 +234,23 @@ TVM_DLL void ComputeInline(ScheduleState self, const StmtSRef& block_sref);
 TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref);
 /******** Schedule: Reduction ********/
 /*!
+ * \brief Decompose a reduction block into two separate blocks.
+ * a) The init block, which is translated from the init statement of the reduction block;
+ * b) The update block, which is the original block without init statement.
+ *
+ * The init block is inserted right before the given loop.
+ *
+ * The schedule primitive requires:
+ * 1) The input block is a reduction block.
+ * 2) The input loop is the ancestor of the block.
+ * 3) The input loop is not lower than all the loops related to reduce block var.
+ * \param block_rv The reduction block to be decomposed
+ * \param loop_rv The loop above which the init block is inserted before.
+ * \return The init block
+ */
+TVM_DLL StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                                    const StmtSRef& loop_sref);
+/*!
  * \brief Factor a reduction block by the specified loop
  * \details See python/tvm/tir/schedule/schedule.py
  * \param self The state of the schedule
diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc
index 677b643..0653f6e 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -21,6 +21,282 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+  if (is_one(pred)) return Bool(true);
+  PrimExpr new_pred = Bool(true);
+  auto f = [&](const VarNode* var) { return discarded_loops.count(var); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          /*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+  block_var_map.reserve(block->iter_vars.size());
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix(""),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(realize->predicate, discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (int i : chosen_loops) {
+    const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);
+    // Create a new equivalent to the chosen loop
+    Var old_loop_var = old_loop->loop_var;
+    Var new_loop_var = old_loop_var.copy_with_suffix("_init");
+    loop_var_map[old_loop_var] = new_loop_var;
+    body = For(/*loop_var=*/new_loop_var,
+               /*min=*/old_loop->min,
+               /*extent=*/old_loop->extent,
+               /*kind=*/ForKind::kSerial,
+               /*body=*/body);
+  }
+  body = Substitute(body, loop_var_map);
+  // Step 6. Mutate IR
+  const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(old_scope_root, scope_root_sref);
+  Block new_scope_root{nullptr};
+  Block new_reduction_block{nullptr};
+  std::tie(new_scope_root, new_reduction_block) = DecomposeReductionBlockReplacer::Replace(
+      GetRef<Block>(old_scope_root), GetRef<For>(loop), body, GetRef<Block>(block));
+  self->Replace(scope_root_sref, new_scope_root,
+                {{GetRef<Block>(old_scope_root), new_scope_root},
+                 {GetRef<Block>(block), new_reduction_block}});
+  self->UpdateScopeBlockInfo(new_scope_root);
+  return self->stmt2ref.at(init_block.get());
+}
+
 /******** Commutative Reducer ********/
 
 /*!
@@ -958,6 +1234,31 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax
 
 /******** InstructionKind Registration ********/
 
+struct DecomposeReductionTraits : public UnpackedInstTraits<DecomposeReductionTraits> {
+  static constexpr const char* kName = "DecomposeReduction";
+  static constexpr bool kIsPure = false;
+
+ private:
+  static constexpr size_t kNumInputs = 2;
+  static constexpr size_t kNumAttrs = 0;
+  static constexpr size_t kNumDecisions = 0;
+
+  static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv) {
+    return sch->DecomposeReduction(block_rv, loop_rv);
+  }
+
+  static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv) {
+    PythonAPICall py("decompose_reduction");
+    py.Input("block", block_rv);
+    py.Input("loop", loop_rv);
+    py.SingleOutput(outputs);
+    return py.Str();
+  }
+
+  template <typename>
+  friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
 struct RFactorTraits : public UnpackedInstTraits<RFactorTraits> {
   static constexpr const char* kName = "RFactor";
   static constexpr bool kIsPure = false;
@@ -984,6 +1285,7 @@ struct RFactorTraits : public UnpackedInstTraits<RFactorTraits> {
 };
 
 TVM_REGISTER_INST_KIND_TRAITS(RFactorTraits);
+TVM_REGISTER_INST_KIND_TRAITS(DecomposeReductionTraits);
 
 /******** FFI ********/
 
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 4262a09..84a37c3 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -155,6 +155,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline")
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline")
     .set_body_method<Schedule>(&ScheduleNode::ReverseComputeInline);
 /******** (FFI) Reduction ********/
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposeReduction")
+    .set_body_method<Schedule>(&ScheduleNode::DecomposeReduction);
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor")
     .set_body_method<Schedule>(&ScheduleNode::RFactor);
 /******** (FFI) Block annotation ********/
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index 4604add..faeb0b9 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -169,34 +169,16 @@ void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new
 }
 
 /**************** Creation ****************/
-
-/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
-class StateCreator : private StmtVisitor {
+/*! \brief A helper class to update BlockInfo for a ScheduleStateNode */
+class BlockInfoCollector : private StmtVisitor {
  public:
-  /*!
-   * \brief The entry function
-   * \param self The schedule state to be completed
-   */
-  static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mask) {
-    ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
-    ScheduleStateNode* self = n.get();
-    // Set `n->mod`
-    n->mod = std::move(mod);
-    // Set `n->debug_mask`
-    n->debug_mask = debug_mask;
-    // Set `n->stmt2ref` and `n->block_info`
-    StateCreator creator(self);
-    for (const auto& kv : n->mod->functions) {
-      const BaseFunc& base_func = kv.second;
-      if (const auto* func = base_func.as<PrimFuncNode>()) {
-        creator.VisitStmt(func->body);
-      }
-    }
-    return n;
+  static void Collect(ScheduleStateNode* self, const Stmt& stmt) {
+    BlockInfoCollector collector(self);
+    collector.VisitStmt(stmt);
   }
 
  private:
-  explicit StateCreator(ScheduleStateNode* self)
+  explicit BlockInfoCollector(ScheduleStateNode* self)
       : self_(self), srefs_{}, block2realize_{}, block_frames_{} {
     block_frames_.emplace({});
   }
@@ -206,25 +188,11 @@ class StateCreator : private StmtVisitor {
    * \param stmt A for-loop statement or a block statement
    * \return A sref to the stmt
    */
-  StmtSRef PushSRef(const StmtNode* stmt) {
-    if (srefs_.empty()) {
-      srefs_.push_back(
-          StmtSRef(stmt,
-                   /*parent=*/nullptr,
-                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
-    } else {
-      StmtSRefNode* parent = srefs_.back().get();
-      srefs_.push_back(
-          StmtSRef(stmt, parent,
-                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
-    }
-    return srefs_.back();
-  }
+  void PushSRef(const StmtNode* stmt) { srefs_.push_back(self_->stmt2ref.at(stmt)); }
 
-  /*! \brief Pop the top of the scope and record it in stmt2ref map */
-  StmtSRef PopAndRecordSRef() {
-    StmtSRef sref = std::move(srefs_.back());
-    self_->stmt2ref[sref->stmt] = sref;
+  /*! \brief Pop the top of the scope */
+  StmtSRef PopSRef() {
+    StmtSRef sref = srefs_.back();
     srefs_.pop_back();
     return sref;
   }
@@ -238,7 +206,10 @@ class StateCreator : private StmtVisitor {
             .first->second;
     // Set `affine_binding`
     if (is_root_block) {
-      info.affine_binding = true;
+      // If the block doesn't have outer loops and BlockRealize,
+      // then we set the affine binding flag as true only if the block has no block vars
+      const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root);
+      if (block->iter_vars.empty()) info.affine_binding = true;
     } else {
       info.affine_binding =
           IsAffineBinding(/*realize=*/block2realize_.at(scope_root->stmt),
@@ -385,7 +356,7 @@ class StateCreator : private StmtVisitor {
     analyzer_.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
     PushSRef(loop);
     VisitStmt(loop->body);
-    PopAndRecordSRef();
+    PopSRef();
   }
 
   void VisitStmt_(const BlockRealizeNode* realize) final {
@@ -395,7 +366,7 @@ class StateCreator : private StmtVisitor {
     // Recursive visit
     PushSRef(block);
     VisitStmt(block->body);  // `block->init` is not visited
-    StmtSRef sref = PopAndRecordSRef();
+    StmtSRef sref = PopSRef();
     // Create BlockInfo for the block
     MakeBlockInfo(sref);
     // Update parent scope
@@ -409,7 +380,7 @@ class StateCreator : private StmtVisitor {
     SetSeqIndexInChildren(self_, seq_stmt);
   }
 
-  /*! \brief The result ScheduleStateNode */
+  /*! \brief The ScheduleStateNode we are operating on */
   ScheduleStateNode* self_;
   /*! \brief The stack frame used to indicate the current scope */
   std::vector<StmtSRef> srefs_;
@@ -421,6 +392,86 @@ class StateCreator : private StmtVisitor {
   arith::Analyzer analyzer_;
 };
 
+/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
+class StateCreator : private StmtVisitor {
+ public:
+  /*!
+   * \brief The entry function
+   * \param self The schedule state to be completed
+   */
+  static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mask) {
+    ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
+    ScheduleStateNode* self = n.get();
+    // Set `n->mod`
+    n->mod = std::move(mod);
+    // Set `n->debug_mask`
+    n->debug_mask = debug_mask;
+    // Set `n->stmt2ref` and `n->block_info`
+    StateCreator creator(self);
+    for (const auto& kv : n->mod->functions) {
+      const BaseFunc& base_func = kv.second;
+      if (const auto* func = base_func.as<PrimFuncNode>()) {
+        creator.VisitStmt(func->body);
+        BlockInfoCollector::Collect(self, func->body);
+      }
+    }
+    return n;
+  }
+
+ private:
+  explicit StateCreator(ScheduleStateNode* self) : self_(self) {}
+
+  /*!
+   * \brief Add a new statement to the stack, which becomes the current scope
+   * \param stmt A for-loop statement or a block statement
+   * \return A sref to the stmt
+   */
+  void PushSRef(const StmtNode* stmt) {
+    if (srefs_.empty()) {
+      srefs_.push_back(
+          StmtSRef(stmt,
+                   /*parent=*/nullptr,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    } else {
+      StmtSRefNode* parent = srefs_.back().get();
+      srefs_.push_back(
+          StmtSRef(stmt, parent,
+                   /*seq_index=*/-1));  // `seq_index` will be set properly in SetSeqIndex
+    }
+  }
+
+  /*! \brief Pop the top of the scope and record it in stmt2ref map */
+  void PopAndRecordSRef() {
+    StmtSRef sref = std::move(srefs_.back());
+    self_->stmt2ref[sref->stmt] = sref;
+    srefs_.pop_back();
+  }
+
+  void VisitStmt_(const ForNode* loop) final {
+    PushSRef(loop);
+    VisitStmt(loop->body);
+    PopAndRecordSRef();
+  }
+
+  void VisitStmt_(const BlockRealizeNode* realize) final {
+    const BlockNode* block = realize->block.get();
+    PushSRef(block);
+    VisitStmt(block->body);  // `block->init` is not visited
+    PopAndRecordSRef();
+  }
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    // Set `seq_index` information for SeqStmtNode
+    StmtVisitor::VisitStmt_(seq_stmt);
+    SetSeqIndexInChildren(self_, seq_stmt);
+  }
+
+  /*! \brief The result ScheduleStateNode */
+  ScheduleStateNode* self_;
+  /*! \brief The stack frame used to indicate the current scope */
+  std::vector<StmtSRef> srefs_;
+};
+
 /**************** Constructor ****************/
 
 ScheduleState::ScheduleState(IRModule mod, int debug_mask) {
@@ -1034,6 +1085,10 @@ BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const {
   return it->second;
 }
 
+void ScheduleStateNode::UpdateScopeBlockInfo(const Stmt& stmt) {
+  BlockInfoCollector::Collect(this, stmt);
+}
+
 TVM_DLL Array<Bool> GetCachedFlags(const ScheduleState& self, const StmtSRef& block_sref) {
   const BlockInfo& info = self->GetBlockInfo(block_sref);
   return {Bool(info.affine_binding),  //
diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc
index 6f67959..cc48f2b 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -236,6 +236,16 @@ void TracedScheduleNode::ReverseComputeInline(const BlockRV& block_rv) {
 
 /******** Schedule: Reduction ********/
 
+BlockRV TracedScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) {
+  BlockRV result = ConcreteScheduleNode::DecomposeReduction(block_rv, loop_rv);
+  static const InstructionKind& kind = InstructionKind::Get("DecomposeReduction");
+  trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
+                                      /*inputs=*/{block_rv, loop_rv},
+                                      /*attrs=*/{},
+                                      /*outputs=*/{result}));
+  return result;
+}
+
 BlockRV TracedScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) {
   BlockRV result = ConcreteScheduleNode::RFactor(loop_rv, factor_axis);
   static const InstructionKind& kind = InstructionKind::Get("RFactor");
diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h
index fb89783..fae5ca8 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -82,6 +82,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
   void ComputeInline(const BlockRV& block_rv) final;
   void ReverseComputeInline(const BlockRV& block_rv) final;
   /******** Schedule: Reduction ********/
+  BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) final;
   BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) final;
   /******** Schedule: Block annotation ********/
   void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py
index d79338a..8460b5c 100644
--- a/tests/python/unittest/test_tir_schedule_reduction.py
+++ b/tests/python/unittest/test_tir_schedule_reduction.py
@@ -28,607 +28,174 @@ from tvm.tir.schedule.testing import verify_trace_roundtrip
 
 
 @T.prim_func
-def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, [128, 128])
-    B = T.match_buffer(b, [128, 128])
-    C = T.match_buffer(c, [128, 128])
-
-    for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4):
-        with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
-            T.bind(vi, i0)
-            T.bind(vj, i1)
-            T.bind(vk, (((i2_outer * 32) + (i2_inner_outer * 4)) + i2_inner_inner))
-            T.reads([C[vi, vj], A[vi, vk], B[vj, vk]])
-            T.writes([C[vi, vj]])
+def rowsum_blockized(a: T.handle, b: T.handle) -> None:
+    B = T.match_buffer(b, [32, 4])
+    A = T.match_buffer(a, [32, 4, 128])
+    for i0, i2_0 in T.grid(32, 16):
+        with T.block([32, T.reduce_axis(0, 16)], "blockized_B") as [io, ko]:
+            T.bind(io, i0)
+            T.bind(ko, i2_0)
             with T.init():
-                C[vi, vj] = 0.0
-            C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
+                for i1 in T.serial(0, 4):
+                    with T.block([4], "B_init") as [ii_init]:
+                        T.bind(ii_init, i1)
+                        B[io, ii_init] = 0.0
+            for i1_1, i2_1 in T.grid(4, 8):
+                with T.block([4, T.reduce_axis(0, 128)], "B") as [ii, k]:
+                    T.bind(ii, i1_1)
+                    T.bind(k, ko * 8 + i2_1)
+                    B[io, ii] = B[io, ii] + A[io, ii, k]
 
 
 @T.prim_func
-def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None:
+def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
     A = T.match_buffer(a, [128, 128])
     B = T.match_buffer(b, [128, 128])
     C = T.match_buffer(c, [128, 128])
-    C_rf = T.alloc_buffer([4, 128, 128])
-
-    for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4):
-        with T.block([4, 128, 128, T.reduce_axis(0, 4), T.reduce_axis(0, 8)], "update_rf") as [
-            vi2_inner_inner,
-            vi,
-            vj,
-            vi2_outer,
-            vi2_inner_outer,
-        ]:
-            T.bind(vi2_inner_inner, i2_inner_inner)
-            T.bind(vi, i0)
-            T.bind(vj, i1)
-            T.bind(vi2_outer, i2_outer)
-            T.bind(vi2_inner_outer, i2_inner_outer)
-            with T.init():
-                C_rf[vi2_inner_inner, vi, vj] = 0.0
-            C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + (
-                A[vi, (((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)]
-                * B[vj, (((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)]
-            )
-
-    for i0_1, i1_1, i2_inner_inner_1 in T.grid(128, 128, 4):
-        with T.block([T.reduce_axis(0, 4), 128, 128], "update") as [
-            vi2_inner_inner_1,
-            vi_1,
-            vj_1,
-        ]:
-            T.bind(vi2_inner_inner_1, i2_inner_inner_1)
-            T.bind(vi_1, i0_1)
-            T.bind(vj_1, i1_1)
-            with T.init():
-                C[vi_1, vj_1] = 0.0
-            C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1]
-
 
-@T.prim_func
-def matmul_not_stage_pipeline(a: T.handle, b: T.handle, d: T.handle) -> None:
-    A = T.match_buffer(a, [256, 256])
-    B = T.match_buffer(b, [256, 256])
-    D = T.match_buffer(d, [256, 256])
-    C = T.alloc_buffer([256, 256])
-
-    with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
+    with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
         with T.init():
             C[vi, vj] = 0.0
-        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
-
-    with T.block([256, 256], "D") as [vi, vj]:
-        D[vi, vj] = C[vi, vj]
+        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
 
 
 @T.prim_func
-def matmul_not_same_buffer_access(a: T.handle, b: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (128, 128))
-    B = T.match_buffer(b, (128, 128))
-    C = T.match_buffer(c, (128, 128))
-
-    with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
-        with T.init():
-            C[vi, vj] = 0.0
-        C[vj, vi] = C[vj, vi] + A[vi, vk] * B[vk, vj]
-
-
-@T.prim_func
-def matmul_loop_multiple_children(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None:
+def matmul_decompose0(a: T.handle, b: T.handle, c: T.handle) -> None:
     A = T.match_buffer(a, [128, 128])
     B = T.match_buffer(b, [128, 128])
     C = T.match_buffer(c, [128, 128])
-    D = T.match_buffer(d, [128, 128])
-
-    for k, i, j in T.grid(128, 128, 128):
-        with T.block([T.reduce_axis(0, 128), 128, 128], "C") as [ck, ci, cj]:
-            T.bind(ck, k)
-            T.bind(ci, i)
-            T.bind(cj, j)
-            with T.init():
-                C[ci, cj] = 0.0
-            C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj]
-        with T.block([T.reduce_axis(0, 128), 128, 128], "D") as [dk, di, dj]:
-            T.bind(dk, k)
-            T.bind(di, i)
-            T.bind(dj, j)
-            with T.init():
-                D[di, dj] = 0.0
-            D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj]
 
+    with T.block([128, 128], "init") as [vi, vj]:
+        C[vi, vj] = 0.0
 
-@T.prim_func
-def square_sum(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, [16, 256, 256])
-    C = T.match_buffer(c, [16])
-
-    with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]:
-        with T.init():
-            C[b] = 0.0
-        C[b] = C[b] + A[b, i, j] * A[b, i, j]
+    with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
+        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
 
 
 @T.prim_func
-def square_sum_rfactor(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, [16, 256, 256])
-    C = T.match_buffer(c, [16])
-    C_rf = T.alloc_buffer([16, 256])
-
-    for i0, i1, i2 in T.grid(16, 256, 256):
-        with T.block([256, 16, T.reduce_axis(0, 256)], "C_rf") as [vi2, b, i]:
-            T.bind(vi2, i2)
-            T.bind(b, i0)
-            T.bind(i, i1)
-            with T.init():
-                C_rf[b, vi2] = 0.0
-            C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2])
+def matmul_decompose1(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [32, 4, 128], elem_offset=0, align=128, offset_factor=1)
+    B = T.match_buffer(b, [32, 4], elem_offset=0, align=128, offset_factor=1)
 
-    for i0_1, i2_1 in T.grid(16, 256):
-        with T.block([T.reduce_axis(0, 256), 16], "C") as [vi2_1, b_1]:
-            T.bind(vi2_1, i2_1)
-            T.bind(b_1, i0_1)
-            with T.init():
-                C[b_1] = 0.0
-            C[b_1] = C[b_1] + C_rf[b_1, vi2_1]
-
-
-@T.prim_func
-def transformed_square_sum_square_root(a: T.handle, d: T.handle) -> None:
-    A = T.match_buffer(a, [16, 256, 256])
-    D = T.match_buffer(d, [16])
-    C = T.alloc_buffer([16])
-
-    for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1):
-        with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]:
-            T.bind(b, i0)
-            T.bind(i, T.floordiv(i1_i2_fused_outer, 256))
-            T.bind(j, T.floormod(i1_i2_fused_outer, 256))
-            T.reads([C[b], A[b, i, j]])
-            T.writes([C[b]])
-            with T.init():
-                C[b] = 0.0
-            C[b] = C[b] + (A[b, i, j] * A[b, i, j])
-    for i0_1 in T.serial(0, 16):
-        with T.block([16], "D") as [b_1]:
-            T.bind(b_1, i0_1)
-            T.reads([C[b_1]])
-            T.writes([D[b_1]])
-            D[b_1] = T.sqrt(C[b_1], dtype="float32")
+    for i0 in T.serial(0, 32):
+        with T.block([32], "blockized_B_init") as [io]:
+            for i1 in T.serial(0, 4):
+                with T.block([4], "B_init") as [ii]:
+                    B[io, ii] = T.float32(0)
+    for i0, i2_o in T.grid(32, 16):
+        with T.block([32, T.reduce_axis(0, 16)], "blockized_B_update") as [io, ko]:
+            for i1, i2_i in T.grid(4, 8):
+                with T.block([4, T.reduce_axis(0, 128)], "B") as [ii, k]:
+                    T.bind(ii, i1)
+                    T.bind(k, ((ko * 8) + i2_i))
+                    B[io, ii] = B[io, ii] + A[io, ii, k]
 
 
 @T.prim_func
-def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None:
-    A = T.match_buffer(a, [16, 256, 256])
-    D = T.match_buffer(d, [16])
-    C = T.alloc_buffer([16])
-    C_rf = T.alloc_buffer([1, 16])
-
-    for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1):
-        with T.block([1, 16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C_rf") as [
-            vi1_i2_fused_inner,
-            b,
-            i,
-            j,
-        ]:
-            T.bind(vi1_i2_fused_inner, i1_i2_fused_inner)
-            T.bind(b, i0)
-            T.bind(i, T.floordiv(i1_i2_fused_outer, 256))
-            T.bind(j, T.floormod(i1_i2_fused_outer, 256))
-            with T.init():
-                C_rf[vi1_i2_fused_inner, b] = 0.0
-            C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j])
-
-    for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1):
-        with T.block([T.reduce_axis(0, 1), 16], "C") as [vi1_i2_fused_inner_1, b_1]:
-            T.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1)
-            T.bind(b_1, i0_1)
-            with T.init():
-                C[b_1] = 0.0
-            C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1]
+def matmul_decompose2(a: T.handle, b: T.handle, c: T.handle) -> None:
+    C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
+    B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1)
+    A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
 
-    for i0_2 in T.serial(0, 16):
-        with T.block([16], "D") as [b_2]:
-            T.bind(b_2, i0_2)
-            D[b_2] = T.sqrt(C[b_2], dtype="float32")
+    for i0, i1 in T.grid(128, 128):
+        with T.block([128, 128], "update_init") as [vi_init, vj_init]:
+            C[vi_init, vj_init] = T.float32(0)
+        for i2 in T.serial(0, 128):
+            with T.block([128, 128, T.reduce_axis(0, 128)], "update_update") as [vi, vj, vk]:
+                C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
 
 
 @T.prim_func
-def element_wise(a: T.handle, b: T.handle) -> None:
-    A = T.match_buffer(a, (128, 128))
-    B = T.match_buffer(b, (128, 128))
-
-    with T.block([128, 128], "B") as [vi, vj]:
-        B[vi, vj] = A[vi, vj] * 2.0
-
-
-@T.prim_func
-def rowsum(a: T.handle, b: T.handle) -> None:
-    A = T.match_buffer(a, (128, 128))
-    B = T.match_buffer(b, (128,))
-
-    with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
-        with T.init():
-            B[vi] = 0.0
-        B[vi] = B[vi] + A[vi, vk]
-
-
-@T.prim_func
-def rowsum_not_quasi_affine(a: T.handle, b: T.handle) -> None:
-    A = T.match_buffer(a, (128, 128))
-    B = T.match_buffer(b, (128,))
-
-    for i, k in T.grid(128, 16):
-        with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
-            T.bind(vi, i)
-            T.bind(vk, T.floordiv(k * k, 2))
-            with T.init():
-                B[vi] = 0.0
-            B[vi] = B[vi] + A[vi, vk]
-
-
-@T.prim_func
-def rowsum_not_dominant(a: T.handle, b: T.handle) -> None:
-    A = T.match_buffer(a, (128, 128))
-    B = T.match_buffer(b, (128, 128))
-
-    with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
-        with T.init():
-            B[vi, vk] = 0.0
-        B[vi, vk] = B[vi, vk] + A[vi, vk]
-
-
-@T.prim_func
-def rowsum_not_serial(a: T.handle, b: T.handle) -> None:
-    A = T.match_buffer(a, (128, 128))
-    B = T.match_buffer(b, (128,))
-
-    for i in T.serial(0, 128):
-        for k in T.parallel(0, 128):
-            with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
-                T.bind(vi, i)
-                T.bind(vk, k)
-                with T.init():
-                    B[vi] = 0.0
-                B[vi] = B[vi] + A[vi, vk]
-
-
-@T.prim_func
-def rowsum_wrong_reduce_pattern1(a: T.handle, b: T.handle) -> None:
-    A = T.match_buffer(a, (128, 128))
-    B = T.match_buffer(b, (128,))
-
-    with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
-        with T.init():
-            B[vi] = 1.0
-        B[vi] = B[vi] + A[vi, vk]
-
-
-@T.prim_func
-def rowsum_wrong_reduce_pattern2(a: T.handle, b: T.handle) -> None:
-    A = T.match_buffer(a, (128, 128))
-    B = T.match_buffer(b, (128,))
-
-    with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
-        with T.init():
-            B[vi] = 0.0
-        B[vi] = B[vi] - A[vi, vk]
-
-
-@T.prim_func
-def rowsum_transformed(a: T.handle, b: T.handle) -> None:
-    A = T.match_buffer(a, (128, 128))
-    B = T.match_buffer(b, (128,))
-
-    for io, ii_ko_fused, ki in T.grid(32, 128, 4):
-        with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
-            T.bind(vi, io * 4 + T.floordiv(ii_ko_fused, 32))
-            T.bind(vk, T.floormod(ii_ko_fused, 32) * 4 + ki)
-            with T.init():
-                B[vi] = 0.0
-            B[vi] = B[vi] + A[vi, vk]
-
-
-@T.prim_func
-def rowsum_zero_dim(a: T.handle, b: T.handle) -> None:
-    A = T.match_buffer(a, [128])
-    B = T.match_buffer(b, [])
-
-    with T.block([T.reduce_axis(0, 128)], "B") as [k]:
-        with T.init():
-            B[()] = 0.0
-        B[()] = B[()] + A[k]
-
-
-@T.prim_func
-def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None:
-    A = T.match_buffer(a, [128])
-    B = T.match_buffer(b, [])
-    B_rf = T.alloc_buffer([128])
-
-    with T.block([128], "B_rf") as [vi0]:
-        with T.init():
-            B_rf[vi0] = 0.0
-        B_rf[vi0] = B_rf[vi0] + A[vi0]
-
-    with T.block([T.reduce_axis(0, 128)], "B") as [vi0_1]:
-        with T.init():
-            B[()] = 0.0
-        B[()] = B[()] + B_rf[vi0_1]
-
-
-@T.prim_func
-def multiple_reduction_blocks(a: T.handle, f: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16, 16))
-    C = T.alloc_buffer((16, 16))
-    D = T.alloc_buffer((16, 16))
-    E = T.alloc_buffer((16, 16))
-    F = T.match_buffer(f, (16, 16))
-
-    for i in T.serial(0, 16):
-        for j1 in T.serial(0, 16):
-            for k1o, k1i in T.grid(4, 4):
-                with T.block([16, 16, T.reduce_axis(0, 16)], "C") as [ci, cj, ck]:
-                    T.bind(ci, i)
-                    T.bind(cj, j1)
-                    T.bind(ck, k1o * 4 + k1i)
-                    with T.init():
-                        C[ci, cj] = 0.0
-                    C[ci, cj] = C[ci, cj] + A[ci, cj, ck]
-            for k2o, k2i in T.grid(4, 4):
-                with T.block([16, 16, T.reduce_axis(0, 16)], "D") as [di, dj, dk]:
-                    T.bind(di, i)
-                    T.bind(dj, j1)
-                    T.bind(dk, k2o * 4 + k2i)
-                    with T.init():
-                        D[di, dj] = 0.0
-                    D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj]
-        for j2 in T.serial(0, 16):
-            for k3o, k3i in T.grid(4, 4):
-                with T.block([16, 16, T.reduce_axis(0, 16)], "E") as [ei, ej, ek]:
-                    T.bind(ei, i)
-                    T.bind(ej, j2)
-                    T.bind(ek, k3o * 4 + k3i)
-                    with T.init():
-                        E[ei, ej] = 0.0
-                    E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej]
-            for k4o, k4i in T.grid(4, 4):
-                with T.block([16, 16, T.reduce_axis(0, 16)], "F") as [fi, fj, fk]:
-                    T.bind(fi, i)
-                    T.bind(fj, j2)
-                    T.bind(fk, k4o * 4 + k4i)
-                    with T.init():
-                        F[fi, fj] = 0.0
-                    F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj]
-
+def matmul_decompose_fail3(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128])
+    B = T.match_buffer(b, [128, 128])
+    C = T.match_buffer(c, [128, 128])
 
-@T.prim_func
-def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None:
-    A = T.match_buffer(a, [16, 16, 16])
-    C = T.alloc_buffer([16, 16])
-    D = T.alloc_buffer([16, 16])
-    E = T.alloc_buffer([16, 16])
-    F = T.match_buffer(f, [16, 16])
-    C_rf = T.alloc_buffer([16, 16, 4])
-
-    for i, j1, k1o, k1i in T.grid(16, 16, 4, 4):
-        with T.block([4, 16, 16, T.reduce_axis(0, 4)], "C_rf") as [vk1o, ci, cj, vk1i]:
-            T.bind(vk1o, k1o)
-            T.bind(ci, i)
-            T.bind(cj, j1)
-            T.bind(vk1i, k1i)
+    for i, k, j in T.grid(128, 128, 128):
+        with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
             with T.init():
-                C_rf[ci, cj, vk1o] = 0.0
-            C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, ((vk1o * 4) + vk1i)]
-    for i_1 in T.serial(0, 16):
-        for j1_1 in T.serial(0, 16):
-            for k1o_1 in T.serial(0, 4):
-                with T.block([T.reduce_axis(0, 4), 16, 16], "C") as [vk1o_1, ci_1, cj_1]:
-                    T.bind(vk1o_1, k1o_1)
-                    T.bind(ci_1, i_1)
-                    T.bind(cj_1, j1_1)
-                    with T.init():
-                        C[ci_1, cj_1] = 0.0
-                    C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1]
-            for k2o, k2i in T.grid(4, 4):
-                with T.block([16, 16, T.reduce_axis(0, 16)], "D") as [di, dj, dk]:
-                    T.bind(di, i_1)
-                    T.bind(dj, j1_1)
-                    T.bind(dk, (k2o * 4) + k2i)
-                    with T.init():
-                        D[di, dj] = 0.0
-                    D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj]
-        for j2 in T.serial(0, 16):
-            for k3o, k3i in T.grid(4, 4):
-                with T.block([16, 16, T.reduce_axis(0, 16)], "E") as [ei, ej, ek]:
-                    T.bind(ei, i_1)
-                    T.bind(ej, j2)
-                    T.bind(ek, (k3o * 4) + k3i)
-                    with T.init():
-                        E[ei, ej] = 0.0
-                    E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej]
-            for k4o, k4i in T.grid(4, 4):
-                with T.block([16, 16, T.reduce_axis(0, 16)], "F") as [fi, fj, fk]:
-                    T.bind(fi, i_1)
-                    T.bind(fj, j2)
-                    T.bind(fk, (k4o * 4) + k4i)
-                    with T.init():
-                        F[fi, fj] = 0.0
-                    F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj]
+                C[vi, vj] = 0.0
+            T.bind(vi, i)
+            T.bind(vj, j)
+            T.bind(vk, k)
+            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+@T.prim_func
+def matmul_decompose4(a: T.handle, b: T.handle, c: T.handle) -> None:
+    C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
+    B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1)
+    A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
+    # body
+    with T.block([], "root"):
+        T.reads([])
+        T.writes([])
+        for i0_0 in T.serial(0, 16):
+            for i0_1_init, i1_init in T.grid(8, 128):
+                with T.block([128, 128], "update_init") as [vi_init, vj_init]:
+                    T.bind(vi_init, ((i0_0 * 8) + i0_1_init))
+                    T.bind(vj_init, i1_init)
+                    C[vi_init, vj_init] = T.float32(0)
+            for i0_1, i1, i2_0, i2_1 in T.grid(8, 128, 19, 7):
+                with T.block([128, 128, T.reduce_axis(0, 128)], "update_update") as [
+                    vi,
+                    vj,
+                    vk,
+                ]:
+                    T.where((((i2_0 * 7) + i2_1) < 128))
+                    T.bind(vi, ((i0_0 * 8) + i0_1))
+                    T.bind(vj, i1)
+                    T.bind(vk, ((i2_0 * 7) + i2_1))
+                    C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
 
 
 # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
 
 
-def test_reduction_rfactor_matmul():
-    s = tir.Schedule(transformed_matmul, debug_mask="all")
-    update = s.get_block("update")
-    _, _, _, _, kii = s.get_loops(update)
-    rf_block = s.rfactor(kii, 0)
-    tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor)
-    assert s.get(rf_block).same_as(s.get(s.get_block("update_rf")))
-    assert s.get(update).same_as(s.get(s.get_block("update")))
-    verify_trace_roundtrip(s, mod=transformed_matmul)
-
-
-def test_reduction_rfactor_square_sum():
-    s = tir.Schedule(square_sum, debug_mask="all")
-    C = s.get_block("C")
-    _, _, j = s.get_loops(C)
-    rf_block = s.rfactor(j, 1)
-    tvm.ir.assert_structural_equal(s.mod["main"], square_sum_rfactor)
-    assert s.get(rf_block).same_as(s.get(s.get_block("C_rf")))
-    assert s.get(C).same_as(s.get(s.get_block("C")))
-    verify_trace_roundtrip(s, mod=square_sum)
-
-
-def test_reduction_rfactor_square_sum_square_root():
-    s = tir.Schedule(transformed_square_sum_square_root, debug_mask="all")
-    C = s.get_block("C")
-    _, _, f_i = s.get_loops(C)
-    rf_block = s.rfactor(f_i, 0)
-    tvm.ir.assert_structural_equal(s.mod["main"], square_sum_square_root_rfactor)
-    assert s.get(rf_block).same_as(s.get(s.get_block("C_rf")))
-    assert s.get(C).same_as(s.get(s.get_block("C")))
-    verify_trace_roundtrip(s, mod=transformed_square_sum_square_root)
-
-
-def test_reduction_rfactor_loop_multiple_children():
-    s = tir.Schedule(matmul_loop_multiple_children, debug_mask="all")
-    k, _, _ = s.get_loops(s.get_block("C"))
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(k, 0)
-
-
-def test_reduction_rfactor_not_stage_pipeline():
-    s = tir.Schedule(matmul_not_stage_pipeline, debug_mask="all")
-    _, _, k = s.get_loops(s.get_block("C"))
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(k, 0)
-
-
-def test_reduction_rfactor_not_reduction_block1():
-    s = tir.Schedule(element_wise, debug_mask="all")
-    i, _ = s.get_loops(s.get_block("B"))
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(i, 0)
-
-
-def test_reduction_rfactor_not_reduction_block2():
-    s = tir.Schedule(rowsum_not_quasi_affine, debug_mask="all")
-    _, k = s.get_loops(s.get_block("B"))
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(k, 0)
-
-
-def test_reduction_rfactor_not_reduction_block3():
-    s = tir.Schedule(rowsum_not_dominant, debug_mask="all")
-    _, k = s.get_loops(s.get_block("B"))
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(k, 0)
-
-
-def test_reduction_rfactor_not_serial_loop():
-    s = tir.Schedule(rowsum_not_serial, debug_mask="all")
-    _, k = s.get_loops(s.get_block("B"))
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(k, 0)
-
-
-def test_reduction_rfactor_not_same_buffer_access():
-    s = tir.Schedule(matmul_not_same_buffer_access, debug_mask="all")
-    _, _, k = s.get_loops(s.get_block("C"))
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(k, 0)
-
-
-def test_reduction_rfactor_factor_axis_range_fail():
-    s = tir.Schedule(transformed_matmul, debug_mask="all")
-    _, _, _, _, kii = s.get_loops(s.get_block("update"))
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(kii, 3)
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(kii, -4)
-
-
-def test_reduction_rfactor_factor_axis_range():
-    s = tir.Schedule(transformed_matmul, debug_mask="all")
-    update = s.get_block("update")
-    _, _, _, _, kii = s.get_loops(update)
-    rf_block = s.rfactor(kii, -3)
-    tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor)
-    assert s.get(rf_block).same_as(s.get(s.get_block("update_rf")))
-    assert s.get(update).same_as(s.get(s.get_block("update")))
-    verify_trace_roundtrip(s, mod=transformed_matmul)
-
-
-def test_reduction_rfactor_wrong_reduce_pattern1():
-    s = tir.Schedule(rowsum_wrong_reduce_pattern1, debug_mask="all")
-    _, k = s.get_loops(s.get_block("B"))
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(k, 0)
+def test_reduction_decompose0():
+    s = tir.Schedule(matmul, debug_mask="all")
+    C = s.get_block("update")
+    i, j, k = s.get_loops(C)
+    s.decompose_reduction(C, i)
+    tvm.ir.assert_structural_equal(matmul_decompose0, s.mod["main"])
+    verify_trace_roundtrip(s, mod=matmul)
 
 
-def test_reduction_rfactor_wrong_reduce_pattern2():
-    s = tir.Schedule(rowsum_wrong_reduce_pattern2, debug_mask="all")
-    _, k = s.get_loops(s.get_block("B"))
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(k, 0)
+def test_reduction_decompose1():
+    s = tir.Schedule(rowsum_blockized, debug_mask="all")
+    blockized_B = s.get_block("blockized_B")
+    io, ko = s.get_loops(blockized_B)
+    s.decompose_reduction(blockized_B, io)
+    tvm.ir.assert_structural_equal(matmul_decompose1, s.mod["main"])
+    verify_trace_roundtrip(s, mod=rowsum_blockized)
 
 
-def test_reduction_rfactor_wrong_loops1():
-    s = tir.Schedule(rowsum, debug_mask="all")
-    i, _ = s.get_loops(s.get_block("B"))
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(i, 0)
+def test_reduction_decompose2():
+    s = tir.Schedule(matmul, debug_mask="all")
+    C = s.get_block("update")
+    i, j, k = s.get_loops(C)
+    s.decompose_reduction(C, k)
+    tvm.ir.assert_structural_equal(matmul_decompose2, s.mod["main"])
+    verify_trace_roundtrip(s, mod=matmul)
 
 
-def test_reduction_rfactor_wrong_loops2():
-    s = tir.Schedule(rowsum_transformed, debug_mask="all")
-    _, _, k_i = s.get_loops(s.get_block("B"))
+def test_reduction_decompose3():
+    s = tir.Schedule(matmul_decompose_fail3, debug_mask="all")
+    C = s.get_block("update")
+    i, j, k = s.get_loops(C)
     with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(k_i, 0)
+        s.decompose_reduction(C, k)
 
 
-def test_reduction_rfactor_zero_dim():
-    s = tir.Schedule(rowsum_zero_dim, debug_mask="all")
-    B = s.get_block("B")
-    (k,) = s.get_loops(B)
-    rf_block = s.rfactor(k, 0)
-    tvm.ir.assert_structural_equal(s.mod["main"], rowsum_zero_dim_rfactor)
-    assert s.get(rf_block).same_as(s.get(s.get_block("B_rf")))
-    assert s.get(B).same_as(s.get(s.get_block("B")))
-    verify_trace_roundtrip(s, mod=rowsum_zero_dim)
-
-
-def test_reduction_rfactor_outermost_loop_multiple_children_fail():  # pylint: disable=invalid-name
-    s = tir.Schedule(multiple_reduction_blocks, debug_mask="all")
-    _, _, k2o, k2i = s.get_loops(s.get_block("D"))
-    _, _, k3o, k3i = s.get_loops(s.get_block("E"))
-    _, _, k4o, k4i = s.get_loops(s.get_block("F"))
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(k2o, 0)
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(k2i, 0)
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(k3o, 0)
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(k3i, 0)
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(k4o, 0)
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.rfactor(k4i, 0)
-
-
-def test_reduction_rfactor_outermost_loop_multiple_children():  # pylint: disable=invalid-name
-    s = tir.Schedule(multiple_reduction_blocks, debug_mask="all")
-    C = s.get_block("C")
-    _, _, k1o, _ = s.get_loops(C)
-    rf_block = s.rfactor(k1o, 2)
-    tvm.ir.assert_structural_equal(s.mod["main"], multiple_reduction_blocks_rfactor)
-    assert s.get(rf_block).same_as(s.get(s.get_block("C_rf")))
-    assert s.get(C).same_as(s.get(s.get_block("C")))
-    verify_trace_roundtrip(s, mod=multiple_reduction_blocks)
+def test_reduction_decompose4():
+    s = tir.Schedule(matmul, debug_mask="all")
+    C = s.get_block("update")
+    i, j, k = s.get_loops(C)
+    io, ii = s.split(i, factors=[16, 8])
+    ko, ki = s.split(k, factors=[19, 7])
+    s.decompose_reduction(C, ii)
+    tvm.ir.assert_structural_equal(matmul_decompose4, s.mod["main"])
+    verify_trace_roundtrip(s, mod=matmul)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_rfactor.py
similarity index 100%
copy from tests/python/unittest/test_tir_schedule_reduction.py
copy to tests/python/unittest/test_tir_schedule_rfactor.py