You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/07/13 15:57:09 UTC

[tvm] branch main updated: [TIR][Schedule] Refactor Tensorize (#12070)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7d9a07ccc7 [TIR][Schedule] Refactor Tensorize (#12070)
7d9a07ccc7 is described below

commit 7d9a07ccc70eef951bcfff0333c2f82cdc6a3b12
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Wed Jul 13 08:57:01 2022 -0700

    [TIR][Schedule] Refactor Tensorize (#12070)
    
    * Refactor blockize
    
    * Refactor tensorize
    
    * Address review comments
    
    * typo
    
    * rename variables according to review
---
 src/tir/schedule/primitive/blockize_tensorize.cc   | 853 ++++++++++-----------
 .../python/unittest/test_tir_schedule_blockize.py  | 322 ++++----
 2 files changed, 580 insertions(+), 595 deletions(-)

diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc
index 4ede2dd90d..9c3029ebf5 100644
--- a/src/tir/schedule/primitive/blockize_tensorize.cc
+++ b/src/tir/schedule/primitive/blockize_tensorize.cc
@@ -24,6 +24,20 @@
 namespace tvm {
 namespace tir {
 
+template <class T>
+bool UsesVar(const T& x, const Var& var) {
+  return UsesVar(x, [tgt = var.get()](const VarNode* v) { return v == tgt; });
+}
+
+Range RangeFromExtent(const PrimExpr& extent) {
+  return Range::FromMinExtent(make_zero(extent->dtype), extent);
+}
+
+template <class T>
+T DeepCopy(const T& stmt) {
+  return Downcast<T>(LoadJSON(SaveJSON(stmt)));
+}
+
 /*!
  * \brief ScheduleError that the bindings of the inner block are not divisible by the subspace
  * represented by the outer loops.
@@ -64,16 +78,16 @@ class SubspaceNotDivisibleError : public ScheduleError {
  *
  * \param iter_vars The input iterators
  * \param bindings The values of iter_vars
- * \param outer_loops Iterators outside the subspace.
- * \param inner_loops Iterators of the subspace
  * \param predicate The predicate constraint on the input iterators.
+ * \param outer_iters The iters of the outer space
+ * \param inner_iters The iters of the inner space
  * \return The result of the subspace division.
  */
 Array<Array<arith::IterMark>> TrivialSubspaceDivision(const Array<IterVar>& iter_vars,
                                                       const Array<PrimExpr>& bindings,
+                                                      const PrimExpr& predicate,
                                                       const Array<Var>& outer_iters,
-                                                      const Array<Var>& inner_iters,
-                                                      const PrimExpr& predicate) {
+                                                      const Array<Var>& inner_iters) {
   if (!is_one(predicate)) return {};
   Array<Array<arith::IterMark>> res;
   std::unordered_set<const VarNode*> outer_loop_vars;
@@ -95,7 +109,7 @@ Array<Array<arith::IterMark>> TrivialSubspaceDivision(const Array<IterVar>& iter
   auto use_inner_loop_vars = make_uses_var(inner_iters);
   arith::IterMark unit_iter_mark(arith::IterSumExpr({}, 0), 1);
 
-  for (size_t i = 0; i < bindings.size(); ++i) {
+  for (int i = 0, n = bindings.size(); i < n; ++i) {
     bool outer = use_outer_loop_vars(bindings[i]);
     bool inner = use_inner_loop_vars(bindings[i]);
     arith::IterMark iter_mark;
@@ -122,531 +136,462 @@ Array<Array<arith::IterMark>> TrivialSubspaceDivision(const Array<IterVar>& iter
 }
 
 /*!
- * \brief Generate the blockized init block.
- * \param block The original block with init.
- * \param inner_block_realize The block realize of the inner block after blockize.
- * \param inner_loops The inner loops after blockize.
- * \return The subtree of the init block and its outer loops.
+ * \brief Subspace division. The space is divided into two subspaces:
+ *  1. The subspace represented by the outer loops above `loop_sref` (exclusive).
+ *  2. The subspace represented by the inner loops below `loop_sref` (inclusive).
+ * \param realize The inner block
+ * \param block_sref The sref to the inner block
+ * \param loop_sref The loop that is the root of the second subspace.
+ * \param loops The loops that represents the second part of the subspace.
+ * \param analyzer The arithmetic analyzer to use.
  */
-Stmt GenerateBlockizedInit(const Block& block, const BlockRealize& inner_block_realize,
-                           const std::vector<const ForNode*>& inner_loops) {
-  Array<IterVar> init_block_iters;
-  Array<PrimExpr> init_bindings;
-  const Block& inner_block = inner_block_realize->block;
-
-  // Step 1: Collect data-parallel block iters
-  for (size_t i = 0; i < inner_block->iter_vars.size(); i++) {
-    const IterVar& iter_var = inner_block->iter_vars[i];
-    const PrimExpr& binding = inner_block_realize->iter_values[i];
-    if (iter_var->iter_type == IterVarType::kDataPar &&
-        UsesVar(block->init.value(),
-                [tgt_var = iter_var->var.get()](const VarNode* var) { return var == tgt_var; })) {
-      init_block_iters.push_back(iter_var);
-      init_bindings.push_back(binding);
+Array<Array<arith::IterMark>> SubspaceDivide(const BlockRealize& realize,
+                                             const StmtSRef& block_sref,  //
+                                             const StmtSRef& loop_sref,   //
+                                             std::vector<const ForNode*>* loops,
+                                             arith::Analyzer* analyzer) {
+  Array<Var> inner_vars;
+  Array<Var> outer_vars;
+  Map<Var, Range> loop_var_domain;
+  bool inner = true;
+  for (StmtSRefNode* sref = block_sref->parent;    //
+       sref && sref->stmt->IsInstance<ForNode>();  //
+       sref = sref->parent) {
+    const ForNode* loop = static_cast<const ForNode*>(sref->stmt);
+    if (inner) {
+      loops->push_back(loop);
+      inner_vars.push_back(loop->loop_var);
+    } else {
+      outer_vars.push_back(loop->loop_var);
     }
-  }
-
-  // Step 2: Collect loops related to iters of the init block
-  std::vector<const ForNode*> init_loops;
-  for (const ForNode* inner_loop : inner_loops) {
-    for (const PrimExpr& init_binding : init_bindings) {
-      if (UsesVar(init_binding, [tgt_var = inner_loop->loop_var.get()](const VarNode* var) {
-            return var == tgt_var;
-          })) {
-        init_loops.push_back(inner_loop);
-        break;
-      }
+    loop_var_domain.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
+    if (sref == loop_sref.get()) {
+      inner = false;
     }
   }
-
-  // Step 3: Create new block iters for the init block
-  Map<Var, PrimExpr> subst_map;
-  for (size_t i = 0; i < init_block_iters.size(); i++) {
-    IterVar new_iter_var = init_block_iters[i];
-    Var old_var = new_iter_var->var;
-    Var new_var = old_var.copy_with_suffix("_init");
-    new_iter_var.CopyOnWrite()->var = new_var;
-    subst_map.Set(old_var, new_var);
-    init_block_iters.Set(i, std::move(new_iter_var));
-  }
-
-  // Step 4: Generate loop nests and the init block
-  Stmt new_init = BlockRealize(
-      /*iter_values=*/init_bindings,
-      /*predicate=*/inner_block_realize->predicate,
-      /*block=*/
-      Block{/*iter_vars=*/init_block_iters,
-            /*reads=*/{},
-            /*writes=*/block->writes,
-            /*name_hint=*/block->name_hint + "_init",
-            /*body=*/block->init.value(),
-            /*init=*/NullOpt});
-
-  // Step 5: Generate the parent loops for the init block
-  for (const ForNode* init_loop : init_loops) {
-    ObjectPtr<ForNode> new_loop = make_object<ForNode>(*init_loop);
-    new_loop->loop_var = init_loop->loop_var.copy_with_suffix("");
-    subst_map.Set(init_loop->loop_var, new_loop->loop_var);
-    new_loop->body = std::move(new_init);
-    new_init = For(new_loop);
+  Array<Array<arith::IterMark>> result =
+      arith::SubspaceDivide(realize->iter_values, loop_var_domain, inner_vars, realize->predicate,
+                            arith::IterMapLevel::Surjective, analyzer);
+  if (!result.empty()) {
+    return result;
   }
-
-  // Step 6: Substitute with new loop variables and block iters to prevent duplication of
-  // variables in the outer block.
-  new_init = Substitute(new_init, subst_map);
-
-  return new_init;
+  return TrivialSubspaceDivision(realize->block->iter_vars,
+                                 realize->iter_values,  //
+                                 realize->predicate,    //
+                                 outer_vars, inner_vars);
 }
 
 /*!
- * \brief A helper to collect the parent loops of the block. The loops are divided into two groups,
- * 'outer_loops', and 'inner_loops', by a specified loop as the separator. 'outer_loops' are the
- * ancestor loops of the separator loop. 'inner_loops' include the separator loop itself, and its
- * successor loops. It is possible that 'outer_loops' is empty.
+ * \brief Derive the block bindings for both inner and outer block
+ * \param iter_vars The original block iterators to the inner block
+ * \param division The subspace division.
+ * \param outer_iter_vars The outer block iterators.
+ * \param outer_bindings The outer block bindings.
+ * \param inner_iter_vars The inner block iterators.
+ * \param inner_bindings The inner block bindings.
+ * \return A substitution plan to the iterators in the original inner block.
  */
-class LoopSubspaceCollector {
- public:
-  /*!
-   * \brief Collect the parent loops of the block and store the result in the corresponding fields.
-   * \param block_sref The sref to the target block.
-   * \param loop_sref The sref to the separator loop. The loop itself is counted as an inner loop.
-   */
-  void Collect(const StmtSRef& block_sref, const StmtSRef& loop_sref) {
-    bool inner = true;
-    for (StmtSRefNode* current_sref = block_sref->parent;
-         current_sref && current_sref->stmt->IsInstance<ForNode>();
-         current_sref = current_sref->parent) {
-      const auto* current_loop = current_sref->StmtAs<ForNode>();
-      ICHECK(current_loop);
-      if (inner) {
-        inner_loops.push_back(current_loop);
-        inner_loop_vars.push_back(current_loop->loop_var);
-      } else {
-        outer_loops.push_back(current_loop);
-        outer_loop_vars.push_back(current_loop->loop_var);
-      }
-      loop_var_domain.Set(current_loop->loop_var,
-                          Range::FromMinExtent(current_loop->min, current_loop->extent));
-      if (current_sref == loop_sref.get()) inner = false;
+Map<Var, PrimExpr> DeriveBlockBinding(const Array<IterVar>& iter_vars,                //
+                                      const Array<Array<arith::IterMark>>& division,  //
+                                      Array<IterVar>* outer_iter_vars,                //
+                                      Array<PrimExpr>* outer_bindings,                //
+                                      Array<IterVar>* inner_iter_vars,                //
+                                      Array<PrimExpr>* inner_bindings) {
+  using arith::IterMapExpr;
+  using arith::IterMapExprNode;
+  using arith::NormalizeIterMapToExpr;
+  Map<Var, PrimExpr> block_var_subst;
+  ICHECK_EQ(iter_vars.size() + 1, division.size());
+  for (int i = 0, n = iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = iter_vars[i];
+    arith::IterMark outer_mark = division[i][0];
+    arith::IterMark inner_mark = division[i][1];
+    IterMapExpr outer_binding = Downcast<IterMapExpr>(outer_mark->source);
+    IterMapExpr inner_binding = Downcast<IterMapExpr>(inner_mark->source);
+    // After computing the subspace division, bindings[i] can be written as
+    // outer_binding * inner_binding->extent + inner_binding
+    // The outer block will have binding: iter_outer -> outer_binding
+    // The inner block will have binding: iter_inner -> inner_binding
+    // The iter in the original block will be substituted with base + iter_inner where
+    // base == iter_outer * iter_inner_extent
+    if (is_one(inner_mark->extent)) {  // IsOuter
+      // extract this iter var to outer block directly
+      outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding));
+      outer_iter_vars->push_back(iter_var);
+      continue;
     }
+    // create iter var for the outer block
+    IterVar outer_iter(/*dom=*/RangeFromExtent(outer_mark->extent),
+                       /*var=*/iter_var->var.copy_with_suffix("_o"),
+                       /*iter_type=*/iter_var->iter_type);
+    outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding));
+    outer_iter_vars->push_back(outer_iter);
+    // create iter var for the inner block
+    IterVar inner_iter(/*dom=*/RangeFromExtent(inner_mark->extent),
+                       /*var=*/iter_var->var.copy_with_suffix("_i"),
+                       /*iter_type=*/iter_var->iter_type);
+    inner_bindings->push_back(NormalizeIterMapToExpr(inner_binding));
+    inner_iter_vars->push_back(inner_iter);
+    // substitution
+    PrimExpr sub{nullptr};
+    if (is_one(outer_mark->extent)) {
+      sub = inner_iter->var;
+    } else {
+      sub = outer_iter * inner_mark->extent + inner_iter->var;
+    }
+    block_var_subst.Set(iter_var->var, sub);
   }
-  /*! \brief Outer loops which are ancestors of the separator. */
-  std::vector<const ForNode*> outer_loops;
-  /*! \brief Inner loops which are the separator itself or its successors. */
-  std::vector<const ForNode*> inner_loops;
-  /*! \brief Loop variables of the outer loops. */
-  Array<Var> outer_loop_vars;
-  /*! \brief Loop variables of the inner loops. */
-  Array<Var> inner_loop_vars;
-  /*! \brief Domain of the loop variables. */
-  Map<Var, Range> loop_var_domain;
-};
+  return block_var_subst;
+}
 
 /*!
- * \brief Check the bindings of the block iters can be divided by a subspace collected by the
- * collector.
- * \param mod The current IR module.
- * \param block_realize The block realize to be checked.
- * \param collector The collector which has collected the loops of the block.
- * \param analyzer The arithmetic analyzer.
- * \return The result of the subspace division.
- * \throws ScheduleError If the bindings are not divisible by the subspace.
+ * \brief Generate the inner block for blockization
+ * \param is_write_reduction Whether the write regions of the inner block are actually reduction.
+ * \param iter_vars IterVars used in the inner block.
+ * \param iter_values IterVar bindings used in the inner block.
+ * \param predicate The predicate of the inner block.
+ * \param block The inner block as a template to be created from. This method will modify its
+ * `iter_vars`, `init` and `reads` fields.
+ * \return The inner block created.
  */
-Array<Array<arith::IterMark>> CheckSubspaceDivisible(const IRModule& mod,
-                                                     const BlockRealize& block_realize,
-                                                     const LoopSubspaceCollector& collector,
-                                                     arith::Analyzer* analyzer) {
-  const Block& block = block_realize->block;
-
-  Array<Array<arith::IterMark>> division = arith::SubspaceDivide(
-      block_realize->iter_values, collector.loop_var_domain, collector.inner_loop_vars,
-      block_realize->predicate, arith::IterMapLevel::Surjective, analyzer);
-
-  if (division.empty()) {
-    // If we can't do perfect subspace division, check if it is a trivial case of subspace division.
-    // In this case, we can still blockize.
-    division = TrivialSubspaceDivision(block->iter_vars, block_realize->iter_values,
-                                       collector.outer_loop_vars, collector.inner_loop_vars,
-                                       block_realize->predicate);
-  }
-  if (division.empty()) {
-    throw SubspaceNotDivisibleError(mod, GetRef<For>(collector.inner_loops.back()), block);
+BlockRealize GenerateInner(bool is_write_reduction,
+                           const Array<IterVar>& iter_vars,     //
+                           const Array<PrimExpr>& iter_values,  //
+                           const PrimExpr& predicate,           //
+                           Block block) {
+  BlockNode* n = block.CopyOnWrite();
+  n->iter_vars = iter_vars;
+  n->init = NullOpt;
+  if (is_write_reduction) {
+    Array<BufferRegion> reads;
+    reads.reserve(block->writes.size() + block->reads.size());
+    reads.insert(reads.end(), block->writes.begin(), block->writes.end());
+    reads.insert(reads.end(), block->reads.begin(), block->reads.end());
+    n->reads = std::move(reads);
   }
-  return division;
+  return BlockRealize(/*iter_values=*/iter_values, /*predicate=*/predicate,
+                      /*block=*/block);
 }
 
 /*!
- * \brief The binding extractor to compute the bindings of the outer and the inner blocks after
- * blockize.
+ * \brief Generate the init stmt for the outer block
+ * \param block The original block with init.
+ * \param inner_realize The block realize of the inner block after blockize.
+ * \param loops The inner loops after blockize.
+ * \return The subtree of the init block and its outer loops.
  */
-class BlockizedBindingExtractor {
- public:
-  /*!
-   * \brief Extract bindings for blockize.
-   * \param iter_vars The iter vars of the original inner block.
-   * \param division The result of the subspace division.
-   */
-  void ExtractBindings(const Array<IterVar>& iter_vars,
-                       const Array<Array<arith::IterMark>>& division, arith::Analyzer* analyzer) {
-    ICHECK_EQ(iter_vars.size() + 1, division.size());
-    for (size_t i = 0; i < iter_vars.size(); ++i) {
-      const IterVar& iter_var = iter_vars[i];
-      arith::IterMark outer_mark = division[i][0];
-      arith::IterMark inner_mark = division[i][1];
-      const auto* outer_binding =
-          TVM_TYPE_AS(outer_binding, outer_mark->source, arith::IterMapExprNode);
-      const auto* inner_binding =
-          TVM_TYPE_AS(inner_binding, inner_mark->source, arith::IterMapExprNode);
-
-      // After computing the subspace division, bindings[i] can be written as
-      // outer_binding * inner_binding->extent + inner_binding
-      // The outer block will have binding: iter_outer -> outer_binding
-      // The inner block will have binding: iter_inner -> inner_binding
-      // The iter in the original block will be substituted with base + iter_inner where
-      // base == iter_outer * iter_inner_extent
-
-      if (is_one(division[i][1]->extent)) {  // IsOuter
-        // extract this iter var to outer block directly
-        outer_bindings.push_back(
-            arith::NormalizeIterMapToExpr(GetRef<arith::IterMapExpr>(outer_binding)));
-        outer_iter_vars.push_back(iter_var);
-      } else {
-        // create iter var for the outer block
-        const IterVar outer_var(/*dom=*/Range::FromMinExtent(0, division[i][0]->extent),
-                                /*var=*/iter_var->var.copy_with_suffix("_o"),
-                                /*iter_type=*/iter_var->iter_type);
-        outer_bindings.push_back(
-            arith::NormalizeIterMapToExpr(GetRef<arith::IterMapExpr>(outer_binding)));
-        outer_iter_vars.push_back(outer_var);
-        PrimExpr base = is_one(division[i][0]->extent) ? 0 : outer_var * division[i][1]->extent;
-        // create iter var for the inner block
-        IterVar new_iter(Range::FromMinExtent(0, division[i][1]->extent), Var(iter_var->var),
-                         iter_var->iter_type, iter_var->thread_tag, iter_var->span);
-        inner_iter_dom_map.Set(new_iter->var, arith::IntSet::FromRange(new_iter->dom));
-        analyzer->Bind(new_iter->var, new_iter->dom);
-        inner_iter_vars.push_back(new_iter);
-        inner_bindings.push_back(
-            arith::NormalizeIterMapToExpr(GetRef<arith::IterMapExpr>(inner_binding)));
-        inner_iter_subst_map.Set(iter_var->var, base + new_iter->var);
+Stmt GenerateOuterInit(const Stmt& block_init, const BlockRealize& inner_realize,
+                       const std::vector<const ForNode*>& loops, String block_name) {
+  const Block& inner_block = inner_realize->block;
+  Map<Var, PrimExpr> subst_map;
+  // Step 1: Create new block vars for the block inside the init stmt of outer block
+  // A iter is used in the block if
+  // 1) It is data parallel
+  // 2) It is used in the original init block
+  Array<IterVar> iter_vars;
+  Array<PrimExpr> iter_values;
+  ICHECK_EQ(inner_block->iter_vars.size(), inner_realize->iter_values.size());
+  int n = inner_block->iter_vars.size();
+  iter_vars.reserve(n);
+  iter_values.reserve(n);
+  for (int i = 0; i < n; ++i) {
+    const IterVar& old_iter_var = inner_block->iter_vars[i];
+    const PrimExpr& iter_value = inner_realize->iter_values[i];
+    if (old_iter_var->iter_type == IterVarType::kDataPar &&
+        UsesVar(block_init, old_iter_var->var)) {
+      ObjectPtr<IterVarNode> new_iter_var = make_object<IterVarNode>(*old_iter_var.get());
+      new_iter_var->var = new_iter_var->var.copy_with_suffix("_init");
+      subst_map.Set(old_iter_var->var, new_iter_var->var);
+      iter_vars.push_back(IterVar(new_iter_var));
+      iter_values.push_back(iter_value);
+    }
+  }
+  // Step 2: Generate the block inside init stmt of outer block
+  Stmt stmt = BlockRealize(
+      /*iter_values=*/iter_values,
+      /*predicate=*/inner_realize->predicate,
+      /*block=*/
+      Block(/*iter_vars=*/iter_vars,
+            /*reads=*/{},
+            /*writes=*/inner_block->writes,
+            /*name_hint=*/block_name,
+            /*body=*/block_init,
+            /*init=*/NullOpt));
+  // Step 3. Create the loop nest on top of the block
+  for (const ForNode* loop : loops) {
+    bool is_init_loop = false;
+    for (const PrimExpr& init_binding : iter_values) {
+      if (UsesVar(init_binding, loop->loop_var)) {
+        is_init_loop = true;
+        break;
       }
     }
+    if (is_init_loop) {
+      ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop);
+      new_loop->loop_var = loop->loop_var.copy_with_suffix("");
+      new_loop->body = std::move(stmt);
+      subst_map.Set(loop->loop_var, new_loop->loop_var);
+      stmt = For(new_loop);
+    }
   }
-  Map<Var, PrimExpr> inner_iter_subst_map;
-  /*! \brief Iters of the outer block. */
-  Array<IterVar> outer_iter_vars;
-  /*! \brief Iters of the outer block. */
-  Array<IterVar> inner_iter_vars;
-  /*! \brief Binding values of the outer block. */
-  Array<PrimExpr> outer_bindings;
-  /*! \brief Binding values of the inner block. */
-  Array<PrimExpr> inner_bindings;
-  /*! \brief The domain of the inner block iters. */
-  Map<Var, arith::IntSet> inner_iter_dom_map;
-};
+  // Step 4: Substitute the iter vars and loop vars
+  return Substitute(stmt, subst_map);
+}
 
 /*!
- * \brief Replacer for the inner block after blockize. Inner block iters will be replaced with
- * base + inner_iter and the expressions after substituion will be simplified if possible.
+ * \brief Substitute variables in the stmt, do simplification and track block substitution
+ * \param stmt The stmt to be substituted.
+ * \param sub The substitution map.
+ * \param block_sref_reuse The block substitution happens during the substitution.
+ * \param analyzer The analyzer for arithmetic simplification.
+ * \return The substituted stmt.
  */
-class InnerIterReplacer : public StmtExprMutator {
- public:
-  /*!
-   * \brief The constructor
-   * \param subst_map The substitution map of the inner block iters.
-   * \param analyzer The arithmetic analyzer.
-   * \param block_sref_reuse The map to save the block reuse information.
-   */
-  InnerIterReplacer(Map<Var, PrimExpr> subst_map, arith::Analyzer* analyzer,
-                    Map<Block, Block>* block_sref_reuse)
-      : subst_map_(std::move(subst_map)),
-        analyzer_(analyzer),
-        block_sref_reuse_(block_sref_reuse) {}
-
-  PrimExpr VisitExpr_(const VarNode* op) final {
-    auto it = subst_map_.find(GetRef<Var>(op));
-    if (it != subst_map_.end()) {
-      return (*it).second;
+Stmt Substitute(const Stmt& stmt, const Map<Var, PrimExpr>& sub,
+                Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer) {
+  struct Replacer : public StmtExprMutator {
+    explicit Replacer(const Map<Var, PrimExpr>& sub, Map<Block, Block>* block_sref_reuse,
+                      arith::Analyzer* analyzer)
+        : sub_(sub), block_sref_reuse_(block_sref_reuse), analyzer_(analyzer) {}
+
+    PrimExpr VisitExpr(const PrimExpr& op) final {
+      PrimExpr result = StmtExprMutator::VisitExpr(op);
+      if (!result.same_as(op)) {
+        return analyzer_->Simplify(result);
+      }
+      return result;
     }
-    return StmtExprMutator::VisitExpr_(op);
-  }
 
-  PrimExpr VisitExpr(const PrimExpr& op) final {
-    PrimExpr result = StmtExprMutator::VisitExpr(op);
-    if (!result.same_as(op)) {
-      return analyzer_->Simplify(result);
+    PrimExpr VisitExpr_(const VarNode* op) final {
+      if (Optional<PrimExpr> e = sub_.Get(GetRef<Var>(op))) {
+        return e.value();
+      }
+      return StmtExprMutator::VisitExpr_(op);
     }
-    return result;
-  }
 
-  Stmt VisitStmt_(const BlockNode* op) final {
-    Stmt result = StmtExprMutator::VisitStmt_(op);
-    if (!result.same_as(GetRef<Stmt>(op))) {
-      block_sref_reuse_->Set(GetRef<Block>(op), Downcast<Block>(result));
+    Stmt VisitStmt_(const BlockNode* op) final {
+      Block src = GetRef<Block>(op);
+      Block tgt = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
+      if (!src.same_as(tgt)) {
+        block_sref_reuse_->Set(src, tgt);
+      }
+      return tgt;
     }
-    return result;
-  }
 
- private:
-  Map<Var, PrimExpr> subst_map_;
-  arith::Analyzer* analyzer_;
-  Map<Block, Block>* block_sref_reuse_;
-};
+    const Map<Var, PrimExpr>& sub_;
+    Map<Block, Block>* block_sref_reuse_;
+    arith::Analyzer* analyzer_;
+  };
+  return Replacer(sub, block_sref_reuse, analyzer)(stmt);
+}
 
 /*!
- * \brief Compute the access region of the outer block by relaxing the inner loops.
- * \param buffer_region The original buffer region.
- * \param The range of the inner loops.
- * \return The new buffer region.
+ * \brief Relax the variables for the given regions
+ * \param regions The regions to be relaxed.
+ * \param dom_map The variables to be relaxed
+ * \return The relaxed regions
  */
-BufferRegion RelaxBlockizedInnerIters(const BufferRegion& buffer_region,
-                                      const Map<Var, arith::IntSet>& inner_iter_relaxed_range) {
-  Array<Range> new_region;
-  new_region.reserve(buffer_region->region.size());
-  Array<arith::IntSet> relaxed_int_set =
-      arith::EvalSet(buffer_region->region, inner_iter_relaxed_range);
-  ICHECK(buffer_region->region.size() == buffer_region->buffer->shape.size());
-  for (size_t i = 0; i < buffer_region->region.size(); i++) {
-    Range max_range = Range::FromMinExtent(0, buffer_region->buffer->shape[i]);
-    new_region.push_back(relaxed_int_set[i].CoverRange(max_range));
+Array<BufferRegion> EvalSetRegions(const Array<BufferRegion>& regions,
+                                   const Map<Var, arith::IntSet>& dom_map) {
+  Array<BufferRegion> results;
+  results.reserve(regions.size());
+  for (const BufferRegion& buffer_region : regions) {
+    const Buffer& buffer = buffer_region->buffer;
+    Array<arith::IntSet> relaxed = arith::EvalSet(buffer_region->region, dom_map);
+    ICHECK_EQ(relaxed.size(), buffer->shape.size());
+    int ndim = buffer->shape.size();
+    Array<Range> new_region;
+    new_region.reserve(ndim);
+    for (int i = 0; i < ndim; ++i) {
+      new_region.push_back(relaxed[i].CoverRange(RangeFromExtent(buffer->shape[i])));
+    }
+    results.push_back(BufferRegion(buffer, new_region));
   }
-  return BufferRegion(buffer_region->buffer, std::move(new_region));
+  return results;
 }
 
 /*!
- * \brief Generate the outer block after blockize.
- * \param extractor The binding extractor which has extracted the blockized bindings.
- * \param block The original inner block.
- * \param inner_block_realize The block realize of the inner block after blockize.
- * \param inner_loops The inner loops after blockize.
- * \param predicate The outer predicate of the subspace division.
- * \return The block realize of the outer block after blockize.
+ * \brief Create the loop nest on top of the given stmt.
+ * \param stmt The stmt to be wrapped.
+ * \param loops The loop nests
+ * \return The wrapped stmt.
  */
-BlockRealize GenerateBlockizedOuterBlock(const BlockizedBindingExtractor& extractor,
-                                         const Block& block, BlockRealize inner_block_realize,
-                                         const std::vector<const ForNode*>& inner_loops,
-                                         PrimExpr predicate) {
-  // Step 1: Generate the init block if needed
-  Optional<Stmt> new_init = NullOpt;
-  if (block->init.defined()) {
-    new_init = GenerateBlockizedInit(block, inner_block_realize, inner_loops);
-  }
-
-  // Step 2: Compute the access regions of the outer block by relaxing the inner loops
-  Array<BufferRegion> new_reads = block->reads;
-  Array<BufferRegion> new_writes = block->writes;
-
-  auto f_mutate = [&](const BufferRegion& buffer_region) {
-    return RelaxBlockizedInnerIters(buffer_region, extractor.inner_iter_dom_map);
-  };
-  new_reads.MutateByApply(f_mutate);
-  new_writes.MutateByApply(f_mutate);
-
-  // Step 3: Generate the body of the outer block. The body of the outer block is the inner block
-  // realize and its surrounding loops.
-  Stmt outer_block_body = inner_block_realize;
-  for (const ForNode* loop : inner_loops) {
+Stmt MakeLoopNest(Stmt stmt, const std::vector<const ForNode*>& loops) {
+  for (const ForNode* loop : loops) {
     ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop);
-    new_loop->body = std::move(outer_block_body);
-    outer_block_body = For(new_loop);
+    new_loop->body = std::move(stmt);
+    stmt = For(new_loop);
   }
-
-  // Step 4: Generate the outer block and block realize.
-  return BlockRealize(/*iter_values=*/std::move(extractor.outer_bindings),
-                      /*predicate=*/std::move(predicate),
-                      /*block=*/
-                      Block(/*iter_vars=*/std::move(extractor.outer_iter_vars),  //
-                            /*reads=*/std::move(new_reads),                      //
-                            /*writes=*/std::move(new_writes),                    //
-                            /*name_hint=*/block->name_hint + "_o",               //
-                            /*body=*/std::move(outer_block_body),                //
-                            /*init=*/std::move(new_init)));
+  return stmt;
 }
 
-StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) {
+BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref,
+                          Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer) {
   const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
-  arith::Analyzer analyzer;
-
-  // Step 1: Check the loop has a single child BlockRealize on the sref tree.
+  // Step 1: Check and get the only block under `loop`.
   BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, loop_sref);
   Block block = block_realize->block;
   StmtSRef block_sref = self->stmt2ref.at(block.get());
-
-  // Step 2: Collect loops inside and outside loop_sref.
-  LoopSubspaceCollector collector;
-  collector.Collect(block_sref, loop_sref);
-
-  // Step 3: Calculate subspace division for the inner loops.
+  // Step 2: Derive subspace division
+  std::vector<const ForNode*> loops;
   Array<Array<arith::IterMark>> division =
-      CheckSubspaceDivisible(self->mod, block_realize, collector, &analyzer);
-
-  // Step 4: Generate bindings for the outer block and the inner block based on the result of
-  // the subspace division.
-  BlockizedBindingExtractor extractor;
-  extractor.ExtractBindings(block->iter_vars, division, &analyzer);
-  const PrimExpr& outer_pred = division.back()[0]->extent;
-  const PrimExpr& inner_pred = division.back()[1]->extent;
-
-  // Step 5: Substitute the iter vars in the original block with the inner iters after the subspace
-  // division
-  Map<Block, Block> block_sref_reuse;
-  InnerIterReplacer replacer(std::move(extractor.inner_iter_subst_map), &analyzer,
-                             &block_sref_reuse);
-  Block new_block = Downcast<Block>(replacer(block));
-
-  // Step 6: Generate the inner block.
-  bool outer_reduction = false;  // whether there are outer reduction iter vars.
-  for (const IterVar& iter_var : extractor.outer_iter_vars) {
-    if (iter_var->iter_type == kCommReduce) {
-      outer_reduction = true;
-    }
+      SubspaceDivide(block_realize, block_sref, loop_sref, &loops, analyzer);
+  if (division.empty()) {
+    throw SubspaceNotDivisibleError(self->mod, GetRef<For>(loops.back()), block);
   }
-  BlockRealizeNode* inner_block_realize = block_realize.CopyOnWrite();
-  inner_block_realize->iter_values = extractor.inner_bindings;
-  inner_block_realize->predicate = inner_pred;
-  inner_block_realize->block = new_block;
-  BlockNode* inner_block = inner_block_realize->block.CopyOnWrite();
-  inner_block->iter_vars = extractor.inner_iter_vars;
-  inner_block->init = NullOpt;
-  /* Add write regions to read regions if
-   * 1. there are outer reduction iter vars.
-   * 2. the init block is defined for current block.
-   */
-  if (outer_reduction && block->init.defined()) {
-    Array<BufferRegion> new_reads;
-    for (const BufferRegion& write_access : inner_block->writes) {
-      new_reads.push_back(write_access);
-    }
-    for (const BufferRegion& read_access : inner_block->reads) {
-      new_reads.push_back(read_access);
+  PrimExpr outer_predicate = division.back()[0]->extent;
+  PrimExpr inner_predicate = division.back()[1]->extent;
+  // Step 3. Derive block bindings for both outer and inner block.
+  Array<IterVar> outer_iter_vars;
+  Array<IterVar> inner_iter_vars;
+  Array<PrimExpr> outer_bindings;
+  Array<PrimExpr> inner_bindings;
+  Map<Var, PrimExpr> block_var_subst =                       //
+      DeriveBlockBinding(block->iter_vars, division,         //
+                         &outer_iter_vars, &outer_bindings,  //
+                         &inner_iter_vars, &inner_bindings);
+  // Step 4: Do var substitution to adjust to the new block bindings
+  Map<Var, arith::IntSet> inner_iter_dom;
+  for (const IterVar& iter : inner_iter_vars) {
+    inner_iter_dom.Set(iter->var, arith::IntSet::FromRange(iter->dom));
+    analyzer->Bind(iter->var, iter->dom);
+  }
+  Block block_subst =
+      Downcast<Block>(Substitute(block, block_var_subst, block_sref_reuse, analyzer));
+  // Step 5: Generate the inner block. The write regions of the inner blocks will be reduction if
+  // 1. The original block has init stmt.
+  // 2. There are outer reduction iter vars.
+  bool has_outer_reduction = false;
+  if (block_subst->init.defined()) {
+    for (const IterVar& iter_var : outer_iter_vars) {
+      if (iter_var->iter_type == kCommReduce) {
+        has_outer_reduction = true;
+        break;
+      }
     }
-    inner_block->reads = std::move(new_reads);
   }
-  block_sref_reuse.Set(block, inner_block_realize->block);
-
+  BlockRealize inner_realize = GenerateInner(/*is_write_reduction=*/has_outer_reduction,
+                                             /*iter_vars=*/inner_iter_vars,
+                                             /*iter_values*/ inner_bindings,
+                                             /*predicate=*/inner_predicate,
+                                             /*block=*/block_subst);
+  block_sref_reuse->Set(block, inner_realize->block);
   // Step 6: Generate the outer block.
-  BlockRealize outer_realize =
-      GenerateBlockizedOuterBlock(extractor, new_block, GetRef<BlockRealize>(inner_block_realize),
-                                  collector.inner_loops, outer_pred);
-  // Step 7: Do the actual replacement
-  self->Replace(loop_sref, outer_realize, block_sref_reuse);
-
-  // Step 8: Update the cached flags
-  StmtSRef outer_block_sref = self->stmt2ref.at(outer_realize->block.get());
-  StmtSRef scope_root = tir::GetScopeRoot(self, outer_block_sref, /*require_stage_pipeline=*/false);
+  return BlockRealize(
+      /*iter_values=*/std::move(outer_bindings),
+      /*predicate=*/std::move(outer_predicate),
+      /*block=*/
+      Block(/*iter_vars=*/std::move(outer_iter_vars),
+            /*reads=*/EvalSetRegions(block_subst->reads, inner_iter_dom),
+            /*writes=*/EvalSetRegions(block_subst->writes, inner_iter_dom),
+            /*name_hint=*/block_subst->name_hint + "_o",
+            /*body=*/MakeLoopNest(inner_realize, loops),
+            /*init=*/
+            block_subst->init.defined()  //
+                ? GenerateOuterInit(block_subst->init.value(), inner_realize, loops,
+                                    block_subst->name_hint + "_init")
+                : Optional<Stmt>(NullOpt)));
+}
+
+StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) {
+  arith::Analyzer analyzer;
+  Map<Block, Block> block_sref_reuse;
+  BlockRealize blockized = BlockizeImpl(self, loop_sref, &block_sref_reuse, &analyzer);
+  self->Replace(loop_sref, blockized, block_sref_reuse);
+  StmtSRef result = self->stmt2ref.at(blockized->block.get());
+  StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false);
   bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root);
   self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root));
   self->block_info[scope_root].affine_binding = scope_block_affine_binding;
-  return outer_block_sref;
-}
-
-/*!
- * \brief Update the map from the buffers in the desc to the impl of the tensor
- * intrinsic.
- * \param intrinsic The tensor intrinsic.
- * \param buffer_map The map to be updated.
- */
-void RemapTensorIntrinBuffers(
-    const TensorIntrin& intrinsic,
-    std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>* buffer_map) {
-  ICHECK_EQ(intrinsic->desc->params.size(), intrinsic->impl->params.size());
-  for (size_t i = 0; i < intrinsic->desc->params.size(); ++i) {
-    const Var& lhs_var = intrinsic->desc->params[i];
-    const Buffer& lhs_buffer = intrinsic->desc->buffer_map[lhs_var];
-    const Var& rhs_var = intrinsic->impl->params[i];
-    const Buffer& rhs_buffer = intrinsic->impl->buffer_map[rhs_var];
-    (*buffer_map)[rhs_buffer] = lhs_buffer;
-  }
+  return result;
 }
 
-void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref,
-               const TensorIntrin& intrinsic) {
-  /*!
-   * Check:
-   *   - Check buffer binding, including type, alignment, shape and etc.
-   *   - Check the sub AST is equal to the desc function.
-   *
-   * Mutate:
-   *   - Blockize the sub AST (please refer blockize for details)
-   *   - Bind buffers
-   *   - Mutate the impl of the tensor intrinsic by replacing its buffers with new
-   *     buffers created via match buffer region.
-   *   - Replace the sub tree with the mutated function.
-   */
-  const BlockRealize& desc_block_realize = Downcast<BlockRealize>(intrinsic->desc->body);
-  const BlockRealize& impl_block_realize = Downcast<BlockRealize>(intrinsic->impl->body);
-  Block impl_block = impl_block_realize->block;
-
+void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& intrin) {
   // Step 1: Blockize the subtree rooted at the given loop if needed
-  StmtSRef block_sref{nullptr};
-  if (block_or_loop_sref->StmtAs<ForNode>()) {
-    block_sref = Blockize(self, block_or_loop_sref);
+  BlockRealize block_realize{nullptr};
+  Optional<Block> old_block = NullOpt;
+  if (sref->stmt->IsInstance<BlockNode>()) {
+    block_realize = GetBlockRealize(self, sref);
+    old_block = block_realize->block;
+  } else if (sref->stmt->IsInstance<ForNode>()) {
+    arith::Analyzer analyzer;
+    Map<Block, Block> block_sref_reuse;
+    block_realize = BlockizeImpl(self, sref, &block_sref_reuse, &analyzer);
   } else {
-    ICHECK(block_or_loop_sref->StmtAs<BlockNode>());
-    block_sref = block_or_loop_sref;
+    LOG(FATAL) << "TypeError: Tensorize only support For or Block, but gets: "
+               << GetRef<Stmt>(sref->stmt);
+    throw;
   }
-  const BlockRealize& block_realize = GetBlockRealize(self, block_sref);
-
-  // Step 2: Compare the block with the desc of the tensor intrinsic, find the correspondence
-  // between buffers in the block and the desc.
+  PrimFunc intrin_desc = intrin->desc;
+  PrimFunc intrin_impl = DeepCopy(intrin->impl);
+  // Step 2: Structural pattern matching
   TensorizeComparator comparator(self->mod, /*assert_mode=*/true);
-  comparator.VisitStmt(block_realize, desc_block_realize);
-
-  // Step 3: Find the correspondence between buffers in the current AST and the impl of
-  // the tensor intrinsic
-  // Step 3.1: Map from intrinsic func buffer to desc func buffer
-  std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> intrin_buffer_map;
-  RemapTensorIntrinBuffers(intrinsic, &intrin_buffer_map);
-  // Step 3.2: Map form intrinsic func buffer to current AST buffer
-  std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map;
-  for (const auto& pair : intrin_buffer_map) {
-    auto it = comparator.rhs_buffer_map_.find(pair.second);
-    ICHECK(it != comparator.rhs_buffer_map_.end()) << pair.second;
-    buffer_map[pair.first] = it->second;
+  comparator.VisitStmt(block_realize, intrin_desc->body);
+  // Step 3: Prepare necessary mapping
+  // 1) Buffer mapping from intrin impl buffers to intrin desc buffers.
+  // 2) Buffer mapping from intrin impl buffers to buffers in the current AST.
+  // 3) Mapping impl buffers to their accessed regions.
+  std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> impl2desc;
+  ICHECK_EQ(intrin_desc->params.size(), intrin_impl->params.size());
+  for (int i = 0, n = intrin_desc->params.size(); i < n; ++i) {
+    const Buffer& desc = intrin_desc->buffer_map[intrin_desc->params[i]];
+    const Buffer& impl = intrin_impl->buffer_map[intrin_impl->params[i]];
+    impl2desc[impl] = desc;
   }
-
-  // Step 4: Create MatchBufferRegion for the params of the impl function of the tensor
-  // intrin to make them subregions of the buffer in the original IR.
-  std::unordered_map<Buffer, Array<Range>, ObjectPtrHash, ObjectPtrEqual> buffer_region_map;
+  std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> impl2cur;
+  for (const auto& pair : impl2desc) {
+    const Buffer& impl = pair.first;
+    const Buffer& desc = pair.second;
+    ICHECK(comparator.rhs_buffer_map_.count(desc));
+    impl2cur[impl] = comparator.rhs_buffer_map_[desc];
+  }
+  std::unordered_map<Buffer, Array<Range>, ObjectPtrHash, ObjectPtrEqual> impl2region;
+  Block impl_block = Downcast<BlockRealize>(intrin_impl->body)->block;
   for (const BufferRegion& read : impl_block->reads) {
-    buffer_region_map.emplace(read->buffer, read->region);
+    impl2region.emplace(read->buffer, read->region);
   }
   for (const BufferRegion& write : impl_block->writes) {
-    buffer_region_map.emplace(write->buffer, write->region);
+    impl2region.emplace(write->buffer, write->region);
   }
+  // Step 4: Create MatchBufferRegion for the params of the impl function of the tensor
+  // intrin to make them subregions of the buffer in the original IR.
   Array<MatchBufferRegion> match_buffer_regions;
-  match_buffer_regions.reserve(intrinsic->impl->params.size());
-  for (size_t i = 0; i < intrinsic->impl->params.size(); ++i) {
-    const auto& param = intrinsic->impl->params[i];
-    const auto& buffer = intrinsic->impl->buffer_map.at(param);
-    const auto& source = buffer_map.at(buffer);
-    // add the detected base indices to each buffer access region of the tensor intrinsic
-    Region old_region = buffer_region_map.at(buffer);
-    const auto& indices_base = comparator.buffer_indices_.at(source);
+  match_buffer_regions.reserve(intrin_impl->params.size());
+  for (int i = 0, n = intrin_impl->params.size(); i < n; ++i) {
+    const Buffer& impl = intrin_impl->buffer_map.at(intrin_impl->params[i]);
+    const Buffer& cur = impl2cur.at(impl);
+    const Array<Range>& old_region = impl2region.at(impl);
+    const std::vector<PrimExpr>& indices_base = comparator.buffer_indices_.at(cur);
     int offset = static_cast<int>(indices_base.size()) - static_cast<int>(old_region.size());
     ICHECK(offset >= 0);
-    Region new_region;
-    new_region.reserve(source->shape.size());
+    Array<Range> new_region;
+    new_region.reserve(cur->shape.size());
     for (int i = 0; i < offset; i++) {
-      new_region.push_back(Range::FromMinExtent(indices_base[i], 1));
+      PrimExpr min = indices_base[i];
+      PrimExpr extent = make_const(min.dtype(), 1);
+      new_region.push_back(Range::FromMinExtent(min, extent));
     }
     for (int i = 0; i < static_cast<int>(old_region.size()); i++) {
-      new_region.push_back(Range::FromMinExtent(indices_base[i + offset], old_region[i]->extent));
+      PrimExpr min = indices_base[i + offset];
+      PrimExpr extent = old_region[i]->extent;
+      new_region.push_back(Range::FromMinExtent(min, extent));
     }
-    match_buffer_regions.push_back(MatchBufferRegion(buffer, BufferRegion(source, new_region)));
+    match_buffer_regions.push_back(MatchBufferRegion(impl, BufferRegion(cur, new_region)));
   }
-
   // Step 5: Replace the subtree in the original IR with the tensor intrin impl.
-  ObjectPtr<BlockNode> new_block_ptr = make_object<BlockNode>(*block_realize->block.get());
-  new_block_ptr->body = impl_block->body;
-  ICHECK(new_block_ptr->match_buffers.empty());
-  new_block_ptr->match_buffers = std::move(match_buffer_regions);
-  Block new_block(new_block_ptr);
-
-  self->Replace(block_sref, new_block, {{block_realize->block, new_block}});
-
+  {
+    BlockNode* block = block_realize.CopyOnWrite()->block.CopyOnWrite();
+    block->body = impl_block->body;
+    block->match_buffers = std::move(match_buffer_regions);
+  }
+  if (old_block.defined()) {
+    self->Replace(sref, block_realize->block, {{old_block.value(), block_realize->block}});
+  } else {
+    self->Replace(sref, block_realize, {});
+  }
   // Step 6: Update the cached flags.
-  StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
-  self->UpdateScopeBlockInfo(static_cast<const BlockNode*>(scope_root->stmt)->body);
+  StmtSRef result = self->stmt2ref.at(block_realize->block.get());
+  StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false);
+  self->UpdateScopeBlockInfo(scope_root->StmtAs<BlockNode>()->body);
 }
 
 /******** InstructionKind Registration ********/
diff --git a/tests/python/unittest/test_tir_schedule_blockize.py b/tests/python/unittest/test_tir_schedule_blockize.py
index 481421cfdf..6d13281320 100644
--- a/tests/python/unittest/test_tir_schedule_blockize.py
+++ b/tests/python/unittest/test_tir_schedule_blockize.py
@@ -15,12 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-function-docstring,missing-module-docstring
-import sys
-import pytest
 import tvm
 import tvm.testing
-from tvm.script import tir as T
 from tvm import tir
+from tvm.script import tir as T
 from tvm.tir.schedule.testing import verify_trace_roundtrip
 
 # fmt: off
@@ -33,177 +31,219 @@ def single_elementwise(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128
             vi, vj = T.axis.remap("SS", [i, j])
             B[vi, vj] = A[vi, vj] * 2.0
 
-
-@T.prim_func
-def single_elementwise_blockized1(
-    A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]
-) -> None:
-    with T.block("blockized_B"):
-        vio = T.axis.spatial(1, 0)
-        vjo = T.axis.spatial(1, 0)
-        T.reads(A[0:128, 0:128])
-        T.writes(B[0:128, 0:128])
-        for i, j in T.grid(128, 128):
-            with T.block("B"):
-                vi, vj = T.axis.remap("SS", [i, j])
-                T.reads(A[vi, vj])
-                T.writes(B[vi, vj])
-                B[vi, vj] = A[vi, vj] * T.float32(2)
+# fmt: on
+# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
 
 
-@T.prim_func
-def single_elementwise_blockized2(
-    A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]
-) -> None:
-    for i in T.serial(128):
+def test_blockize_outer():
+    @T.prim_func
+    def after_blockize_outer(
+        A: T.Buffer[(128, 128), "float32"],
+        B: T.Buffer[(128, 128), "float32"],
+    ) -> None:
         with T.block("blockized_B"):
-            vi = T.axis.spatial(128, i)
+            vio = T.axis.spatial(1, 0)
             vjo = T.axis.spatial(1, 0)
-            T.reads(A[vi, 0:128])
-            T.writes(B[vi, 0:128])
-            for j in T.serial(128):
-                with T.block("B"):
-                    vj = T.axis.remap("S", [j])
-                    T.reads(A[vi, vj])
-                    T.writes(B[vi, vj])
-                    B[vi, vj] = A[vi, vj] * T.float32(2)
-
-
-@T.prim_func
-def two_elementwise(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
-    B = T.alloc_buffer([128, 128], dtype="float32")
-    for i, j in T.grid(128, 128):
-        with T.block("B"):
-            vi, vj = T.axis.remap("SS", [i, j])
-            T.reads(A[vi, vj])
-            T.writes(B[vi, vj])
-            B[vi, vj] = A[vi, vj] * T.float32(2)
-    for i, j in T.grid(128, 128):
-        with T.block("C"):
-            vi, vj = T.axis.remap("SS", [i, j])
-            T.reads(B[vi, vj])
-            T.writes(C[vi, vj])
-            C[vi, vj] = B[vi, vj] + T.float32(1)
-
-
-@T.prim_func
-def two_elementwise_blockized(
-    A: T.Buffer[(128, 128), "float32"],
-    C: T.Buffer[(128, 128), "float32"]
-) -> None:
-    B = T.alloc_buffer([128, 128], dtype="float32")
-    for i_0, j_0 in T.grid(8, 8):
-        with T.block("blockized_B"):
-            vio, vjo = T.axis.remap("SS", [i_0, j_0])
-            T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
-            T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
-            for i_1, j_1 in T.grid(16, 16):
+            for i, j in T.grid(128, 128):
                 with T.block("B"):
-                    vi, vj = T.axis.remap("SS", [i_1, j_1])
-                    T.reads(A[vio * 16 + vi, vjo * 16 + vj])
-                    T.writes(B[vio * 16 + vi, vjo * 16 + vj])
-                    B[vio * 16 + vi, vjo * 16 + vj] = A[vio * 16 + vi, vjo * 16 + vj] * T.float32(2)
-        with T.block("blockized_C"):
-            vio, vjo = T.axis.remap("SS", [i_0, j_0])
-            T.reads(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
-            T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
-            for ax0, ax1 in T.grid(16, 16):
-                with T.block("C"):
-                    vi, vj = T.axis.remap("SS", [ax0, ax1])
-                    T.reads(B[vio * 16 + vi, vjo * 16 + vj])
-                    T.writes(C[vio * 16 + vi, vjo * 16 + vj])
-                    C[vio * 16 + vi, vjo * 16 + vj] = B[vio * 16 + vi, vjo * 16 + vj] + T.float32(1)
-
-
-@T.prim_func
-def rowsum(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) -> None:
-    for k, i in T.grid(128, 128):
-        with T.block("B"):
-            vk, vi = T.axis.remap("RS", [k, i])
-            with T.init():
-                B[vi] = 0.0
-            B[vi] = B[vi] + A[vi, vk]
-
-
-@T.prim_func
-def rowsum_blockized(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) -> None:
-    with T.block("blockized_B"):
-        vko = T.axis.R(1, 0)
-        vio = T.axis.S(1, 0)
-        with T.init():
-            for i1 in T.serial(0, 128):
-                with T.block("B_init"):
-                    vi_init = T.axis.S(128, i1)
-                    B[vi_init] = T.float32(0)
-        for i0, i1_1 in T.grid(128, 128):
-            with T.block("B"):
-                vk, vi = T.axis.remap("RS", [i0, i1_1])
-                B[vi] = B[vi] + A[vi, vk]
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    B[vi, vj] = A[vi, vj] * 2.0
 
-
-# fmt: off
-# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
-
-def test_blockize_outer():
     func = single_elementwise
-    # schedule
     s = tir.Schedule(func, debug_mask="all")
-    B = s.get_block("B")
-    x, y = s.get_loops(B)
+    x, _ = s.get_loops(s.get_block("B"))
     s.blockize(x)
-    print(s.mod['main'].script())
-    tvm.ir.assert_structural_equal(s.mod["main"], single_elementwise_blockized1)
+    tvm.ir.assert_structural_equal(s.mod["main"], after_blockize_outer)
     verify_trace_roundtrip(sch=s, mod=func)
 
 
 def test_blockize_inner():
+    @T.prim_func
+    def after_blockize_inner(
+        A: T.Buffer[(128, 128), "float32"],
+        B: T.Buffer[(128, 128), "float32"],
+    ) -> None:
+        for i in T.serial(128):
+            with T.block("blockized_B"):
+                vi = T.axis.spatial(128, i)
+                vjo = T.axis.spatial(1, 0)
+                for j in T.serial(128):
+                    with T.block("B"):
+                        vj = T.axis.remap("S", [j])
+                        B[vi, vj] = A[vi, vj] * 2.0
+
     func = single_elementwise
-    # schedule
     s = tir.Schedule(func, debug_mask="all")
-    B = s.get_block("B")
-    x, y = s.get_loops(B)
+    _, y = s.get_loops(s.get_block("B"))
     s.blockize(y)
-    tvm.ir.assert_structural_equal(s.mod["main"], single_elementwise_blockized2)
+    tvm.ir.assert_structural_equal(s.mod["main"], after_blockize_inner)
     verify_trace_roundtrip(sch=s, mod=func)
 
 
 def test_two_elementwise_blockize_reverse_compute_at():
-    func = two_elementwise
+    @T.prim_func
+    def before_blockize_rca(
+        A: T.Buffer[(128, 128), "float32"],
+        C: T.Buffer[(128, 128), "float32"],
+    ) -> None:
+        B = T.alloc_buffer([128, 128], dtype="float32")
+        for i, j in T.grid(8, 8):
+            with T.block("B_o"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+                T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+                for i_1, j_1 in T.grid(16, 16):
+                    with T.block("B"):
+                        vi_i, vj_i = T.axis.remap("SS", [i_1, j_1])
+                        T.reads(A[vi * 16 + vi_i, vj * 16 + vj_i])
+                        T.writes(B[vi * 16 + vi_i, vj * 16 + vj_i])
+                        B[vi * 16 + vi_i, vj * 16 + vj_i] = A[vi * 16 + vi_i, vj * 16 + vj_i] * 2.0
+            for ax0, ax1 in T.grid(16, 16):
+                with T.block("C"):
+                    vi = T.axis.spatial(128, i * 16 + ax0)
+                    vj = T.axis.spatial(128, j * 16 + ax1)
+                    T.reads(B[vi, vj])
+                    T.writes(C[vi, vj])
+                    C[vi, vj] = B[vi, vj] + 1.0
+
+    @T.prim_func
+    def after_blockize_rca(
+        A: T.Buffer[(128, 128), "float32"],
+        C: T.Buffer[(128, 128), "float32"],
+    ) -> None:
+        B = T.alloc_buffer([128, 128], dtype="float32")
+        for i, j in T.grid(8, 8):
+            with T.block("B_o"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+                T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+                for i_1, j_1 in T.grid(16, 16):
+                    with T.block("B"):
+                        vi_i, vj_i = T.axis.remap("SS", [i_1, j_1])
+                        T.reads(A[vi * 16 + vi_i, vj * 16 + vj_i])
+                        T.writes(B[vi * 16 + vi_i, vj * 16 + vj_i])
+                        B[vi * 16 + vi_i, vj * 16 + vj_i] = A[vi * 16 + vi_i, vj * 16 + vj_i] * 2.0
+            with T.block("C_o"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                T.reads(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+                T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+                for ax0, ax1 in T.grid(16, 16):
+                    with T.block("C"):
+                        vi_i, vj_i = T.axis.remap("SS", [ax0, ax1])
+                        T.reads(B[vi * 16 + vi_i, vj * 16 + vj_i])
+                        T.writes(C[vi * 16 + vi_i, vj * 16 + vj_i])
+                        C[vi * 16 + vi_i, vj * 16 + vj_i] = B[vi * 16 + vi_i, vj * 16 + vj_i] + 1.0
+
+    func = before_blockize_rca
     s = tir.Schedule(func, debug_mask="all")
-    B = s.get_block("B")
-    C = s.get_block("C")
-    x, y = s.get_loops(B)
-    xo, xi = s.split(x, factors=[None, 16])
-    yo, yi = s.split(y, factors=[None, 16])
-    s.reorder(xo, yo, xi, yi)
-    s.blockize(xi)
-    s.reverse_compute_at(C, yo)
-    s.blockize(s.get_loops(C)[-2])
-    tvm.ir.assert_structural_equal(s.mod["main"], two_elementwise_blockized)
+    _, _, x, _ = s.get_loops(s.get_block("C"))
+    s.blockize(x)
+    tvm.ir.assert_structural_equal(s.mod["main"], after_blockize_rca)
     verify_trace_roundtrip(sch=s, mod=func)
 
 
 def test_two_elementwise_blockize_compute_at():
-    func = two_elementwise
+    @T.prim_func
+    def before_blockize_compute_at(
+        A: T.Buffer[(128, 128), "float32"],
+        C: T.Buffer[(128, 128), "float32"],
+    ) -> None:
+        # body
+        # with T.block("root")
+        B = T.alloc_buffer([128, 128], dtype="float32")
+        for i_0, j_0 in T.grid(8, 8):
+            for ax0, ax1 in T.grid(16, 16):
+                with T.block("B"):
+                    vi = T.axis.spatial(128, i_0 * 16 + ax0)
+                    vj = T.axis.spatial(128, j_0 * 16 + ax1)
+                    T.reads(A[vi, vj])
+                    T.writes(B[vi, vj])
+                    B[vi, vj] = A[vi, vj] * 2.0
+            with T.block("C_o"):
+                vi_o, vj_o = T.axis.remap("SS", [i_0, j_0])
+                T.reads(B[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
+                T.writes(C[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
+                for i_1, j_1 in T.grid(16, 16):
+                    with T.block("C"):
+                        vi_i, vj_i = T.axis.remap("SS", [i_1, j_1])
+                        T.reads(B[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
+                        T.writes(C[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
+                        C[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = (
+                            B[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + 1.0
+                        )
+
+    @T.prim_func
+    def after_blockize_compute_at(
+        A: T.Buffer[(128, 128), "float32"],
+        C: T.Buffer[(128, 128), "float32"],
+    ) -> None:
+        B = T.alloc_buffer([128, 128], dtype="float32")
+        for i_0, j_0 in T.grid(8, 8):
+            with T.block("B_o"):
+                vi_o, vj_o = T.axis.remap("SS", [i_0, j_0])
+                T.reads(A[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
+                T.writes(B[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
+                for ax0, ax1 in T.grid(16, 16):
+                    with T.block("B"):
+                        vi_i, vj_i = T.axis.remap("SS", [ax0, ax1])
+                        T.reads(A[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
+                        T.writes(B[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
+                        B[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = (
+                            A[vi_o * 16 + vi_i, vj_o * 16 + vj_i] * 2.0
+                        )
+            with T.block("C_o"):
+                vi_o, vj_o = T.axis.remap("SS", [i_0, j_0])
+                T.reads(B[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
+                T.writes(C[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
+                for i_1, j_1 in T.grid(16, 16):
+                    with T.block("C"):
+                        vi_i, vj_i = T.axis.remap("SS", [i_1, j_1])
+                        T.reads(B[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
+                        T.writes(C[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
+                        C[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = (
+                            B[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + 1.0
+                        )
+
+    func = before_blockize_compute_at
     s = tir.Schedule(func, debug_mask="all")
-    B = s.get_block("B")
-    C = s.get_block("C")
-    x, y = s.get_loops(C)
-    xo, xi = s.split(x, factors=[None, 16])
-    yo, yi = s.split(y, factors=[None, 16])
-    s.reorder(xo, yo, xi, yi)
-    s.blockize(xi)
-    s.compute_at(B, yo)
-    s.blockize(s.get_loops(B)[-2])
-    tvm.ir.assert_structural_equal(s.mod["main"], two_elementwise_blockized)
+    _, _, x, _ = s.get_loops(s.get_block("B"))
+    s.blockize(x)
+    tvm.ir.assert_structural_equal(s.mod["main"], after_blockize_compute_at)
     verify_trace_roundtrip(sch=s, mod=func)
 
 
 def test_blockize_init_loops():
+    @T.prim_func
+    def rowsum(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) -> None:
+        for k, i in T.grid(128, 128):
+            with T.block("B"):
+                vk, vi = T.axis.remap("RS", [k, i])
+                with T.init():
+                    B[vi] = 0.0
+                B[vi] = B[vi] + A[vi, vk]
+
+    @T.prim_func
+    def after_rowsum_blockize(
+        A: T.Buffer[(128, 128), "float32"],
+        B: T.Buffer[(128,), "float32"],
+    ) -> None:
+        with T.block("blockized_B"):
+            vko = T.axis.R(1, 0)
+            vio = T.axis.S(1, 0)
+            with T.init():
+                for i1 in T.serial(0, 128):
+                    with T.block("B_init"):
+                        vi_init = T.axis.S(128, i1)
+                        B[vi_init] = T.float32(0)
+            for i0, i1_1 in T.grid(128, 128):
+                with T.block("B"):
+                    vk, vi = T.axis.remap("RS", [i0, i1_1])
+                    B[vi] = B[vi] + A[vi, vk]
+
     s = tir.Schedule(rowsum, debug_mask="all")
     k, _ = s.get_loops(s.get_block("B"))
     s.blockize(k)
-    tvm.ir.assert_structural_equal(s.mod["main"], rowsum_blockized)
+    tvm.ir.assert_structural_equal(s.mod["main"], after_rowsum_blockize)
     verify_trace_roundtrip(sch=s, mod=rowsum)