You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/12/09 00:37:56 UTC

[tvm] branch main updated: [TIR] Add preserve_unit_iters option to blockize/tensorize (#13579)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8545297a5e [TIR] Add preserve_unit_iters option to blockize/tensorize (#13579)
8545297a5e is described below

commit 8545297a5e4a1b2b274b000850a94d95213fabd0
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Thu Dec 8 16:37:51 2022 -0800

    [TIR] Add preserve_unit_iters option to blockize/tensorize (#13579)
    
    * [TIR] Add preserve_unit_iters option to blockize/tensorize
    
    * fix
---
 include/tvm/arith/iter_affine_map.h                |   5 +-
 include/tvm/tir/schedule/schedule.h                |  11 +-
 python/tvm/arith/iter_affine_map.py                |  15 +-
 python/tvm/tir/schedule/schedule.py                |  17 +-
 src/arith/iter_affine_map.cc                       |  31 ++-
 src/tir/schedule/concrete_schedule.cc              |  16 +-
 src/tir/schedule/concrete_schedule.h               |   6 +-
 src/tir/schedule/primitive.h                       |   6 +-
 src/tir/schedule/primitive/blockize_tensorize.cc   |  51 ++--
 src/tir/schedule/schedule.cc                       |   6 +-
 src/tir/schedule/traced_schedule.cc                |  20 +-
 src/tir/schedule/traced_schedule.h                 |   6 +-
 .../python/unittest/test_arith_iter_affine_map.py  |  29 +++
 .../test_meta_schedule_schedule_rule_mlt_intrin.py |  30 +--
 .../test_meta_schedule_schedule_rule_mlt_tc.py     |  41 +--
 .../unittest/test_meta_schedule_trace_apply.py     | 278 ++++++++++-----------
 .../python/unittest/test_tir_schedule_blockize.py  |  29 ++-
 17 files changed, 352 insertions(+), 245 deletions(-)

diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h
index 6b98d84fdf..0d8bd574ae 100644
--- a/include/tvm/arith/iter_affine_map.h
+++ b/include/tvm/arith/iter_affine_map.h
@@ -396,6 +396,8 @@ Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
  * \param predicate The predicate constraints on the input iterators
  * \param check_level The iter mapping checking level.
  * \param analyzer Analyzer used to get context information.
+ * \param simplify_trivial_iterators If true, iterators with extent of
+ *           1 will be replaced with a constant value.
  *
  * \return The result list has length len(bindings) + 1
         [0, len(bindings)): The iter map matching result. The inner list is of length 2.
@@ -407,7 +409,8 @@ Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
 Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
                                       const Map<Var, Range>& input_iters,
                                       const Array<Var>& sub_iters, const PrimExpr& predicate,
-                                      IterMapLevel check_level, arith::Analyzer* analyzer);
+                                      IterMapLevel check_level, arith::Analyzer* analyzer,
+                                      bool simplify_trivial_iterators = true);
 
 /*!
  * \brief Given an expression that may contain IterMapExpr, transform it to normal PrimExpr.
diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h
index 5dbc1b5af3..c4838f2eb8 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -563,21 +563,26 @@ class ScheduleNode : public runtime::Object {
   /*!
    * \brief Convert the subtree rooted at a specific loop into a block.
    * \param loop_rv the root of the subtree
+   * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
    * \return the new block
    */
-  virtual BlockRV Blockize(const LoopRV& loop_rv) = 0;
+  virtual BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters = true) = 0;
   /*!
    * \brief Tensorize the computation enclosed by loop with the tensor intrin.
    * \param loop_rv The loop to be tensorized
    * \param intrin Name of the tensor intrinsic
+   * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
    */
-  virtual void Tensorize(const LoopRV& loop_rv, const String& intrin) = 0;
+  virtual void Tensorize(const LoopRV& loop_rv, const String& intrin,
+                         bool preserve_unit_iters = true) = 0;
   /*!
    * \brief Tensorize the computation enclosed by loop with the tensor intrin.
    * \param block_rv The block to be tensorized
    * \param intrin Name of the tensor intrinsic
+   * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
    */
-  virtual void Tensorize(const BlockRV& block_rv, const String& intrin) = 0;
+  virtual void Tensorize(const BlockRV& block_rv, const String& intrin,
+                         bool preserve_unit_iters = true) = 0;
 
   /******** Schedule: Annotation ********/
   /*!
diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py
index 77d6f418b8..54dbcef325 100644
--- a/python/tvm/arith/iter_affine_map.py
+++ b/python/tvm/arith/iter_affine_map.py
@@ -173,7 +173,12 @@ def normalize_iter_map_to_expr(expr):
 
 
 def subspace_divide(
-    bindings, input_iters, sub_iters, predicate=True, check_level=IterMapLevel.Surjective
+    bindings,
+    input_iters,
+    sub_iters,
+    predicate=True,
+    check_level=IterMapLevel.Surjective,
+    simplify_trivial_iterators=True,
 ):
     """Detect if bindings can be written as
     [a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n]
@@ -206,6 +211,10 @@ def subspace_divide(
     check_level : Union[str, IterMapLevel]
         Checking level of iteration mapping
 
+    simplify_trivial_iterators: bool
+        If true, iterators with extent of 1 will be replaced with a
+        constant value.
+
     Returns
     -------
     results : List[List[PrimExpr]]
@@ -218,7 +227,9 @@ def subspace_divide(
     """
     if isinstance(check_level, str):
         check_level = IterMapLevel.from_str(check_level)
-    return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, check_level)
+    return _ffi_api.SubspaceDivide(
+        bindings, input_iters, sub_iters, predicate, check_level, simplify_trivial_iterators
+    )
 
 
 def inverse_affine_iter_map(iter_map, outputs):
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index 91c42f2a8d..5ff9d71313 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2186,13 +2186,15 @@ class Schedule(Object):
     ########## Schedule: Blockize & Tensorize ##########
 
     @type_checked
-    def blockize(self, loop: LoopRV) -> BlockRV:
+    def blockize(self, loop: LoopRV, preserve_unit_iters: bool = True) -> BlockRV:
         """Convert the subtree rooted at a specific loop into a block.
 
         Parameters
         ----------
         loop : LoopRV
             The root of the subtree.
+        preserve_unit_iters : bool
+            Whether or not to preserve unit iterators in block bindings
 
         Returns
         -------
@@ -2257,10 +2259,15 @@ class Schedule(Object):
         block are divisible by the subspace represented by the loops starting at the given loop.
         """
 
-        return _ffi_api.ScheduleBlockize(self, loop)  # type: ignore # pylint: disable=no-member
+        return _ffi_api.ScheduleBlockize(self, loop, preserve_unit_iters)  # type: ignore # pylint: disable=no-member
 
     @type_checked
-    def tensorize(self, block_or_loop: Union[BlockRV, LoopRV], tensor_intrin: str) -> None:
+    def tensorize(
+        self,
+        block_or_loop: Union[BlockRV, LoopRV],
+        tensor_intrin: str,
+        preserve_unit_iters: bool = True,
+    ) -> None:
         """Tensorize the computation enclosed by loop with the tensor intrinsic.
 
         Parameters
@@ -2269,6 +2276,8 @@ class Schedule(Object):
             The loop to be tensorized.
         tensor_intrin : str
             The tensor intrin or the name of the tensor intrin.
+        preserve_unit_iters : bool
+            Whether or not to preserve unit iterators in block bindings
 
         Examples
         --------
@@ -2402,7 +2411,7 @@ class Schedule(Object):
                         )
         """
         _ffi_api.ScheduleTensorize(  # type: ignore # pylint: disable=no-member
-            self, block_or_loop, tensor_intrin
+            self, block_or_loop, tensor_intrin, preserve_unit_iters
         )
 
     ########## Schedule: Annotation ##########
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index adba61632f..03a36e803b 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -1812,18 +1812,26 @@ class SubspaceDivider {
     // extent of inner
     PrimExpr inner_extent;
 
+    // The kind of the division result.
+    enum class Kind {
+      kInner,  // Indicates the division result is totally in inner subspace.
+      kOuter,  // Indicates the division result is totally in outer subspace.
+      kMixed,  // Indicates the division result is mixed in both subspace.
+    } kind;
+
     DivisionResult(IterMapExpr outer, PrimExpr outer_extent, IterMapExpr inner,
-                   PrimExpr inner_extent)
+                   PrimExpr inner_extent, Kind kind = Kind::kMixed)
         : outer(std::move(outer)),
           inner(std::move(inner)),
           outer_extent(std::move(outer_extent)),
-          inner_extent(std::move(inner_extent)) {}
+          inner_extent(std::move(inner_extent)),
+          kind(kind) {}
 
     // whether the division result is totally in outer subspace
-    bool IsOuter() const { return is_one(inner_extent); }
+    bool IsOuter() const { return kind == Kind::kOuter; }
 
     // whether the division result is totally in inner subspace
-    bool IsInner() const { return is_one(outer_extent); }
+    bool IsInner() const { return kind == Kind::kInner; }
 
     IterSplitExpr GetOuterAsSplit() const { return GetAsSplit(outer, outer_extent); }
 
@@ -1832,13 +1840,13 @@ class SubspaceDivider {
     static DivisionResult Inner(const IterMapExpr& iter, const PrimExpr& extent) {
       auto dtype = iter.dtype();
       return DivisionResult(IterSumExpr({}, make_const(dtype, 0)), make_const(dtype, 1), iter,
-                            extent);
+                            extent, Kind::kInner);
     }
 
     static DivisionResult Outer(const IterMapExpr& iter, const PrimExpr& extent) {
       auto dtype = iter.dtype();
       return DivisionResult(iter, extent, IterSumExpr({}, make_const(dtype, 0)),
-                            make_const(dtype, 1));
+                            make_const(dtype, 1), Kind::kOuter);
     }
 
     // Special value to indicate the division is not possible
@@ -2066,9 +2074,11 @@ class SubspaceDivider {
 Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
                                       const Map<Var, Range>& input_iters,
                                       const Array<Var>& sub_iters, const PrimExpr& predicate,
-                                      IterMapLevel check_level, arith::Analyzer* analyzer) {
+                                      IterMapLevel check_level, arith::Analyzer* analyzer,
+                                      bool simplify_trivial_iterators) {
   if (!IterRangeSanityCheck(input_iters)) return Array<Array<IterMark>>();
-  auto res = DetectIterMap(bindings, input_iters, predicate, check_level, analyzer);
+  auto res = DetectIterMap(bindings, input_iters, predicate, check_level, analyzer,
+                           simplify_trivial_iterators);
   const Array<IterSumExpr>& maps = res->indices;
   if (maps.empty()) return {};
 
@@ -2096,10 +2106,11 @@ Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
 
 TVM_REGISTER_GLOBAL("arith.SubspaceDivide")
     .set_body_typed([](const Array<PrimExpr>& bindings, const Map<Var, Range>& root_iters,
-                       const Array<Var>& sub_iters, const PrimExpr& predicate, int check_level) {
+                       const Array<Var>& sub_iters, const PrimExpr& predicate, int check_level,
+                       bool simplify_trivial_iterators) {
       arith::Analyzer ana;
       return SubspaceDivide(bindings, root_iters, sub_iters, predicate, IterMapLevel(check_level),
-                            &ana);
+                            &ana, simplify_trivial_iterators);
     });
 
 class InverseAffineIterMapTransformer {
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index a0d29a00f8..7ae0185b42 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -690,25 +690,29 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) {
 }
 
 /******** Schedule: Blockize & Tensorize ********/
-BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv) {
+BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) {
   StmtSRef result{nullptr};
   TVM_TIR_SCHEDULE_BEGIN();
-  result = tir::Blockize(state_, this->GetSRef(loop_rv));
+  result = tir::Blockize(state_, this->GetSRef(loop_rv), preserve_unit_iters);
   this->state_->DebugVerify();
   TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_);
   return CreateRV<BlockRV>(result);
 }
 
-void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin) {
+void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin,
+                                     bool preserve_unit_iters) {
   TVM_TIR_SCHEDULE_BEGIN();
-  tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin).value());
+  tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin).value(),
+                 preserve_unit_iters);
   this->state_->DebugVerify();
   TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_);
 }
 
-void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin) {
+void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin,
+                                     bool preserve_unit_iters) {
   TVM_TIR_SCHEDULE_BEGIN();
-  tir::Tensorize(state_, this->GetSRef(block_rv), tir::TensorIntrin::Get(intrin).value());
+  tir::Tensorize(state_, this->GetSRef(block_rv), tir::TensorIntrin::Get(intrin).value(),
+                 preserve_unit_iters);
   this->state_->DebugVerify();
   TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_);
 }
diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h
index 66fca10771..2381870760 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -137,9 +137,9 @@ class ConcreteScheduleNode : public ScheduleNode {
                     int offset) override;
   void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override;
   /******** Schedule: Blockize & Tensorize ********/
-  BlockRV Blockize(const LoopRV& loop_rv) override;
-  void Tensorize(const BlockRV& block_rv, const String& intrin) override;
-  void Tensorize(const LoopRV& loop_rv, const String& intrin) override;
+  BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override;
+  void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) override;
+  void Tensorize(const LoopRV& loop_rv, const String& intrin, bool preserve_unit_iters) override;
   /******** Schedule: Annotation ********/
   void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override;
   void Unannotate(const LoopRV& loop_rv, const String& ann_key) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index af1988eaaf..38931aa271 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -452,18 +452,20 @@ TVM_DLL void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, in
  * \brief Convert the subtree rooted at a specific loop into a block.
  * \param self The state of the schedule
  * \param loop_sref The root of the subtree
+ * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
  * \return The new block
  */
-TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref);
+TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_unit_iters);
 
 /*!
  * \brief Tensorize the computation enclosed by loop with the tensor intrinsic.
  * \param self The state of the schedule
  * \param block_or_loop_sref The block or loop to be tensorized.
  * \param intrin The tensor intrinsic.
+ * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
  */
 TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref,
-                       const TensorIntrin& intrin);
+                       const TensorIntrin& intrin, bool preserve_unit_iters);
 
 /******** Schedule: Annotation ********/
 /*!
diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc
index 80a653c544..6860927c4d 100644
--- a/src/tir/schedule/primitive/blockize_tensorize.cc
+++ b/src/tir/schedule/primitive/blockize_tensorize.cc
@@ -76,7 +76,7 @@ class SubspaceNotDivisibleError : public ScheduleError {
  *   1. The binding covers no inner loop vars.
  *   2. The binding covers only inner loop vars.
  *
- * The bindings are not required to be quasi-affine.
+ * The bindings are not required to be quasi-affine. Trivial block iters are always preserved.
  *
  * \param iter_vars The input iterators
  * \param bindings The values of iter_vars
@@ -146,12 +146,13 @@ Array<Array<arith::IterMark>> TrivialSubspaceDivision(const Array<IterVar>& iter
  * \param loop_sref The loop that is the root of the second subspace.
  * \param loops The loops that represents the second part of the subspace.
  * \param analyzer The arithmetic analyzer to use.
+ * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
  */
 Array<Array<arith::IterMark>> SubspaceDivide(const BlockRealize& realize,
                                              const StmtSRef& block_sref,  //
                                              const StmtSRef& loop_sref,   //
                                              std::vector<const ForNode*>* loops,
-                                             arith::Analyzer* analyzer) {
+                                             arith::Analyzer* analyzer, bool preserve_unit_iters) {
   Array<Var> inner_vars;
   Array<Var> outer_vars;
   Map<Var, Range> loop_var_domain;
@@ -173,7 +174,8 @@ Array<Array<arith::IterMark>> SubspaceDivide(const BlockRealize& realize,
   }
   Array<Array<arith::IterMark>> result =
       arith::SubspaceDivide(realize->iter_values, loop_var_domain, inner_vars, realize->predicate,
-                            arith::IterMapLevel::Surjective, analyzer);
+                            arith::IterMapLevel::Surjective, analyzer,
+                            /*simplify_trivial_iterators=*/!preserve_unit_iters);
   if (!result.empty()) {
     return result;
   }
@@ -191,6 +193,7 @@ Array<Array<arith::IterMark>> SubspaceDivide(const BlockRealize& realize,
  * \param outer_bindings The outer block bindings.
  * \param inner_iter_vars The inner block iterators.
  * \param inner_bindings The inner block bindings.
+ * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
  * \return A substitution plan to the iterators in the original inner block.
  */
 Map<Var, PrimExpr> DeriveBlockBinding(const Array<IterVar>& iter_vars,                //
@@ -198,7 +201,7 @@ Map<Var, PrimExpr> DeriveBlockBinding(const Array<IterVar>& iter_vars,
                                       Array<IterVar>* outer_iter_vars,                //
                                       Array<PrimExpr>* outer_bindings,                //
                                       Array<IterVar>* inner_iter_vars,                //
-                                      Array<PrimExpr>* inner_bindings) {
+                                      Array<PrimExpr>* inner_bindings, bool preserve_unit_iters) {
   using arith::IterMapExpr;
   using arith::IterMapExprNode;
   using arith::NormalizeIterMapToExpr;
@@ -427,7 +430,8 @@ Stmt MakeLoopNest(Stmt stmt, const std::vector<const ForNode*>& loops) {
 }
 
 BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref,
-                          Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer) {
+                          Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer,
+                          bool preserve_unit_iters) {
   TVM_SREF_TO_FOR(loop_sref);
   // Step 1: Check and get the only block under `loop`.
   BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, loop_sref);
@@ -436,7 +440,7 @@ BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref,
   // Step 2: Derive subspace division
   std::vector<const ForNode*> loops;
   Array<Array<arith::IterMark>> division =
-      SubspaceDivide(block_realize, block_sref, loop_sref, &loops, analyzer);
+      SubspaceDivide(block_realize, block_sref, loop_sref, &loops, analyzer, preserve_unit_iters);
   if (division.empty()) {
     throw SubspaceNotDivisibleError(self->mod, GetRef<For>(loops.back()), block);
   }
@@ -450,7 +454,8 @@ BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref,
   Map<Var, PrimExpr> block_var_subst =                       //
       DeriveBlockBinding(block->iter_vars, division,         //
                          &outer_iter_vars, &outer_bindings,  //
-                         &inner_iter_vars, &inner_bindings);
+                         &inner_iter_vars, &inner_bindings,  //
+                         preserve_unit_iters);
   // Step 4: Do var substitution to adjust to the new block bindings
   Map<Var, arith::IntSet> inner_iter_dom;
   for (const IterVar& iter : inner_iter_vars) {
@@ -494,10 +499,11 @@ BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref,
                 : Optional<Stmt>(NullOpt)));
 }
 
-StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) {
+StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_unit_iters) {
   arith::Analyzer analyzer;
   Map<Block, Block> block_sref_reuse;
-  BlockRealize blockized = BlockizeImpl(self, loop_sref, &block_sref_reuse, &analyzer);
+  BlockRealize blockized =
+      BlockizeImpl(self, loop_sref, &block_sref_reuse, &analyzer, preserve_unit_iters);
   self->Replace(loop_sref, blockized, block_sref_reuse);
   StmtSRef result = self->stmt2ref.at(blockized->block.get());
   StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false);
@@ -507,7 +513,8 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) {
   return result;
 }
 
-void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& intrin) {
+void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& intrin,
+               bool preserve_unit_iters) {
   // Step 1: Blockize the subtree rooted at the given loop if needed
   BlockRealize block_realize{nullptr};
   Optional<Block> old_block = NullOpt;
@@ -517,7 +524,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int
   } else if (sref->stmt->IsInstance<ForNode>()) {
     arith::Analyzer analyzer;
     Map<Block, Block> block_sref_reuse;
-    block_realize = BlockizeImpl(self, sref, &block_sref_reuse, &analyzer);
+    block_realize = BlockizeImpl(self, sref, &block_sref_reuse, &analyzer, preserve_unit_iters);
   } else {
     LOG(FATAL) << "TypeError: Tensorize only support For or Block, but gets: "
                << GetRef<Stmt>(sref->stmt);
@@ -617,16 +624,17 @@ struct BlockizeTraits : public UnpackedInstTraits<BlockizeTraits> {
 
  private:
   static constexpr size_t kNumInputs = 1;
-  static constexpr size_t kNumAttrs = 0;
+  static constexpr size_t kNumAttrs = 1;
   static constexpr size_t kNumDecisions = 0;
 
-  static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) {
-    return sch->Blockize(loop_rv);
+  static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Bool preserve_unit_iters) {
+    return sch->Blockize(loop_rv, preserve_unit_iters.operator bool());
   }
 
-  static String UnpackedAsPython(Array<String> outputs, String loop_rv) {
+  static String UnpackedAsPython(Array<String> outputs, String loop_rv, Bool preserve_unit_iters) {
     PythonAPICall py("blockize");
     py.Input("loop", loop_rv);
+    py.Input("preserve_unit_iters", preserve_unit_iters.operator bool());
     py.SingleOutput(outputs);
     return py.Str();
   }
@@ -641,24 +649,27 @@ struct TensorizeTraits : public UnpackedInstTraits<TensorizeTraits> {
 
  private:
   static constexpr size_t kNumInputs = 1;
-  static constexpr size_t kNumAttrs = 1;
+  static constexpr size_t kNumAttrs = 2;
   static constexpr size_t kNumDecisions = 0;
 
-  static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String intrin) {
+  static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String intrin,
+                                      Bool preserve_unit_iters) {
     if (const auto* block = block_or_loop_rv.as<BlockRVNode>()) {
-      sch->Tensorize(GetRef<BlockRV>(block), intrin);
+      sch->Tensorize(GetRef<BlockRV>(block), intrin, preserve_unit_iters.operator bool());
     } else if (const auto* loop = block_or_loop_rv.as<LoopRVNode>()) {
-      sch->Tensorize(GetRef<LoopRV>(loop), intrin);
+      sch->Tensorize(GetRef<LoopRV>(loop), intrin, preserve_unit_iters.operator bool());
     } else {
       LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: "
                  << block_or_loop_rv->GetTypeKey();
     }
   }
 
-  static String UnpackedAsPython(Array<String> outputs, String block_or_loop_rv, String intrin) {
+  static String UnpackedAsPython(Array<String> outputs, String block_or_loop_rv, String intrin,
+                                 Bool preserve_unit_iters) {
     PythonAPICall py("tensorize");
     py.Input("block_or_loop", block_or_loop_rv);
     py.Input("tensor_intrin", intrin);
+    py.Input("preserve_unit_iters", preserve_unit_iters.operator bool());
     return py.Str();
   }
 
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 3fe81c9f43..d008f3639c 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -211,11 +211,11 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope")
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize")
     .set_body_method<Schedule>(&ScheduleNode::Blockize);
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize")
-    .set_body_typed([](Schedule self, ObjectRef rv, String intrin) {
+    .set_body_typed([](Schedule self, ObjectRef rv, String intrin, bool preserve_unit_iters) {
       if (const auto* block_rv = rv.as<BlockRVNode>()) {
-        self->Tensorize(GetRef<BlockRV>(block_rv), intrin);
+        self->Tensorize(GetRef<BlockRV>(block_rv), intrin, preserve_unit_iters);
       } else if (const auto* loop_rv = rv.as<LoopRVNode>()) {
-        self->Tensorize(GetRef<LoopRV>(loop_rv), intrin);
+        self->Tensorize(GetRef<LoopRV>(loop_rv), intrin, preserve_unit_iters);
       } else {
         LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey()
                    << ". Its value is: " << rv;
diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc
index 010730f66c..00941b4857 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -442,34 +442,36 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index,
 
 /******** Schedule: Blockize & Tensorize ********/
 
-BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv) {
-  BlockRV new_block = ConcreteScheduleNode::Blockize(loop_rv);
+BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) {
+  BlockRV new_block = ConcreteScheduleNode::Blockize(loop_rv, preserve_unit_iters);
   static const InstructionKind& kind = InstructionKind::Get("Blockize");
   trace_->Append(/*inst=*/Instruction(
       /*kind=*/kind,
       /*inputs=*/{loop_rv},
-      /*attrs=*/{},
+      /*attrs=*/{Bool(preserve_unit_iters)},
       /*outputs=*/{new_block}));
   return new_block;
 }
 
-void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin) {
-  ConcreteScheduleNode::Tensorize(loop_rv, intrin);
+void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin,
+                                   bool preserve_unit_iters) {
+  ConcreteScheduleNode::Tensorize(loop_rv, intrin, preserve_unit_iters);
   static const InstructionKind& kind = InstructionKind::Get("Tensorize");
   trace_->Append(/*inst=*/Instruction(
       /*kind=*/kind,
       /*inputs=*/{loop_rv},
-      /*attrs=*/{intrin},
+      /*attrs=*/{intrin, Bool(preserve_unit_iters)},
       /*outputs=*/{}));
 }
 
-void TracedScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin) {
-  ConcreteScheduleNode::Tensorize(block_rv, intrin);
+void TracedScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin,
+                                   bool preserve_unit_iters) {
+  ConcreteScheduleNode::Tensorize(block_rv, intrin, preserve_unit_iters);
   static const InstructionKind& kind = InstructionKind::Get("Tensorize");
   trace_->Append(/*inst=*/Instruction(
       /*kind=*/kind,
       /*inputs=*/{block_rv},
-      /*attrs=*/{intrin},
+      /*attrs=*/{intrin, Bool(preserve_unit_iters)},
       /*outputs=*/{}));
 }
 
diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h
index cea2096d20..80257f644f 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -96,9 +96,9 @@ class TracedScheduleNode : public ConcreteScheduleNode {
                     int offset) final;
   void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final;
   /******** Schedule: Blockize & Tensorize ********/
-  BlockRV Blockize(const LoopRV& loop_rv) final;
-  void Tensorize(const BlockRV& block_rv, const String& intrin) final;
-  void Tensorize(const LoopRV& loop_rv, const String& intrin) final;
+  BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final;
+  void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) final;
+  void Tensorize(const LoopRV& loop_rv, const String& intrin, bool preserve_unit_iters) final;
   /******** Schedule: Annotation ********/
   void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override;
   void Unannotate(const LoopRV& loop_rv, const String& ann_key) override;
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py
index 6a2fdbbb3f..7ae5c58a95 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -670,6 +670,35 @@ def test_subspace_division():
     assert len(res) == 0
 
 
+def test_subspace_divide_trivial_iters():
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
+    z = tvm.tir.Var("z", "int32")
+
+    # trivial 1.1
+    res = tvm.arith.subspace_divide(
+        [x * 16 + y], var_dom([(x, 1), (y, 16)]), [y], simplify_trivial_iterators=False
+    )
+    res = convert_division(res)
+    assert len(res) == 2
+    tvm.ir.assert_structural_equal(res[0][0], x)
+    tvm.ir.assert_structural_equal(res[0][1], y)
+
+    # trivial 1.2
+    res = tvm.arith.subspace_divide(
+        [x, y],
+        var_dom([(x, 1), (y, 1)]),
+        [y],
+        simplify_trivial_iterators=False,
+    )
+    res = convert_division(res)
+    assert len(res) == 3
+    tvm.ir.assert_structural_equal(res[0][0], x)
+    tvm.ir.assert_structural_equal(res[0][1], 0)
+    tvm.ir.assert_structural_equal(res[1][0], 0)
+    tvm.ir.assert_structural_equal(res[1][1], y)
+
+
 def test_complex():
     n0 = create_iter("n0", 2)
     n1 = create_iter("n1", 4)
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py
index e70f7cb2c6..54f342c3a5 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py
@@ -74,16 +74,16 @@ def test_vnni_conv2d_nchwc():
         for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1 in T.grid(1, 8, 28, 56, 1, 1, 2, 1, 1, 1):
             for i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1):
                 with T.block("conv2d_NCHWc_int8_o"):
-                    n = T.axis.spatial(1, 0)
+                    n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1)
                     oc_chunk = T.axis.spatial(16, i1_0 * 2 + i1_1 + i1_2 + i1_3)
                     oh = T.axis.spatial(56, i2_0 * 2 + i2_1 * 2 + i2_2 + i2_3)
                     ow = T.axis.spatial(56, i3_3 + i3_0 + i3_1 + i3_2)
-                    oc_block_o = T.axis.spatial(1, 0)
-                    kh = T.axis.reduce(1, 0)
-                    kw = T.axis.reduce(1, 0)
+                    oc_block_o = T.axis.spatial(1, i4_0_2 + i4_0_3 + i4_0_0 + i4_0_1)
+                    kh = T.axis.reduce(1, i5_1 + i5_0)
+                    kw = T.axis.reduce(1, i6_0 + i6_1)
                     ic_outer = T.axis.reduce(4, i7_0 * 4 + i7_1)
                     ic_f_inner = T.axis.reduce(4, i8_0 + i8_1)
-                    ic_s_inner_o = T.axis.reduce(1, 0)
+                    ic_s_inner_o = T.axis.reduce(1, i9_0_1 + i9_0_0)
                     T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4])
                     T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, 0 : 16])
                     T.block_attr({"meta_schedule.auto_tensorize":"dot_16x4_vnni"})
@@ -119,16 +119,16 @@ def test_vnni_conv2d_nchwc():
         for i0_0, i1_0, i2_0, i3_0, i4_0_0 in T.grid(1, 8, 28, 56, 1):
             for i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 2, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1):
                 with T.block("conv2d_NCHWc_int8_o"):
-                    n = T.axis.spatial(1, 0)
+                    n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1)
                     oc_chunk = T.axis.spatial(16, i1_0 * 2 + i1_1 + i1_2 + i1_3)
                     oh = T.axis.spatial(56, i2_0 * 2 + i2_1 * 2 + i2_2 + i2_3)
                     ow = T.axis.spatial(56, i3_3 + i3_0 + i3_1 + i3_2)
-                    oc_block_o = T.axis.spatial(1, 0)
-                    kh = T.axis.reduce(1, 0)
-                    kw = T.axis.reduce(1, 0)
+                    oc_block_o = T.axis.spatial(1, i4_0_2 + i4_0_3 + i4_0_0 + i4_0_1)
+                    kh = T.axis.reduce(1, i5_1 + i5_0)
+                    kw = T.axis.reduce(1, i6_0 + i6_1)
                     ic_outer = T.axis.reduce(4, i7_0 * 4 + i7_1)
                     ic_f_inner = T.axis.reduce(4, i8_0 + i8_1)
-                    ic_s_inner_o = T.axis.reduce(1, 0)
+                    ic_s_inner_o = T.axis.reduce(1, i9_0_1 + i9_0_0)
                     T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4])
                     T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, 0 : 16])
                     T.block_attr({"meta_schedule.auto_tensorize":"dot_16x4_vnni"})
@@ -162,16 +162,16 @@ def test_vnni_conv2d_nchwc():
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
         for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 8, 28, 56, 1, 1, 2, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1):
             with T.block("conv2d_NCHWc_int8_o"):
-                n = T.axis.spatial(1, 0)
+                n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1)
                 oc_chunk = T.axis.spatial(16, i1_0 * 2 + i1_1 + i1_2 + i1_3)
                 oh = T.axis.spatial(56, i2_0 * 2 + i2_1 * 2 + i2_2 + i2_3)
                 ow = T.axis.spatial(56, i3_3 + i3_0 + i3_1 + i3_2)
-                oc_block_o = T.axis.spatial(1, 0)
-                kh = T.axis.reduce(1, 0)
-                kw = T.axis.reduce(1, 0)
+                oc_block_o = T.axis.spatial(1, i4_0_2 + i4_0_3 + i4_0_0 + i4_0_1)
+                kh = T.axis.reduce(1, i5_1 + i5_0)
+                kw = T.axis.reduce(1, i6_0 + i6_1)
                 ic_outer = T.axis.reduce(4, i7_0 * 4 + i7_1)
                 ic_f_inner = T.axis.reduce(4, i8_0 + i8_1)
-                ic_s_inner_o = T.axis.reduce(1, 0)
+                ic_s_inner_o = T.axis.reduce(1, i9_0_1 + i9_0_0)
                 T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4])
                 T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16])
                 T.block_attr({"meta_schedule.auto_tensorize":"dot_16x4_vnni"})
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
index acc626b904..73b2c990f0 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
@@ -117,7 +117,7 @@ def test_matmul_relu():
                             for ax0_0, ax1_0 in T.grid(2, 1):
                                 with T.block("B_reindex_shared_wmma.matrix_b_o"):
                                     v0_o = T.axis.spatial(8, ax2_0_1 * 2 + ax0_0)
-                                    v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused)
+                                    v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0)
                                     T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"})
@@ -152,7 +152,7 @@ def test_matmul_relu():
                     for ax0_0, ax1_0 in T.grid(2, 1):
                         with T.block("C_reindex_shared_wmma.accumulator_o"):
                             v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0)
-                            v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused)
+                            v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0)
                             T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                             T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                             T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"})
@@ -396,7 +396,8 @@ def test_conv2d():
                         for ax2_0_1 in T.serial(18):
                             for ax0_0, ax1_0 in T.grid(1, 1):
                                 with T.block("PadInput_reindex_shared_wmma.matrix_a_o"):
-                                    v0_o, v1_o = T.axis.remap("SS", [ax0_0_1_ax1_0_1_fused, ax2_0_1])
+                                    v0_o = T.axis.spatial(16, ax0_0_1_ax1_0_1_fused + ax0_0)
+                                    v1_o = T.axis.spatial(18, ax2_0_1 + ax1_0)
                                     T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"})
@@ -408,7 +409,8 @@ def test_conv2d():
                                             PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
                             for ax0_0, ax1_0 in T.grid(1, 1):
                                 with T.block("weight_reindex_shared_wmma.matrix_b_o"):
-                                    v0_o, v1_o = T.axis.remap("SS", [ax2_0_1, ax0_0_0_ax1_0_0_fused])
+                                    v0_o = T.axis.spatial(18, ax2_0_1 + ax0_0)
+                                    v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0)
                                     T.reads(weight_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"})
@@ -442,7 +444,8 @@ def test_conv2d():
                                             conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32")
                     for ax0_0, ax1_0 in T.grid(1, 1):
                         with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
-                            v0_o, v1_o = T.axis.remap("SS", [ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused])
+                            v0_o = T.axis.spatial(16, ax0_0_1_ax1_0_1_fused + ax0_0)
+                            v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0)
                             T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                             T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                             T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"})
@@ -560,7 +563,7 @@ def test_matmul_relu_pipeline():
                             for ax0_0, ax1_0 in T.grid(2, 1):
                                 with T.block("A_reindex_shared_wmma.matrix_a_o"):
                                     v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0)
-                                    v1_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1)
+                                    v1_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1 + ax1_0)
                                     T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"})
@@ -572,7 +575,7 @@ def test_matmul_relu_pipeline():
                                             A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
                             for ax0_0, ax1_0 in T.grid(1, 2):
                                 with T.block("B_reindex_shared_wmma.matrix_b_o"):
-                                    v0_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1)
+                                    v0_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1 + ax0_0)
                                     v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0)
                                     T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
@@ -706,7 +709,7 @@ def test_matmul_relu_global():
                         for ax2_0_1 in T.serial(2):
                             for ax0_0, ax1_0 in T.grid(1, 2):
                                 with T.block("A_reindex_shared_wmma.matrix_a_o"):
-                                    v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2)
+                                    v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2 + ax0_0)
                                     v1_o = T.axis.spatial(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax1_0)
                                     T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
@@ -754,7 +757,7 @@ def test_matmul_relu_global():
                                             C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32")
                     for ax0_0, ax1_0 in T.grid(1, 4):
                         with T.block("C_reindex_wmma.accumulator_o"):
-                            v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2)
+                            v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2 + ax0_0)
                             v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0)
                             T.reads(C_reindex_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                             T.writes(C[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
@@ -875,7 +878,7 @@ def test_padded_matmul_relu():
                             for ax0_0, ax1_0 in T.grid(2, 1):
                                 with T.block("B_reindex_shared_wmma.matrix_b_o"):
                                     v0_o = T.axis.spatial(8, ax2_0_1 * 2 + ax0_0)
-                                    v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused)
+                                    v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0)
                                     T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"})
@@ -910,7 +913,7 @@ def test_padded_matmul_relu():
                     for ax0_0, ax1_0 in T.grid(2, 1):
                         with T.block("C_reindex_shared_wmma.accumulator_o"):
                             v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0)
-                            v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused)
+                            v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0)
                             T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                             T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                             T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"})
@@ -1001,7 +1004,7 @@ def test_conv_1x1():
                         for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1):
                             for ax0_0_1, ax1_0_1 in T.grid(1, 4):
                                 with T.block("PadInput_reindex_shared_wmma.matrix_a_o"):
-                                    v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused)
+                                    v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax0_0_1)
                                     v1_o = T.axis.spatial(4, ax1_0_1)
                                     T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
@@ -1014,10 +1017,8 @@ def test_conv_1x1():
                                             PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
                             for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 4, 1):
                                 with T.block("weight_reindex_shared_wmma.matrix_b_o"):
-                                    v0 = T.axis.spatial(1, 0)
-                                    v1 = T.axis.spatial(1, 0)
-                                    v2_o = T.axis.spatial(4, ax2_0)
-                                    v3_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused)
+                                    v0, v1, v2_o = T.axis.remap("SSS", [ax0, ax1, ax2_0])
+                                    v3_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0)
                                     T.reads(weight_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
                                     T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
                                     T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"})
@@ -1029,8 +1030,8 @@ def test_conv_1x1():
                                             weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = weight_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]
                             for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 4, 1, 1):
                                 with T.block("conv2d_nhwc_o"):
-                                    v0 = T.axis.reduce(1, 0)
-                                    v1 = T.axis.reduce(1, 0)
+                                    v0 = T.axis.reduce(1, ax0_2 + ax0_0 + ax0_1)
+                                    v1 = T.axis.reduce(1, ax1_1 + ax1_2 + ax1_0)
                                     v2_o = T.axis.spatial(16, ax2_0_4 + ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax2_0_3)
                                     v3_o = T.axis.spatial(4, ax3_0_4 + ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0_3)
                                     v4_o = T.axis.reduce(4, ax4_0_0 * 4 + ax4_0_1 * 4 + ax4_0_2)
@@ -1053,8 +1054,8 @@ def test_conv_1x1():
                                             conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i], "float32")
                     for ax0_0, ax1_0 in T.grid(1, 1):
                         with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
-                            v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused)
-                            v1_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused)
+                            v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax0_0)
+                            v1_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax1_0)
                             T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                             T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                             T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"})
diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py b/tests/python/unittest/test_meta_schedule_trace_apply.py
index c8e6bf6a0c..9a62207fa2 100644
--- a/tests/python/unittest/test_meta_schedule_trace_apply.py
+++ b/tests/python/unittest/test_meta_schedule_trace_apply.py
@@ -635,26 +635,26 @@ class Conv2dInt8_tensorcore_scheduled:
     def main(p0: T.Buffer[(16, 56, 56, 64), "int8"], p1: T.Buffer[(256, 1, 1, 64), "int8"], p2: T.Buffer[(1, 1, 1, 256), "int32"], p3: T.Buffer[(1, 1, 1, 256), "int32"], p4: T.Buffer[(1, 1, 1, 256), "int64"], p5: T.Buffer[(1, 1, 1, 256), "int64"], p6: T.Buffer[(1, 1, 1, 256), "int64"], p7: T.Buffer[(), "int32"], p8: T.Buffer[1, "int32"], p9: T.Buffer[(16, 56, 56, 256), "int32"], compute: T.Buffer[(16, 56, 56, 256), "uint8"]) -> None:
         # function attr dict
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        a0 = T.var("int32")
-        a1 = T.var("int32")
-        b0 = T.var("int32")
-        b1 = T.var("int32")
-        c0 = T.var("int32")
-        c1 = T.var("int32")
-        d0 = T.var("int32")
-        d0_1 = T.var("int32")
-        d0_2 = T.var("int32")
-        d0_3 = T.var("int32")
-        d1 = T.var("int32")
-        d1_1 = T.var("int32")
-        d1_2 = T.var("int32")
-        d1_3 = T.var("int32")
-        s0 = T.var("int32")
-        s0_1 = T.var("int32")
-        s0_2 = T.var("int32")
-        s1 = T.var("int32")
-        s1_1 = T.var("int32")
-        s1_2 = T.var("int32")
+        A_s0 = T.var("int32")
+        A_s0_1 = T.var("int32")
+        A_s0_2 = T.var("int32")
+        A_s0_3 = T.var("int32")
+        A_s1 = T.var("int32")
+        A_s1_1 = T.var("int32")
+        A_s1_2 = T.var("int32")
+        A_s1_3 = T.var("int32")
+        B_s0 = T.var("int32")
+        B_s1 = T.var("int32")
+        C_s0 = T.var("int32")
+        C_s0_1 = T.var("int32")
+        C_s0_2 = T.var("int32")
+        C_s0_3 = T.var("int32")
+        C_s0_4 = T.var("int32")
+        C_s1 = T.var("int32")
+        C_s1_1 = T.var("int32")
+        C_s1_2 = T.var("int32")
+        C_s1_3 = T.var("int32")
+        C_s1_4 = T.var("int32")
         # body
         # with T.block("root")
         conv2d_nhwc_reindex_shared = T.alloc_buffer([50176, 256], dtype="int32", scope="shared")
@@ -666,83 +666,81 @@ class Conv2dInt8_tensorcore_scheduled:
         for ax2_0_0_ax3_0_0_fused in T.thread_binding(3136, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":512, "pragma_unroll_explicit":1}):
             for ax2_0_1_ax3_0_1_fused in T.thread_binding(1, thread="vthread.x"):
                 for ax2_0_2_ax3_0_2_fused in T.thread_binding(16, thread="threadIdx.x"):
-                    for ax0_0, ax1_0 in T.grid(1, 1):
-                        for ax2_0_3_init, ax3_0_3_init, ax2_0_4_init, ax3_0_4_init in T.grid(1, 1, 1, 1):
-                            with T.block("conv2d_nhwc_o_init"):
-                                v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3_init + ax2_0_4_init)
-                                v3_o = T.axis.spatial(16, ax3_0_4_init + ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3_init)
-                                T.reads()
-                                T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
-                                T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1})
-                                C = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int32", strides=[d1, d0], scope="wmma.accumulator", offset_factor=16)
-                                T.evaluate(T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // d1 // 16 * (d1 // 16) + C.elem_offset % d1 // 16, T.float32(0), dtype="handle"))
-                        for ax4_0_0 in T.serial(2):
-                            for ax0_ax1_fused_0 in T.serial(16):
-                                for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.x"):
-                                    for ax0_ax1_fused_2 in T.vectorized(16):
-                                        with T.block("pad_temp_reindex_shared"):
-                                            v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 8 * 128 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) // 32)
-                                            v1 = T.axis.spatial(64, ax4_0_0 * 32 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) % 32)
-                                            T.reads(p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1])
-                                            T.writes(pad_temp_reindex_shared[v0, v1])
-                                            T.block_attr({"buffer_dim_align":[[0, 0, 32, 16]]})
-                                            pad_temp_reindex_shared[v0, v1] = p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]
-                            for ax0_ax1_ax2_ax3_fused_0 in T.serial(8):
-                                for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(16, thread="threadIdx.x"):
-                                    for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(8):
-                                        with T.block("p1_reindex_shared"):
-                                            v0 = T.axis.spatial(1, 0)
-                                            v1 = T.axis.spatial(1, 0)
-                                            v2 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 8 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 8 + ax0_ax1_ax2_ax3_fused_2) // 32)
-                                            v3 = T.axis.spatial(64, ax4_0_0 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 8 + ax0_ax1_ax2_ax3_fused_2) % 32)
-                                            T.reads(p1[v2, v0, v1, v3])
-                                            T.writes(p1_reindex_shared[v0, v1, v2, v3])
-                                            T.block_attr({"buffer_dim_align":[[0, 2, 32, 16]]})
-                                            p1_reindex_shared[v0, v1, v2, v3] = p1[v2, v0, v1, v3]
-                            for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1):
-                                for ax0_0_1, ax1_0_1 in T.grid(1, 2):
-                                    with T.block("pad_temp_reindex_shared_wmma.matrix_a_o"):
-                                        v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2)
-                                        v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax1_0_1)
-                                        T.reads(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                                        T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                                        A = T.match_buffer(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int8", strides=[s1, s0], scope="shared", offset_factor=16)
-                                        C_1 = T.match_buffer(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int8", strides=[d1_1, d0_1], scope="wmma.matrix_a", offset_factor=16)
-                                        T.evaluate(T.tvm_load_matrix_sync(C_1.data, 16, 16, 16, C_1.elem_offset // d1_1 // 16 * (d1_1 // 16) + C_1.elem_offset % d1_1 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int8"), A.data, A.elem_offset, s1 * 16, 1, dtype="handle"), s1, "row_major", dtype="handle"))
-                                for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 1, 2):
-                                    with T.block("p1_reindex_shared_wmma.matrix_b_o"):
+                    for ax2_0_3_init, ax3_0_3_init, ax2_0_4_init, ax3_0_4_init in T.grid(1, 1, 1, 1):
+                        with T.block("conv2d_nhwc_o_init"):
+                            v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3_init + ax2_0_4_init)
+                            v3_o = T.axis.spatial(16, ax3_0_4_init + ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3_init)
+                            T.reads()
+                            T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
+                            T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1})
+                            C = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int32", strides=[C_s0, C_s1], scope="wmma.accumulator", offset_factor=16)
+                            T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16, T.float32(0), dtype="handle")
+                    for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 2):
+                        for ax0_ax1_fused_0 in T.serial(16):
+                            for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.x"):
+                                for ax0_ax1_fused_2 in T.vectorized(16):
+                                    with T.block("pad_temp_reindex_shared"):
+                                        v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 8 * 128 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) // 32)
+                                        v1 = T.axis.spatial(64, ax4_0_0 * 32 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) % 32)
+                                        T.reads(p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1])
+                                        T.writes(pad_temp_reindex_shared[v0, v1])
+                                        T.block_attr({"buffer_dim_align":[[0, 0, 32, 16]]})
+                                        pad_temp_reindex_shared[v0, v1] = p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]
+                        for ax0_ax1_ax2_ax3_fused_0 in T.serial(8):
+                            for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(16, thread="threadIdx.x"):
+                                for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(8):
+                                    with T.block("p1_reindex_shared"):
                                         v0 = T.axis.spatial(1, 0)
                                         v1 = T.axis.spatial(1, 0)
-                                        v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2)
-                                        v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax3_0)
-                                        T.reads(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
-                                        T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
-                                        A_1 = T.match_buffer(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int8", strides=[s1_1, s0_1], scope="shared", offset_factor=16)
-                                        C_2 = T.match_buffer(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int8", strides=[d1_2, d0_2], scope="wmma.matrix_b", offset_factor=16)
-                                        T.evaluate(T.tvm_load_matrix_sync(C_2.data, 16, 16, 16, C_2.elem_offset // d1_2 // 16 * (d1_2 // 16) + C_2.elem_offset % d1_2 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int8"), A_1.data, A_1.elem_offset, s1_1 * 16, 1, dtype="handle"), s1_1, "col_major", dtype="handle"))
-                                for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 2, 1, 1):
-                                    with T.block("conv2d_nhwc_o_update"):
-                                        v0 = T.axis.reduce(1, 0)
-                                        v1 = T.axis.reduce(1, 0)
-                                        v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4)
-                                        v3_o = T.axis.spatial(16, ax3_0_4 + ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3)
-                                        v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 * 2 + ax4_0_2)
-                                        T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16])
-                                        T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
-                                        T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1})
-                                        A_2 = T.match_buffer(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], [16, 16], dtype="int8", strides=[a1, a0], scope="wmma.matrix_a", offset_factor=16)
-                                        B = T.match_buffer(p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], [16, 16], dtype="int8", strides=[b1, b0], scope="wmma.matrix_b", offset_factor=16)
-                                        C_3 = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int32", strides=[c1, c0], scope="wmma.accumulator", offset_factor=16)
-                                        T.evaluate(T.tvm_mma_sync(C_3.data, C_3.elem_offset // c1 // 16 * (c1 // 16) + C_3.elem_offset % c1 // 16, A_2.data, A_2.elem_offset // a1 // 16 * (a1 // 16) + A_2.elem_offset % a1 // 16, B.data, B.elem_offset // b1 // 16 * (b1 // 16) + B.elem_offset % b1 // 16, C_3.data, C_3.elem_offset // c1 // 16 * (c1 // 16) + C_3.elem_offset % c1 // 16, dtype="handle"))
+                                        v2 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 8 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 8 + ax0_ax1_ax2_ax3_fused_2) // 32)
+                                        v3 = T.axis.spatial(64, ax4_0_0 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 8 + ax0_ax1_ax2_ax3_fused_2) % 32)
+                                        T.reads(p1[v2, v0, v1, v3])
+                                        T.writes(p1_reindex_shared[v0, v1, v2, v3])
+                                        T.block_attr({"buffer_dim_align":[[0, 2, 32, 16]]})
+                                        p1_reindex_shared[v0, v1, v2, v3] = p1[v2, v0, v1, v3]
+                        for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1):
+                            for ax0_0_1, ax1_0_1 in T.grid(1, 2):
+                                with T.block("pad_temp_reindex_shared_wmma.matrix_a_o"):
+                                    v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax0_0_1)
+                                    v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax1_0_1)
+                                    T.reads(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
+                                    A = T.match_buffer(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int8", strides=[A_s0, A_s1], scope="shared", offset_factor=16)
+                                    C_1 = T.match_buffer(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int8", strides=[C_s0_1, C_s1_1], scope="wmma.matrix_a", offset_factor=16)
+                                    T.tvm_load_matrix_sync(C_1.data, 16, 16, 16, C_1.elem_offset // C_s0_1 // 16 * (C_s0_1 // 16) + C_1.elem_offset % C_s0_1 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int8"), A.data, A.elem_offset, A_s0 * 16, 1, dtype="handle"), A_s0, "row_major", dtype="handle")
+                            for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 1, 2):
+                                with T.block("p1_reindex_shared_wmma.matrix_b_o"):
+                                    v0, v1 = T.axis.remap("SS", [ax0, ax1])
+                                    v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax2_0)
+                                    v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax3_0)
+                                    T.reads(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
+                                    T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
+                                    A_1 = T.match_buffer(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int8", strides=[A_s0_1, A_s1_1], scope="shared", offset_factor=16)
+                                    C_2 = T.match_buffer(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int8", strides=[C_s0_2, C_s1_2], scope="wmma.matrix_b", offset_factor=16)
+                                    T.tvm_load_matrix_sync(C_2.data, 16, 16, 16, C_2.elem_offset // C_s0_2 // 16 * (C_s0_2 // 16) + C_2.elem_offset % C_s0_2 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int8"), A_1.data, A_1.elem_offset, A_s0_1 * 16, 1, dtype="handle"), A_s0_1, "col_major", dtype="handle")
+                            for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 2, 1, 1):
+                                with T.block("conv2d_nhwc_o_update"):
+                                    v0 = T.axis.reduce(1, ax0_2 + ax0_0 + ax0_1)
+                                    v1 = T.axis.reduce(1, ax1_1 + ax1_2 + ax1_0)
+                                    v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4)
+                                    v3_o = T.axis.spatial(16, ax3_0_4 + ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3)
+                                    v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 * 2 + ax4_0_2)
+                                    T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16])
+                                    T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
+                                    T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1})
+                                    A_2 = T.match_buffer(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], [16, 16], dtype="int8", strides=[A_s0_2, A_s1_2], scope="wmma.matrix_a", offset_factor=16)
+                                    B = T.match_buffer(p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], [16, 16], dtype="int8", strides=[B_s0, B_s1], scope="wmma.matrix_b", offset_factor=16)
+                                    C_3 = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int32", strides=[C_s0_3, C_s1_3], scope="wmma.accumulator", offset_factor=16)
+                                    T.tvm_mma_sync(C_3.data, C_3.elem_offset // C_s0_3 // 16 * (C_s0_3 // 16) + C_3.elem_offset % C_s0_3 // 16, A_2.data, A_2.elem_offset // A_s0_2 // 16 * (A_s0_2 // 16) + A_2.elem_offset % A_s0_2 // 16, B.data, B.elem_offset // B_s0 // 16 * (B_s0 // 16) + B.elem_offset % B_s0 // 16, C_3.data, C_3.elem_offset // C_s0_3 // 16 * (C_s0_3 // 16) + C_3.elem_offset % C_s0_3 // 16, dtype="handle")
                     for ax0_0, ax1_0 in T.grid(1, 1):
                         with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
-                            v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2)
-                            v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2)
+                            v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax0_0)
+                            v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax1_0)
                             T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                             T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                            A_3 = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int32", strides=[d1_3, d0_3], scope="wmma.accumulator", offset_factor=16)
-                            C_4 = T.match_buffer(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int32", strides=[s1_2, s0_2], scope="shared", offset_factor=16)
-                            T.evaluate(T.tvm_store_matrix_sync(A_3.data, 16, 16, 16, A_3.elem_offset // d1_3 // 16 * (d1_3 // 16) + A_3.elem_offset % d1_3 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int32"), C_4.data, C_4.elem_offset, s1_2 * 16, 2, dtype="handle"), s1_2, "row_major", dtype="handle"))
+                            A_3 = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int32", strides=[A_s0_3, A_s1_3], scope="wmma.accumulator", offset_factor=16)
+                            C_4 = T.match_buffer(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int32", strides=[C_s0_4, C_s1_4], scope="shared", offset_factor=16)
+                            T.tvm_store_matrix_sync(A_3.data, 16, 16, 16, A_3.elem_offset // A_s0_3 // 16 * (A_s0_3 // 16) + A_3.elem_offset % A_s0_3 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int32"), C_4.data, C_4.elem_offset, C_s0_4 * 16, 2, dtype="handle"), C_s0_4, "row_major", dtype="handle")
                 for ax0, ax1_0 in T.grid(128, 2):
                     for ax1_1 in T.thread_binding(16, thread="threadIdx.x"):
                         with T.block("conv2d_nhwc_reindex_shared"):
@@ -1145,45 +1143,44 @@ def get_conv2d_vnni_mod(intrin_id):
             conv2d_NCHWc_int8 = T.alloc_buffer([1, 128, 7, 7, 16], dtype="int32")
             for i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused in T.parallel(128, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}):
                 for i2_1, i3_1, i4_0_1 in T.grid(7, 1, 1):
-                    for i5_0, i6_0 in T.grid(1, 1):
-                        for i1_2_init, i2_2_init, i3_2_init, i1_3_init, i2_3_init, i3_3_init in T.grid(1, 1, 1, 1, 1, 7):
-                            with T.block("conv2d_NCHWc_int8_o_init"):
-                                n = T.axis.spatial(1, 0)
-                                oc_chunk = T.axis.spatial(128, i1_2_init + i1_3_init + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused // 32 * 32 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused % 32)
-                                oh = T.axis.spatial(7, i2_1 + i2_2_init + i2_3_init)
-                                ow = T.axis.spatial(7, i3_1 * 7 + i3_2_init * 7 + i3_3_init)
-                                oc_block_o = T.axis.spatial(1, 0)
-                                T.reads()
-                                T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16])
-                                for i4_1 in T.vectorized(16):
-                                    with T.block("conv2d_NCHWc_int8_init"):
-                                        oc_block_i_init = T.axis.spatial(16, i4_1)
-                                        T.reads()
-                                        T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init])
-                                        conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init] = 0
-                        for i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 1, 1, 1, 1, 7, 1):
-                            with T.block("conv2d_NCHWc_int8_o_update"):
-                                n = T.axis.spatial(1, 0)
-                                oc_chunk = T.axis.spatial(128, i1_2 + i1_3 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused // 32 * 32 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused % 32)
-                                oh = T.axis.spatial(7, i2_1 + i2_2 + i2_3)
-                                ow = T.axis.spatial(7, i3_1 * 7 + i3_2 * 7 + i3_3)
-                                oc_block_o = T.axis.spatial(1, 0)
-                                kh = T.axis.reduce(1, 0)
-                                kw = T.axis.reduce(1, 0)
-                                ic_outer = T.axis.reduce(32, i7_0 * 8 + i7_1)
-                                ic_f_inner = T.axis.reduce(4, i8_1 + i8_0)
-                                ic_s_inner_o = T.axis.reduce(1, 0)
-                                T.reads(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16], p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4])
-                                T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16])
-                                A = T.match_buffer(p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], [4], dtype="uint8", offset_factor=1)
-                                B = T.match_buffer(p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4], [16, 4], dtype="int8", offset_factor=1)
-                                C = T.match_buffer(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16], [16], dtype="int32", offset_factor=1)
-                                A_u8x4: T.uint8x4 = A[0:4]
-                                A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32")
-                                B_i8x64: T.int8x64 = B[0, 0:64]
-                                B_i32x16: T.int32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
-                                C_i32x16: T.int32x16 = C[0:16]
-                                C[0:16] = T.call_llvm_pure_intrin(intrin_id, T.uint32(0), C_i32x16, T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16")
+                    for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i4_0_2_init, i0_3_init, i1_3_init, i2_3_init, i3_3_init, i4_0_3_init in T.grid(1, 1, 1, 1, 1, 1, 1, 1, 7, 1):
+                        with T.block("conv2d_NCHWc_int8_o_init"):
+                            n = T.axis.spatial(1, i0_3_init + i0_2_init)
+                            oc_chunk = T.axis.spatial(128, i1_2_init + i1_3_init + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused // 32 * 32 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused % 32)
+                            oh = T.axis.spatial(7, i2_1 + i2_2_init + i2_3_init)
+                            ow = T.axis.spatial(7, i3_1 * 7 + i3_2_init * 7 + i3_3_init)
+                            oc_block_o = T.axis.spatial(1, i4_0_3_init + i4_0_1 + i4_0_2_init)
+                            T.reads()
+                            T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16])
+                            for i4_1 in T.vectorized(16):
+                                with T.block("conv2d_NCHWc_int8_init"):
+                                    oc_block_i_init = T.axis.spatial(16, i4_1)
+                                    T.reads()
+                                    T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init])
+                                    conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init] = 0
+                    for i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 1, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 1, 1, 1, 1, 7, 1):
+                        with T.block("conv2d_NCHWc_int8_o_update"):
+                            n = T.axis.spatial(1, i0_3 + i0_2)
+                            oc_chunk = T.axis.spatial(128, i1_2 + i1_3 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused // 32 * 32 + i0_0_i1_0_i2_0_i3_0_i4_0_0_i0_1_i1_1_fused % 32)
+                            oh = T.axis.spatial(7, i2_1 + i2_2 + i2_3)
+                            ow = T.axis.spatial(7, i3_1 * 7 + i3_2 * 7 + i3_3)
+                            oc_block_o = T.axis.spatial(1, i4_0_3 + i4_0_1 + i4_0_2)
+                            kh = T.axis.reduce(1, i5_0 + i5_1)
+                            kw = T.axis.reduce(1, i6_1 + i6_0)
+                            ic_outer = T.axis.reduce(32, i7_0 * 8 + i7_1)
+                            ic_f_inner = T.axis.reduce(4, i8_1 + i8_0)
+                            ic_s_inner_o = T.axis.reduce(1, i9_0_0 + i9_0_1)
+                            T.reads(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16], p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4])
+                            T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16])
+                            A = T.match_buffer(p0[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], [4], dtype="uint8", offset_factor=1)
+                            B = T.match_buffer(p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4], [16, 4], dtype="int8", offset_factor=1)
+                            C = T.match_buffer(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16], [16], dtype="int32", offset_factor=1)
+                            A_u8x4: T.uint8x4 = A[0:4]
+                            A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32")
+                            B_i8x64: T.int8x64 = B[0, 0:64]
+                            B_i32x16: T.int32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
+                            C_i32x16: T.int32x16 = C[0:16]
+                            C[0:16] = T.call_llvm_pure_intrin(T.uint32(intrin_id), T.uint32(0), C_i32x16, T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16")
                     for ax0, ax1, ax2, ax3 in T.grid(1, 1, 1, 7):
                         for ax4_fused in T.vectorized(16):
                             with T.block("T_cast_8"):
@@ -1740,8 +1737,8 @@ class Conv2dInt8_with_predicate_scheduled:
                             for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 2):
                                 for ax0_0_1, ax1_0_1 in T.grid(1, 1):
                                     with T.block("pad_temp_reindex_shared_wmma.matrix_a_o"):
-                                        v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2)
-                                        v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1)
+                                        v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax0_0_1)
+                                        v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1 + ax1_0_1)
                                         T.reads(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                         T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                         T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_a"})
@@ -1753,10 +1750,9 @@ class Conv2dInt8_with_predicate_scheduled:
                                                 pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
                                 for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 2, 1):
                                     with T.block("p1_reindex_shared_wmma.matrix_b_o"):
-                                        v0 = T.axis.spatial(1, 0)
-                                        v1 = T.axis.spatial(1, 0)
+                                        v0, v1 = T.axis.remap("SS", [ax0, ax1])
                                         v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax2_0)
-                                        v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1)
+                                        v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1 + ax3_0)
                                         T.reads(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
                                         T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
                                         T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_b_trans"})
@@ -1768,8 +1764,8 @@ class Conv2dInt8_with_predicate_scheduled:
                                                 p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]
                                 for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 1, 1, 2):
                                     with T.block("conv2d_nhwc_o"):
-                                        v0 = T.axis.reduce(1, 0)
-                                        v1 = T.axis.reduce(1, 0)
+                                        v0 = T.axis.reduce(1, ax0_2 + ax0_0 + ax0_1)
+                                        v1 = T.axis.reduce(1, ax1_1 + ax1_2 + ax1_0)
                                         v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4)
                                         v3_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax3_0_3 * 2 + ax3_0_4)
                                         v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 + ax4_0_2)
@@ -1789,10 +1785,10 @@ class Conv2dInt8_with_predicate_scheduled:
                                                 T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i])
                                                 T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i])
                                                 T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"})
-                                                conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.cast(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], "int32") * T.cast(p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i], "int32")
+                                                conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.Cast("int32", pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("int32", p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i])
                         for ax0_0, ax1_0 in T.grid(1, 2):
                             with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
-                                v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2)
+                                v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax0_0)
                                 v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax1_0)
                                 T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                 T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
@@ -2478,7 +2474,7 @@ def test_conv2d_int8_tensorcore():
             l311,
             l312,
         ) = sch.get_loops(block=b296)
-        b313 = sch.decompose_reduction(block=b296, loop=l302)
+        b313 = sch.decompose_reduction(block=b296, loop=l300)
         sch.unannotate(block_or_loop=b313, ann_key="meta_schedule.auto_tensorize")
         sch.annotate(
             block_or_loop=b313,
@@ -2723,7 +2719,7 @@ def test_conv2d_int8_vnni():
             l188,
             l189,
         ) = sch.get_loops(block=b165)
-        b190 = sch.decompose_reduction(block=b165, loop=l172)
+        b190 = sch.decompose_reduction(block=b165, loop=l170)
         sch.unannotate(block_or_loop=b190, ann_key="meta_schedule.auto_tensorize")
         sch.annotate(block_or_loop=b190, ann_key="meta_schedule.auto_tensorize", ann_val="")
         b191 = sch.get_block(name="conv2d_NCHWc_int8_o_init", func_name="main")
diff --git a/tests/python/unittest/test_tir_schedule_blockize.py b/tests/python/unittest/test_tir_schedule_blockize.py
index 12836cdb9e..a68170009b 100644
--- a/tests/python/unittest/test_tir_schedule_blockize.py
+++ b/tests/python/unittest/test_tir_schedule_blockize.py
@@ -20,6 +20,7 @@ import tvm.testing
 from tvm import tir
 from tvm.script import tir as T
 from tvm.tir.schedule.testing import verify_trace_roundtrip
+import pytest
 
 # fmt: off
 # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
@@ -247,7 +248,8 @@ def test_blockize_init_loops():
     verify_trace_roundtrip(sch=s, mod=rowsum)
 
 
-def test_blockize_outer_int64_shape():
+@pytest.mark.parametrize("preserve_unit_iters", [True, False])
+def test_blockize_outer_int64_shape(preserve_unit_iters):
     @T.prim_func
     def single_elementwise_int64(
         A: T.Buffer[(T.int64(16), T.int64(128)), "float32"],
@@ -275,10 +277,31 @@ def test_blockize_outer_int64_shape():
                             vi_i, vj_o * T.int64(16) + vj_i
                         ] + T.float32(1)
 
+    @T.prim_func
+    def after_single_elementwise_int64_blockize_preserve_unit_iters(
+        A: T.Buffer[(T.int64(16), T.int64(128)), "float32"],
+        B: T.Buffer[(T.int64(16), T.int64(128)), "float32"],
+    ) -> None:
+        for i0, j0 in T.grid(T.int64(1), T.int64(8)):
+            with T.block("B_o"):
+                vi_o = T.axis.spatial(T.int64(1), i0)
+                vj_o = T.axis.spatial(T.int64(8), j0)
+                for i1, j1 in T.grid(T.int64(16), T.int64(16)):
+                    with T.block("B"):
+                        vi_i, vj_i = T.axis.remap("SS", [i1, j1])
+                        B[vi_i, vj_o * T.int64(16) + vj_i] = A[
+                            vi_i, vj_o * T.int64(16) + vj_i
+                        ] + T.float32(1)
+
     s = tir.Schedule(single_elementwise_int64, debug_mask="all")
     _, _, i1, _ = s.get_loops(s.get_block("B"))
-    s.blockize(i1)
-    tvm.ir.assert_structural_equal(s.mod["main"], after_single_elementwise_int64_blockize)
+    s.blockize(i1, preserve_unit_iters=preserve_unit_iters)
+    expected = (
+        after_single_elementwise_int64_blockize_preserve_unit_iters
+        if preserve_unit_iters
+        else after_single_elementwise_int64_blockize
+    )
+    tvm.ir.assert_structural_equal(s.mod["main"], expected)
     verify_trace_roundtrip(sch=s, mod=single_elementwise_int64)