You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/07/20 09:58:21 UTC

[tvm] branch main updated: Fix #12039‘s broken cases (#12143)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7abdce2660 Fix #12039‘s broken cases (#12143)
7abdce2660 is described below

commit 7abdce26606551776e55a458622e23182e8ae9d4
Author: wrongtest <wr...@gmail.com>
AuthorDate: Wed Jul 20 17:58:11 2022 +0800

    Fix #12039‘s broken cases (#12143)
---
 src/arith/iter_affine_map.cc                       |  91 ++++++++----
 tests/python/unittest/test_arith_intset.py         |   7 +-
 .../python/unittest/test_arith_iter_affine_map.py  |  58 +++++++-
 .../unittest/test_meta_schedule_space_cpu.py       | 164 ++++++++++-----------
 .../unittest/test_meta_schedule_space_cuda.py      |  84 +++++------
 tests/python/unittest/test_tir_schedule_reorder.py |  30 +++-
 .../unittest/test_tir_schedule_split_fuse.py       |   8 +-
 .../test_tir_schedule_state_cached_flags.py        |   2 +-
 8 files changed, 281 insertions(+), 163 deletions(-)

diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index d2aa16ded1..83e2821c98 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -177,8 +177,12 @@ class IterMapRewriter : public ExprMutator {
   using Parent = ExprMutator;
 
   explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters,
-                           bool simplify_trivial_iterators, Array<String>* errors)
-      : analyzer_(analyzer), errors_(*errors), padding_predicate_(const_false()) {
+                           IterMapLevel check_level, bool simplify_trivial_iterators,
+                           Array<String>* errors)
+      : analyzer_(analyzer),
+        check_level_(check_level),
+        errors_(*errors),
+        padding_predicate_(const_false()) {
     for (auto kv : input_iters) {
       const Var& var = kv.first;
       const Range& vrng = kv.second;
@@ -419,6 +423,8 @@ class IterMapRewriter : public ExprMutator {
 
   // Internal analyzer
   Analyzer* analyzer_;
+  // Iter map check level
+  IterMapLevel check_level_;
   // Error messages for each unresolved expression.
   Array<String>& errors_;
   // The var map
@@ -651,7 +657,7 @@ class IterMapRewriter : public ExprMutator {
       if (predicate_induced_max.defined())
         predicate_induced_max = predicate_induced_max.value() - base;
     }
-    Optional<IterSumExpr> opt = TryFuseIters(expr);
+    Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
     ICHECK(!opt.defined() || opt.value()->args.size() == 1);
     // scale should be 1
     if (opt.defined() && is_one(opt.value()->args[0]->scale)) {
@@ -702,7 +708,7 @@ class IterMapRewriter : public ExprMutator {
   IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
     // We are normalizing a regular iter
     if (expr->args.size() < 1) return expr;
-    Optional<IterSumExpr> opt = TryFuseIters(expr);
+    Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
     if (opt.defined()) {
       return opt.value();
     } else {
@@ -735,9 +741,10 @@ class IterMapRewriter : public ExprMutator {
    *    return a corresponding IterSumExpr with extra offset if needed.
    *    Try to normalize IterSum into a fused IterMark
    * \param expr The input sum.
+   * \param check_level The check level if iter mapping.
    * \return The sum with the fused IterMark and extra offset if succeed.
    */
-  Optional<IterSumExpr> TryFuseIters(IterSumExpr expr) {
+  Optional<IterSumExpr> TryFuseIters(IterSumExpr expr, IterMapLevel check_level) {
     // select the iterators in order
     std::vector<bool> visited(expr->args.size(), false);
     std::vector<IterSplitExpr> flattened_iters, grouped_iters;
@@ -758,14 +765,42 @@ class IterMapRewriter : public ExprMutator {
     }
     // check if it can be remapped into a fused pattern.
     PrimExpr expected_extra_base = 0;
+    PrimExpr tail_extent = 0;
     PrimExpr expected_scale = base_scale.value();
     for (size_t i = 0; i < expr->args.size();) {
-      // find j such that expr->args[j] has expected scale
-      size_t j = i == 0 ? base_index : 0;
-      for (; j < expr->args.size(); ++j) {
-        if (!visited[j] && analyzer_->CanProveEqual(expr->args[j]->scale, expected_scale)) break;
+      // find position such that expr->args[j] match expected scale
+      int j = i == 0 ? base_index : expr->args.size() - 1;
+
+      size_t matched_pos = expr->args.size();
+      PrimExpr matched_scale{nullptr};
+      bool is_exact_match{false};
+
+      for (; j >= 0; --j) {
+        if (visited[j]) {
+          continue;
+        }
+        const PrimExpr& cur_scale = expr->args[j]->scale;
+
+        // for bijective mapping, the matched scale must equal to expected scale
+        if (analyzer_->CanProveEqual(cur_scale, expected_scale)) {
+          matched_pos = j;
+          matched_scale = cur_scale;
+          is_exact_match = true;
+          break;
+        }
+        if (check_level != IterMapLevel::Bijective && base_scale.value()->value == 1) {
+          // find the closest scale which is less or equal to expected scale
+          if (analyzer_->CanProveGreaterEqual(expected_scale - cur_scale, 0) &&
+              analyzer_->CanProveGreaterEqual(cur_scale, 0)) {
+            if (matched_pos == expr->args.size() ||
+                analyzer_->CanProveLess(matched_scale - cur_scale, 0)) {
+              matched_pos = j;
+              matched_scale = cur_scale;
+            }
+          }
+        }
       }
-      if (j == expr->args.size()) {
+      if (matched_pos == expr->args.size()) {
         return NullOpt;
       }
       // look for the longest constrained iter started from expr->args[j]
@@ -775,8 +810,8 @@ class IterMapRewriter : public ExprMutator {
       // otherwise we expect the scale of i to be 2*5=10
       Optional<IterSumExpr> constraint_to_match;
       for (const IterSumExpr& iter : constrained_iters_flattened_) {
-        if (IterSplitEqual(expr->args[j], iter->args.back(), false)) {
-          // find a predicate started from expr->args[j]
+        if (IterSplitEqual(expr->args[matched_pos], iter->args.back(), false)) {
+          // find a predicate started from match position
           if (!constraint_to_match ||
               constraint_to_match.value()->args.size() < iter->args.size()) {
             constraint_to_match = iter;
@@ -793,7 +828,7 @@ class IterMapRewriter : public ExprMutator {
           size_t k = 0;
           for (; k < expr->args.size(); ++k) {
             if (!visited[k] && IterSplitEqual(expr->args[k], *it, false)) {
-              if (analyzer_->CanProveEqual((*it)->scale * expected_scale, expr->args[k]->scale))
+              if (analyzer_->CanProveEqual((*it)->scale * matched_scale, expr->args[k]->scale))
                 break;
             }
           }
@@ -806,20 +841,25 @@ class IterMapRewriter : public ExprMutator {
         auto iter = sum_fuse_map_.find(constraint_to_match.value());
         ICHECK(iter != sum_fuse_map_.end());
         const IterMarkWithOffset& iter_matched = iter->second;
-        grouped_iters.emplace_back(iter_matched.mark, expected_scale);
-        expected_extra_base += iter_matched.offset * expected_scale;
-        expected_scale *= iter_matched.mark->extent;
+        grouped_iters.emplace_back(iter_matched.mark, div(matched_scale, base_scale.value()));
+        expected_extra_base += iter_matched.offset * matched_scale;
+        if (!is_exact_match) {
+          tail_extent += expected_scale - matched_scale;
+        }
+        expected_scale = matched_scale * iter_matched.mark->extent;
         // move forward
         i += constraint_to_match.value()->args.size();
       } else {
         // constraint_to_match not found, skip this iterator
-        visited[j] = true;
-        IterSplitExpr arg = expr->args[j];
-        arg.CopyOnWrite()->scale =
-            analyzer_->Simplify(div(expr->args[j]->scale, base_scale.value()));
+        visited[matched_pos] = true;
+        IterSplitExpr arg = expr->args[matched_pos];
+        arg.CopyOnWrite()->scale = analyzer_->Simplify(div(arg->scale, base_scale.value()));
         flattened_iters.push_back(arg);
         grouped_iters.push_back(arg);
-        expected_scale *= expr->args[j]->extent;
+        if (!is_exact_match) {
+          tail_extent += expected_scale - matched_scale;
+        }
+        expected_scale = matched_scale * expr->args[matched_pos]->extent;
         ++i;
       }
     }
@@ -843,7 +883,8 @@ class IterMapRewriter : public ExprMutator {
                          expr->base + expected_extra_base);
     } else {
       // new iter, form a new mark
-      IterMark mark = IterMark(structured_form, div(expected_scale, base_scale.value()));
+      IterMark mark =
+          IterMark(structured_form, div(expected_scale, base_scale.value()) + tail_extent);
       sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0);
       flattened_map_[structured_form] = flattened_form;
       return IterSumExpr({IterSplitExpr(mark, base_scale.value())},
@@ -1086,8 +1127,8 @@ IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range
       constraints.begin(), constraints.end(),
       [](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; });
 
-  IterMapRewriter rewriter(analyzer, constrained_input_iters, simplify_trivial_iterators,
-                           &result->errors);
+  IterMapRewriter rewriter(analyzer, constrained_input_iters, check_level,
+                           simplify_trivial_iterators, &result->errors);
   // Step0.0: rewrite constraints in the order from size-small ones to size-big ones
   for (const IterConstraint& constraint : constraints) {
     auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound,
@@ -1281,7 +1322,7 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o
     } else if (sum->args.size() == 1) {
       return sum;
     }
-    auto opt_fused = TryFuseIters(sum);
+    auto opt_fused = TryFuseIters(sum, check_level_);
     if (!opt_fused) {
       ErrorLogger(this) << "Dividend  " << tvm::PrettyPrint(original_dividend)
                         << ", can't be written as a single fused IterSum";
diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py
index ca9d1077fe..74b53442ec 100644
--- a/tests/python/unittest/test_arith_intset.py
+++ b/tests/python/unittest/test_arith_intset.py
@@ -323,10 +323,6 @@ def test_region_lower_bound_for_non_perfect_tile():
 
 
 def test_region_lower_bound_unfusable():
-    # This test is designed to trigger an error in DetectIterMap,
-    # resulting from a numerator which required multiple input
-    # variables.  The bug resulted in an exception being thrown,
-    # rather than a return value of None.
     var_dom = {
         tvm.tir.Var("i", "int32"): tvm.ir.Range(8),
         tvm.tir.Var("j", "int32"): tvm.ir.Range(4),
@@ -336,7 +332,8 @@ def test_region_lower_bound_unfusable():
         tvm.ir.Range.from_min_extent((i + j) // 2, 1),
     ]
     result = tvm.arith.estimate_region_lower_bound(region, var_dom, predicate=True)
-    assert result is None
+    assert result[0].min_value == 0
+    assert result[0].max_value == 5
 
 
 def test_union_lower_bound():
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py
index 7bc5ead298..6a2fdbbb3f 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -61,7 +61,6 @@ def assert_iter_sum_pattern(
     )
     indices = res.indices
     assert len(indices) == len(keys), res.errors
-    print(indices)
     for i, input_iter in enumerate(keys):
         spec = expect_dict[input_iter]
         (
@@ -446,6 +445,13 @@ def test_predicate():
         predicate=xo * 129 + xi < 128,
     )
 
+    # strided iteration predicate
+    assert_iter_sum_pattern(
+        {xo * 16 + xi * 4: (10, 0, 4)},
+        var_dom([(xo, 3), (xi, 4)]),
+        predicate=xo * 4 + xi < 10,
+    )
+
 
 def convert_division(divisions):
     if divisions is None or len(divisions) == 0:
@@ -1010,5 +1016,55 @@ def test_padding():
     assert_iter_sum_failure({flm(x, 16)}, var_dom([(x, 3)]))
 
 
+def test_overlapped_fuse():
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
+    z = tvm.tir.Var("z", "int32")
+    a = tvm.tir.Var("x", "int32")
+    b = tvm.tir.Var("y", "int32")
+
+    # non-bijective fuse of two
+    assert_iter_sum_pattern(
+        {
+            x * 7 + y: (22, 0, 1),
+        },
+        var_dom([(x, 3), (y, 8)]),
+        check_level="surjective",
+    )
+    assert_iter_sum_failure([x * 7 + y], var_dom([(x, 3), (y, 8)]), check_level="bijective")
+
+    # non-bijective fuse of three
+    assert_iter_sum_pattern(
+        {
+            x * 18 + y * 7 + z: (40, 0, 1),
+        },
+        var_dom([(x, 2), (y, 3), (z, 8)]),
+        check_level="surjective",
+    )
+    assert_iter_sum_failure([x * 7 + y], var_dom([(x, 2), (y, 3), (z, 8)]), check_level="bijective")
+
+    # negative scale fusion is not allowed
+    assert_iter_sum_failure([x * -7 + y], var_dom([(x, 3), (y, 8)]), check_level="surjective")
+    assert_iter_sum_failure([x * 7 - y], var_dom([(x, 3), (y, 8)]), check_level="surjective")
+
+    # with predicate
+    assert_iter_sum_pattern(
+        {
+            a * 40 + b * 20 + x * 18 + y * 3 + z: (125, 6, 1),
+        },
+        var_dom([(a, 3), (b, 2), (x, 2), (y, 6), (z, 8)]),
+        predicate=tvm.tir.all(z < 4, 1 < x * 6 + y, x * 6 + y < 10),
+        check_level="surjective",
+    )
+
+    # stride=1 kernel
+    assert_iter_sum_pattern(
+        {x + a: (230, 0, 1)}, var_dom([(x, 224), (a, 7)]), check_level="surjective"
+    )
+
+    # do not allow both strided and overlapped
+    assert_iter_sum_failure([5 * x + 2 * y], var_dom([(x, 4), (y, 3)]), check_level="surjective")
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/unittest/test_meta_schedule_space_cpu.py b/tests/python/unittest/test_meta_schedule_space_cpu.py
index 36f365e732..12aa150f57 100644
--- a/tests/python/unittest/test_meta_schedule_space_cpu.py
+++ b/tests/python/unittest/test_meta_schedule_space_cpu.py
@@ -48,11 +48,11 @@ def test_cpu_c1d():
             for i0_0, i1_0, i2_0, i0_1_1, i1_1_1, i2_1_1 in T.grid(1, 1, 2, 1, 1, 8):
                 for i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
                     with T.block("conv1d_nlc"):
-                        n = T.axis.spatial(1, i0_0 + i0_1_1 + i0_2 + i0_3)
-                        l = T.axis.spatial(128, i1_1_1 * 128 + i1_0 * 128 + i1_2 * 2 + i1_3)
-                        co = T.axis.spatial(128, (i2_0 * 8 + i2_1_1) * 8 + i2_2 + i2_3)
+                        n = T.axis.spatial(1, i0_1_1 + i0_2 + i0_3 + i0_0)
+                        l = T.axis.spatial(128, i1_0 * 128 + i1_1_1 * 128 + i1_2 * 2 + i1_3)
+                        co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1_1 * 8 + i2_2)
                         rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
-                        rc = T.axis.reduce(64, i4_0 + i4_1)
+                        rc = T.axis.reduce(64, i4_1 + i4_0)
                         T.reads(PadInput[n, l * 2 + rl, co // 128 * 64 + rc], weight[rl, rc, co])
                         T.writes(conv1d_nlc_global[n, l, co])
                         T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
@@ -89,11 +89,11 @@ def test_cpu_c1d():
                             PadInput[i0, i1, i2] = T.if_then_else(1 <= i1 and i1 < 257, inputs[i0, i1 - 1, i2], T.float32(0), dtype="float32")
                     for i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
                         with T.block("conv1d_nlc"):
-                            n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
-                            l = T.axis.spatial(128, i1_1 * 128 + i1_0 * 128 + i1_2 * 2 + i1_3)
-                            co = T.axis.spatial(128, (i2_0 * 8 + i2_1) * 8 + i2_2 + i2_3)
+                            n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
+                            l = T.axis.spatial(128, i1_0 * 128 + i1_1 * 128 + i1_2 * 2 + i1_3)
+                            co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1 * 8 + i2_2)
                             rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
-                            rc = T.axis.reduce(64, i4_0 + i4_1)
+                            rc = T.axis.reduce(64, i4_1 + i4_0)
                             T.reads(PadInput[n, l * 2 + rl, co // 128 * 64 + rc], weight[rl, rc, co])
                             T.writes(conv1d_nlc_global[n, l, co])
                             T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
@@ -107,7 +107,7 @@ def test_cpu_c1d():
                         T.reads(conv1d_nlc_global[v0, v1, v2])
                         T.writes(conv1d_nlc[v0, v1, v2])
                         conv1d_nlc[v0, v1, v2] = conv1d_nlc_global[v0, v1, v2]
-                        
+
     @T.prim_func
     def c1d_2(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 128), "float32"], conv1d_nlc: T.Buffer[(1, 128, 128), "float32"]) -> None:
         # function attr dict
@@ -119,11 +119,11 @@ def test_cpu_c1d():
             T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
             for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1, i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 8, 1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
                 with T.block("conv1d_nlc"):
-                    n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
-                    l = T.axis.spatial(128, i1_1 * 128 + i1_0 * 128 + i1_2 * 2 + i1_3)
-                    co = T.axis.spatial(128, (i2_0 * 8 + i2_1) * 8 + i2_2 + i2_3)
+                    n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
+                    l = T.axis.spatial(128, i1_0 * 128 + i1_1 * 128 + i1_2 * 2 + i1_3)
+                    co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1 * 8 + i2_2)
                     rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
-                    rc = T.axis.reduce(64, i4_0 + i4_1)
+                    rc = T.axis.reduce(64, i4_1 + i4_0)
                     T.reads(inputs[n, l * 2 + rl - 1, co // 128 * 64 + rc], weight[rl, rc, co])
                     T.writes(conv1d_nlc[n, l, co])
                     T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
@@ -201,11 +201,11 @@ def test_cpu_c2d():
                 for i3_1 in T.serial(8):
                     for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(7, 7, 1, 1, 2, 1, 1, 1, 1, 3, 1, 8, 1, 4):
                         with T.block("conv2d_nhwc"):
-                            n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
-                            h = T.axis.spatial(112, ((i1_0 + i1_1) * 2 + i1_2) * 8 + i1_3)
-                            w = T.axis.spatial(112, i2_0 * 28 + i2_1 + i2_2 + i2_3)
-                            co = T.axis.spatial(64, (i3_0 * 8 + i3_1 + i3_2) * 4 + i3_3)
-                            rh = T.axis.reduce(7, i4_0 + i4_1)
+                            n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
+                            h = T.axis.spatial(112, i1_0 * 16 + i1_1 * 16 + i1_2 * 8 + i1_3)
+                            w = T.axis.spatial(112, i2_3 + i2_0 * 28 + i2_1 + i2_2)
+                            co = T.axis.spatial(64, i3_0 * 32 + i3_1 * 4 + i3_2 * 4 + i3_3)
+                            rh = T.axis.reduce(7, i4_1 + i4_0)
                             rw = T.axis.reduce(7, i5_0 + i5_1)
                             rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
                             T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rh, rw, rc, co])
@@ -243,11 +243,11 @@ def test_cpu_c2d():
             for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 7, 4, 2):
                 for i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 28, 8, 7, 7, 1, 1, 2, 1, 1, 1, 1, 3, 1, 8, 1, 4):
                     with T.block("conv2d_nhwc"):
-                        n = T.axis.spatial(1, i0_3 + i0_2 + i0_1_1 + i0_0)
-                        h = T.axis.spatial(112, ((i1_0 + i1_1_1) * 2 + i1_2) * 8 + i1_3)
-                        w = T.axis.spatial(112, i2_0 * 28 + i2_1_1 + i2_2 + i2_3)
-                        co = T.axis.spatial(64, (i3_0 * 8 + i3_1_1 + i3_2) * 4 + i3_3)
-                        rh = T.axis.reduce(7, i4_0 + i4_1)
+                        n = T.axis.spatial(1, i0_3 + i0_0 + i0_1_1 + i0_2)
+                        h = T.axis.spatial(112, i1_0 * 16 + i1_1_1 * 16 + i1_2 * 8 + i1_3)
+                        w = T.axis.spatial(112, i2_3 + i2_0 * 28 + i2_1_1 + i2_2)
+                        co = T.axis.spatial(64, i3_0 * 32 + i3_1_1 * 4 + i3_2 * 4 + i3_3)
+                        rh = T.axis.reduce(7, i4_1 + i4_0)
                         rw = T.axis.reduce(7, i5_0 + i5_1)
                         rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
                         T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rh, rw, rc, co])
@@ -287,11 +287,11 @@ def test_cpu_c2d():
                         PadInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 and i1 < 227 and 3 <= i2 and i2 < 227, inputs[i0, i1 - 3, i2 - 3, i3], T.float32(0), dtype="float32")
                 for i2_0, i3_0, i0_1, i1_1, i2_1, i3_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(4, 2, 1, 1, 28, 8, 7, 7, 1, 1, 2, 1, 1, 1, 1, 3, 1, 8, 1, 4):
                     with T.block("conv2d_nhwc"):
-                        n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
-                        h = T.axis.spatial(112, ((i1_0 + i1_1) * 2 + i1_2) * 8 + i1_3)
-                        w = T.axis.spatial(112, i2_0 * 28 + i2_1 + i2_2 + i2_3)
-                        co = T.axis.spatial(64, (i3_0 * 8 + i3_1 + i3_2) * 4 + i3_3)
-                        rh = T.axis.reduce(7, i4_0 + i4_1)
+                        n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
+                        h = T.axis.spatial(112, i1_0 * 16 + i1_1 * 16 + i1_2 * 8 + i1_3)
+                        w = T.axis.spatial(112, i2_3 + i2_0 * 28 + i2_1 + i2_2)
+                        co = T.axis.spatial(64, i3_0 * 32 + i3_1 * 4 + i3_2 * 4 + i3_3)
+                        rh = T.axis.reduce(7, i4_1 + i4_0)
                         rw = T.axis.reduce(7, i5_0 + i5_1)
                         rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
                         T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rh, rw, rc, co])
@@ -378,15 +378,15 @@ def test_cpu_c3d():
                 for i0_1, i1_1, i2_1, i3_1, i4_1 in T.grid(1, 4, 4, 14, 1):
                     for i5_0, i6_0, i7_0, i8_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_1, i6_1, i7_1, i8_1, i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 7, 7, 3, 1, 1, 1, 1, 32, 7, 1, 1, 1, 1, 1, 7, 8, 1):
                         with T.block("conv3d_ndhwc"):
-                            n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
-                            d = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3)
-                            h = T.axis.spatial(112, (i2_0 * 4 + i2_1 + i2_2) * 7 + i2_3)
-                            w = T.axis.spatial(112, (i3_0 * 14 + i3_1 + i3_2) * 8 + i3_3)
-                            co = T.axis.spatial(64, (i4_0 + i4_1) * 32 + i4_2 + i4_3)
+                            n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
+                            d = T.axis.spatial(8, i1_3 + i1_0 * 4 + i1_1 + i1_2)
+                            h = T.axis.spatial(112, i2_0 * 28 + i2_1 * 7 + i2_2 * 7 + i2_3)
+                            w = T.axis.spatial(112, i3_0 * 112 + i3_1 * 8 + i3_2 * 8 + i3_3)
+                            co = T.axis.spatial(64, i4_3 + i4_0 * 32 + i4_1 * 32 + i4_2)
                             rd = T.axis.reduce(7, i5_0 * 7 + i5_1)
-                            rh = T.axis.reduce(7, i6_0 + i6_1)
+                            rh = T.axis.reduce(7, i6_1 + i6_0)
                             rw = T.axis.reduce(7, i7_0 + i7_1)
-                            rc = T.axis.reduce(3, i8_0 + i8_1)
+                            rc = T.axis.reduce(3, i8_1 + i8_0)
                             T.reads(PadInput[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rd, rh, rw, rc, co])
                             T.writes(conv3d_ndhwc_global[n, d, h, w, co])
                             T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
@@ -428,15 +428,15 @@ def test_cpu_c3d():
                             PadInput[i0, i1, i2, i3, i4] = T.if_then_else(3 <= i1 and i1 < 19 and 3 <= i2 and i2 < 227 and 3 <= i3 and i3 < 227, inputs[i0, i1 - 3, i2 - 3, i3 - 3, i4], T.float32(0), dtype="float32")
                     for i4_1, i5_0, i6_0, i7_0, i8_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_1, i6_1, i7_1, i8_1, i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 1, 7, 7, 3, 1, 1, 1, 1, 32, 7, 1, 1, 1, 1, 1, 7, 8, 1):
                         with T.block("conv3d_ndhwc"):
-                            n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
-                            d = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3)
-                            h = T.axis.spatial(112, (i2_0 * 4 + i2_1 + i2_2) * 7 + i2_3)
-                            w = T.axis.spatial(112, (i3_0 * 14 + i3_1 + i3_2) * 8 + i3_3)
-                            co = T.axis.spatial(64, (i4_0 + i4_1) * 32 + i4_2 + i4_3)
+                            n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
+                            d = T.axis.spatial(8, i1_3 + i1_0 * 4 + i1_1 + i1_2)
+                            h = T.axis.spatial(112, i2_0 * 28 + i2_1 * 7 + i2_2 * 7 + i2_3)
+                            w = T.axis.spatial(112, i3_0 * 112 + i3_1 * 8 + i3_2 * 8 + i3_3)
+                            co = T.axis.spatial(64, i4_3 + i4_0 * 32 + i4_1 * 32 + i4_2)
                             rd = T.axis.reduce(7, i5_0 * 7 + i5_1)
-                            rh = T.axis.reduce(7, i6_0 + i6_1)
+                            rh = T.axis.reduce(7, i6_1 + i6_0)
                             rw = T.axis.reduce(7, i7_0 + i7_1)
-                            rc = T.axis.reduce(3, i8_0 + i8_1)
+                            rc = T.axis.reduce(3, i8_1 + i8_0)
                             T.reads(PadInput[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rd, rh, rw, rc, co])
                             T.writes(conv3d_ndhwc_global[n, d, h, w, co])
                             T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
@@ -476,15 +476,15 @@ def test_cpu_c3d():
                         PadInput[i0, i1, i2, i3, i4] = T.if_then_else(3 <= i1 and i1 < 19 and 3 <= i2 and i2 < 227 and 3 <= i3 and i3 < 227, inputs[i0, i1 - 3, i2 - 3, i3 - 3, i4], T.float32(0), dtype="float32")
                 for i4_1, i5_0, i6_0, i7_0, i8_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_1, i6_1, i7_1, i8_1, i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 1, 7, 7, 3, 1, 1, 1, 1, 32, 7, 1, 1, 1, 1, 1, 7, 8, 1):
                     with T.block("conv3d_ndhwc"):
-                        n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
-                        d = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3)
-                        h = T.axis.spatial(112, (i2_0 * 4 + i2_1 + i2_2) * 7 + i2_3)
-                        w = T.axis.spatial(112, (i3_0 * 14 + i3_1 + i3_2) * 8 + i3_3)
-                        co = T.axis.spatial(64, (i4_0 + i4_1) * 32 + i4_2 + i4_3)
+                        n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
+                        d = T.axis.spatial(8, i1_3 + i1_0 * 4 + i1_1 + i1_2)
+                        h = T.axis.spatial(112, i2_0 * 28 + i2_1 * 7 + i2_2 * 7 + i2_3)
+                        w = T.axis.spatial(112, i3_0 * 112 + i3_1 * 8 + i3_2 * 8 + i3_3)
+                        co = T.axis.spatial(64, i4_3 + i4_0 * 32 + i4_1 * 32 + i4_2)
                         rd = T.axis.reduce(7, i5_0 * 7 + i5_1)
-                        rh = T.axis.reduce(7, i6_0 + i6_1)
+                        rh = T.axis.reduce(7, i6_1 + i6_0)
                         rw = T.axis.reduce(7, i7_0 + i7_1)
-                        rc = T.axis.reduce(3, i8_0 + i8_1)
+                        rc = T.axis.reduce(3, i8_1 + i8_0)
                         T.reads(PadInput[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rd, rh, rw, rc, co])
                         T.writes(conv3d_ndhwc[n, d, h, w, co])
                         T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
@@ -574,14 +574,14 @@ def test_cpu_cap():
                 for i2_1, i3_1, i4_1, i5_1 in T.grid(4, 1, 4, 2):
                     for i6_0, i7_0, i8_0, i9_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_2, i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3 in T.grid(1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16):
                         with T.block("conv2d_capsule_nhwijc"):
-                            n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
+                            n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1)
                             h = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3)
-                            w = T.axis.spatial(8, (i2_0 * 4 + i2_1) * 2 + i2_2 + i2_3)
-                            cap_i = T.axis.spatial(4, (i3_0 + i3_1 + i3_2) * 4 + i3_3)
+                            w = T.axis.spatial(8, i2_0 * 8 + i2_1 * 2 + i2_2 + i2_3)
+                            cap_i = T.axis.spatial(4, i3_0 * 4 + i3_1 * 4 + i3_2 * 4 + i3_3)
                             cap_j = T.axis.spatial(4, i4_0 * 4 + i4_1 + i4_2 + i4_3)
-                            co = T.axis.spatial(32, (i5_0 * 2 + i5_1 + i5_2) * 16 + i5_3)
+                            co = T.axis.spatial(32, i5_0 * 32 + i5_1 * 16 + i5_2 * 16 + i5_3)
                             rh = T.axis.reduce(3, i6_0 * 3 + i6_1)
-                            rw = T.axis.reduce(3, i7_0 + i7_1)
+                            rw = T.axis.reduce(3, i7_1 + i7_0)
                             cap_k = T.axis.reduce(4, i8_0 + i8_1)
                             rc = T.axis.reduce(32, i9_0 * 32 + i9_1)
                             T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight[rh, rw, cap_k, cap_j, rc, co])
@@ -625,14 +625,14 @@ def test_cpu_cap():
                             PadInput[i0, i1, i2, i3, i4, i5] = T.if_then_else(1 <= i1 and i1 < 17 and 1 <= i2 and i2 < 17, inputs[i0, i1 - 1, i2 - 1, i3, i4, i5], T.float32(0), dtype="float32")
                     for i6_0, i7_0, i8_0, i9_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_2, i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3 in T.grid(1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16):
                         with T.block("conv2d_capsule_nhwijc"):
-                            n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
+                            n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1)
                             h = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3)
-                            w = T.axis.spatial(8, (i2_0 * 4 + i2_1) * 2 + i2_2 + i2_3)
-                            cap_i = T.axis.spatial(4, (i3_0 + i3_1 + i3_2) * 4 + i3_3)
+                            w = T.axis.spatial(8, i2_0 * 8 + i2_1 * 2 + i2_2 + i2_3)
+                            cap_i = T.axis.spatial(4, i3_0 * 4 + i3_1 * 4 + i3_2 * 4 + i3_3)
                             cap_j = T.axis.spatial(4, i4_0 * 4 + i4_1 + i4_2 + i4_3)
-                            co = T.axis.spatial(32, (i5_0 * 2 + i5_1 + i5_2) * 16 + i5_3)
+                            co = T.axis.spatial(32, i5_0 * 32 + i5_1 * 16 + i5_2 * 16 + i5_3)
                             rh = T.axis.reduce(3, i6_0 * 3 + i6_1)
-                            rw = T.axis.reduce(3, i7_0 + i7_1)
+                            rw = T.axis.reduce(3, i7_1 + i7_0)
                             cap_k = T.axis.reduce(4, i8_0 + i8_1)
                             rc = T.axis.reduce(32, i9_0 * 32 + i9_1)
                             T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight[rh, rw, cap_k, cap_j, rc, co])
@@ -667,14 +667,14 @@ def test_cpu_cap():
                     PadInput[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1] = T.if_then_else(1 <= i1_1 and i1_1 < 17 and 1 <= i2_1 and i2_1 < 17, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1, i4_1, i5_1], T.float32(0), dtype="float32")
             for i0_0, i1_0, i2_0, i3_0, i4_0, i5_0, i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_1_1, i5_1_1, i6_0, i7_0, i8_0, i9_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_2, i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3 in T.grid(1, 2, 1, 1, 1, 1, 1, 4, 4, 1, 4, 2, 1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16):
                 with T.block("conv2d_capsule_nhwijc"):
-                    n = T.axis.spatial(1, i0_3 + i0_2 + i0_1_1 + i0_0)
+                    n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1_1)
                     h = T.axis.spatial(8, i1_0 * 4 + i1_1_1 + i1_2 + i1_3)
-                    w = T.axis.spatial(8, (i2_0 * 4 + i2_1_1) * 2 + i2_2 + i2_3)
-                    cap_i = T.axis.spatial(4, (i3_0 + i3_1_1 + i3_2) * 4 + i3_3)
+                    w = T.axis.spatial(8, i2_0 * 8 + i2_1_1 * 2 + i2_2 + i2_3)
+                    cap_i = T.axis.spatial(4, i3_0 * 4 + i3_1_1 * 4 + i3_2 * 4 + i3_3)
                     cap_j = T.axis.spatial(4, i4_0 * 4 + i4_1_1 + i4_2 + i4_3)
-                    co = T.axis.spatial(32, (i5_0 * 2 + i5_1_1 + i5_2) * 16 + i5_3)
+                    co = T.axis.spatial(32, i5_0 * 32 + i5_1_1 * 16 + i5_2 * 16 + i5_3)
                     rh = T.axis.reduce(3, i6_0 * 3 + i6_1)
-                    rw = T.axis.reduce(3, i7_0 + i7_1)
+                    rw = T.axis.reduce(3, i7_1 + i7_0)
                     cap_k = T.axis.reduce(4, i8_0 + i8_1)
                     rc = T.axis.reduce(32, i9_0 * 32 + i9_1)
                     T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight[rh, rw, cap_k, cap_j, rc, co])
@@ -763,7 +763,7 @@ def test_cpu_dep():
             for i0_0, i1_0, i2_0, i3_0, i0_1_1, i1_1_1, i2_1_1, i3_1_1 in T.grid(1, 1, 1, 1, 1, 4, 4, 8):
                 for i4_0, i5_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 1, 2, 7, 2, 3, 3, 1, 14, 4, 2):
                     with T.block("depth_conv2d_nhwc"):
-                        n = T.axis.spatial(1, i0_0 + i0_1_1 + i0_2 + i0_3)
+                        n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1_1)
                         h = T.axis.spatial(112, i1_0 * 112 + i1_1_1 * 28 + i1_2 * 14 + i1_3)
                         w = T.axis.spatial(112, i2_0 * 112 + i2_1_1 * 28 + i2_2 * 4 + i2_3)
                         c = T.axis.spatial(32, i3_0 * 32 + i3_1_1 * 4 + i3_2 * 2 + i3_3)
@@ -804,7 +804,7 @@ def test_cpu_dep():
             for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 1, 1, 1):
                 for i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_0, i5_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 4, 4, 8, 1, 1, 1, 2, 7, 2, 3, 3, 1, 14, 4, 2):
                     with T.block("depth_conv2d_nhwc"):
-                        n = T.axis.spatial(1, i0_0 + i0_1_1 + i0_2 + i0_3)
+                        n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1_1)
                         h = T.axis.spatial(112, i1_0 * 112 + i1_1_1 * 28 + i1_2 * 14 + i1_3)
                         w = T.axis.spatial(112, i2_0 * 112 + i2_1_1 * 28 + i2_2 * 4 + i2_3)
                         c = T.axis.spatial(32, i3_0 * 32 + i3_1_1 * 4 + i3_2 * 2 + i3_3)
@@ -843,7 +843,7 @@ def test_cpu_dep():
                         PadInput[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 113 and 1 <= i2 and i2 < 113, placeholder[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32")
                 for i2_1, i3_1, i4_0, i5_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i0_3, i1_3, i2_3, i3_3 in T.grid(4, 8, 1, 1, 1, 2, 7, 2, 3, 3, 1, 14, 4, 2):
                     with T.block("depth_conv2d_nhwc"):
-                        n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
+                        n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1)
                         h = T.axis.spatial(112, i1_0 * 112 + i1_1 * 28 + i1_2 * 14 + i1_3)
                         w = T.axis.spatial(112, i2_0 * 112 + i2_1 * 28 + i2_2 * 4 + i2_3)
                         c = T.axis.spatial(32, i3_0 * 32 + i3_1 * 4 + i3_2 * 2 + i3_3)
@@ -926,11 +926,11 @@ def test_cpu_dil():
                         PadInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 and i1 < 227 and 3 <= i2 and i2 < 227, inputs[i0, i1 - 3, i2 - 3, i3], T.float32(0), dtype="float32")
                 for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(7, 1, 1, 1, 1, 109, 8, 1, 7, 3, 1, 1, 1, 1):
                     with T.block("conv2d_nhwc"):
-                        n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
-                        h = T.axis.spatial(109, i1_0 + i1_1 + i1_2 + i1_3)
-                        w = T.axis.spatial(109, (i2_0 + i2_1) * 109 + i2_2 + i2_3)
-                        co = T.axis.spatial(64, (i3_0 * 2 + i3_1) * 8 + i3_2 + i3_3)
-                        rh = T.axis.reduce(7, i4_0 + i4_1)
+                        n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
+                        h = T.axis.spatial(109, i1_2 + i1_3 + i1_0 + i1_1)
+                        w = T.axis.spatial(109, i2_3 + i2_0 * 109 + i2_1 * 109 + i2_2)
+                        co = T.axis.spatial(64, i3_0 * 16 + i3_1 * 8 + i3_2 + i3_3)
+                        rh = T.axis.reduce(7, i4_1 + i4_0)
                         rw = T.axis.reduce(7, i5_0 * 7 + i5_1)
                         rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
                         T.reads(PadInput[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc], weight[rh, rw, rc, co])
@@ -972,11 +972,11 @@ def test_cpu_dil():
                             PadInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 and i1 < 227 and 3 <= i2 and i2 < 227, inputs[i0, i1 - 3, i2 - 3, i3], T.float32(0), dtype="float32")
                     for i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 1, 1, 109, 8, 1, 7, 3, 1, 1, 1, 1):
                         with T.block("conv2d_nhwc"):
-                            n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
-                            h = T.axis.spatial(109, i1_0 + i1_1 + i1_2 + i1_3)
-                            w = T.axis.spatial(109, (i2_0 + i2_1) * 109 + i2_2 + i2_3)
-                            co = T.axis.spatial(64, (i3_0 * 2 + i3_1) * 8 + i3_2 + i3_3)
-                            rh = T.axis.reduce(7, i4_0 + i4_1)
+                            n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
+                            h = T.axis.spatial(109, i1_2 + i1_3 + i1_0 + i1_1)
+                            w = T.axis.spatial(109, i2_3 + i2_0 * 109 + i2_1 * 109 + i2_2)
+                            co = T.axis.spatial(64, i3_0 * 16 + i3_1 * 8 + i3_2 + i3_3)
+                            rh = T.axis.reduce(7, i4_1 + i4_0)
                             rw = T.axis.reduce(7, i5_0 * 7 + i5_1)
                             rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
                             T.reads(PadInput[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc], weight[rh, rw, rc, co])
@@ -1016,11 +1016,11 @@ def test_cpu_dil():
                         PadInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 and i1 < 227 and 3 <= i2 and i2 < 227, inputs[i0, i1 - 3, i2 - 3, i3], T.float32(0), dtype="float32")
                 for i2_0, i3_0, i0_1, i1_1, i2_1, i3_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 4, 1, 1, 1, 2, 7, 1, 1, 1, 1, 109, 8, 1, 7, 3, 1, 1, 1, 1):
                     with T.block("conv2d_nhwc"):
-                        n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
-                        h = T.axis.spatial(109, i1_0 + i1_1 + i1_2 + i1_3)
-                        w = T.axis.spatial(109, (i2_0 + i2_1) * 109 + i2_2 + i2_3)
-                        co = T.axis.spatial(64, (i3_0 * 2 + i3_1) * 8 + i3_2 + i3_3)
-                        rh = T.axis.reduce(7, i4_0 + i4_1)
+                        n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
+                        h = T.axis.spatial(109, i1_2 + i1_3 + i1_0 + i1_1)
+                        w = T.axis.spatial(109, i2_3 + i2_0 * 109 + i2_1 * 109 + i2_2)
+                        co = T.axis.spatial(64, i3_0 * 16 + i3_1 * 8 + i3_2 + i3_3)
+                        rh = T.axis.reduce(7, i4_1 + i4_0)
                         rw = T.axis.reduce(7, i5_0 * 7 + i5_1)
                         rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
                         T.reads(PadInput[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc], weight[rh, rw, rc, co])
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py
index b8723e286a..7323bc441f 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda.py
@@ -47,7 +47,7 @@ def test_cuda_c1d():
                             for ax0_ax1_ax2_fused in T.serial(260):
                                 with T.block("PadInput_shared"):
                                     v0 = T.axis.spatial(1, 0)
-                                    v1 = T.axis.spatial(258, i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused % 260 // 4)
+                                    v1 = T.axis.spatial(258, i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused // 4)
                                     v2 = T.axis.spatial(64, i4_0 * 4 + ax0_ax1_ax2_fused % 4)
                                     T.reads(inputs[v0, v1 - 1, v2])
                                     T.writes(PadInput_shared[v0, v1, v2])
@@ -64,11 +64,11 @@ def test_cuda_c1d():
                                     weight_shared[v0, v1, v2] = weight[v0, v1, v2]
                             for i3_1, i4_1, i0_3, i1_3, i2_3, i3_2, i4_2, i0_4, i1_4, i2_4 in T.grid(1, 2, 1, 1, 2, 3, 2, 1, 4, 8):
                                 with T.block("conv1d_nlc"):
-                                    n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
-                                    l = T.axis.spatial(128, (i0_0_i1_0_i2_0_fused % 4 * 8 + i0_1_i1_1_i2_1_fused % 16 // 2 + 0 + i1_3) * 4 + i1_4)
-                                    co = T.axis.spatial(128, (((0 * 2 + i0_1_i1_1_i2_1_fused % 2) * 4 + i0_2_i1_2_i2_2_fused % 4) * 2 + i2_3) * 8 + i2_4)
-                                    rl = T.axis.reduce(3, (i3_0 + i3_1) * 3 + i3_2)
-                                    rc = T.axis.reduce(64, (i4_0 * 2 + i4_1) * 2 + i4_2)
+                                    n = T.axis.spatial(1, i0_4 + i0_3)
+                                    l = T.axis.spatial(128, i0_0_i1_0_i2_0_fused * 32 + i0_1_i1_1_i2_1_fused // 2 * 4 + i1_3 * 4 + i1_4)
+                                    co = T.axis.spatial(128, i0_1_i1_1_i2_1_fused % 2 * 64 + i0_2_i1_2_i2_2_fused * 16 + i2_3 * 8 + i2_4)
+                                    rl = T.axis.reduce(3, i3_0 * 3 + i3_1 * 3 + i3_2)
+                                    rc = T.axis.reduce(64, i4_0 * 4 + i4_1 * 2 + i4_2)
                                     T.reads(PadInput_shared[n, l * 2 + rl, co // 128 * 64 + rc], weight_shared[rl, rc, co])
                                     T.writes(conv1d_nlc_local[n, l, co])
                                     T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
@@ -130,7 +130,7 @@ def test_cuda_c2d():
                             for ax0_ax1_ax2_ax3_fused in T.serial(80379):
                                 with T.block("PadInput_shared"):
                                     v0 = T.axis.spatial(1, 0)
-                                    v1 = T.axis.spatial(230, ax0_ax1_ax2_ax3_fused % 80379 // 351)
+                                    v1 = T.axis.spatial(230, ax0_ax1_ax2_ax3_fused // 351)
                                     v2 = T.axis.spatial(230, i0_0_i1_0_i2_0_i3_0_fused // 8 * 112 + ax0_ax1_ax2_ax3_fused % 351 // 3)
                                     v3 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 3)
                                     T.reads(inputs[v0, v1 - 3, v2 - 3, v3])
@@ -149,13 +149,13 @@ def test_cuda_c2d():
                                     weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3]
                             for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 7, 1, 1, 8, 4, 1, 7, 1, 3, 1, 1, 1, 2):
                                 with T.block("conv2d_nhwc"):
-                                    n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
-                                    h = T.axis.spatial(112, ((0 + 0) * 14 + i0_2_i1_2_i2_2_i3_2_fused % 14) * 8 + i1_3 + i1_4)
-                                    w = T.axis.spatial(112, (i0_0_i1_0_i2_0_i3_0_fused % 16 // 8 * 14 + i0_1_i1_1_i2_1_i3_1_fused % 56 // 4 + 0) * 4 + i2_3 + i2_4)
-                                    co = T.axis.spatial(64, (i0_0_i1_0_i2_0_i3_0_fused % 8 * 4 + i0_1_i1_1_i2_1_i3_1_fused % 4 + 0 + i3_3) * 2 + i3_4)
-                                    rh = T.axis.reduce(7, (i4_0 + i4_1) * 7 + i4_2)
-                                    rw = T.axis.reduce(7, i5_0 * 7 + i5_1 + i5_2)
-                                    rc = T.axis.reduce(3, (i6_0 + i6_1) * 3 + i6_2)
+                                    n = T.axis.spatial(1, i0_3 + i0_4)
+                                    h = T.axis.spatial(112, i1_4 + i0_2_i1_2_i2_2_i3_2_fused * 8 + i1_3)
+                                    w = T.axis.spatial(112, i0_0_i1_0_i2_0_i3_0_fused // 8 * 56 + i0_1_i1_1_i2_1_i3_1_fused // 4 * 4 + i2_3 + i2_4)
+                                    co = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 8 * 8 + i0_1_i1_1_i2_1_i3_1_fused % 4 * 2 + i3_3 * 2 + i3_4)
+                                    rh = T.axis.reduce(7, i4_0 * 7 + i4_1 * 7 + i4_2)
+                                    rw = T.axis.reduce(7, i5_2 + i5_0 * 7 + i5_1)
+                                    rc = T.axis.reduce(3, i6_0 * 3 + i6_1 * 3 + i6_2)
                                     T.reads(PadInput_shared[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight_shared[rh, rw, rc, co])
                                     T.writes(conv2d_nhwc_local[n, h, w, co])
                                     T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
@@ -219,7 +219,7 @@ def test_cuda_c3d():
                             for ax0_ax1_ax2_ax3_ax4_fused in T.serial(1687959):
                                 with T.block("PadInput_shared"):
                                     v0 = T.axis.spatial(1, 0)
-                                    v1 = T.axis.spatial(22, ax0_ax1_ax2_ax3_ax4_fused % 1687959 // 80379)
+                                    v1 = T.axis.spatial(22, ax0_ax1_ax2_ax3_ax4_fused // 80379)
                                     v2 = T.axis.spatial(230, ax0_ax1_ax2_ax3_ax4_fused % 80379 // 351)
                                     v3 = T.axis.spatial(230, i0_0_i1_0_i2_0_i3_0_i4_0_fused * 112 + ax0_ax1_ax2_ax3_ax4_fused % 351 // 3)
                                     v4 = T.axis.spatial(3, ax0_ax1_ax2_ax3_ax4_fused % 3)
@@ -240,14 +240,14 @@ def test_cuda_c3d():
                                     weight_shared[v0, v1, v2, v3, v4] = weight[v0, v1, v2, v3, v4]
                             for i5_1, i6_1, i7_1, i8_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_2, i6_2, i7_2, i8_2, i0_4, i1_4, i2_4, i3_4, i4_4 in T.grid(7, 7, 1, 3, 1, 2, 2, 1, 32, 1, 1, 7, 1, 1, 1, 2, 4, 1):
                                 with T.block("conv3d_ndhwc"):
-                                    n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
-                                    d = T.axis.spatial(8, ((0 + 0) * 4 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 392 // 98) * 2 + i1_3 + i1_4)
-                                    h = T.axis.spatial(112, (((0 * 4 + i0_1_i1_1_i2_1_i3_1_i4_1_fused % 8 // 2) * 7 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 98 // 14) * 2 + i2_3) * 2 + i2_4)
-                                    w = T.axis.spatial(112, ((i0_0_i1_0_i2_0_i3_0_i4_0_fused % 2 * 2 + i0_1_i1_1_i2_1_i3_1_i4_1_fused % 2) * 7 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 14 // 2 + i3_3) * 4 + i3_4)
-                                    co = T.axis.spatial(64, ((0 + 0) * 2 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 2) * 32 + i4_3 + i4_4)
-                                    rd = T.axis.reduce(7, i5_0 * 7 + i5_1 + i5_2)
+                                    n = T.axis.spatial(1, i0_4 + i0_3)
+                                    d = T.axis.spatial(8, i1_4 + i0_2_i1_2_i2_2_i3_2_i4_2_fused // 98 * 2 + i1_3)
+                                    h = T.axis.spatial(112, i0_1_i1_1_i2_1_i3_1_i4_1_fused // 2 * 28 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 98 // 14 * 4 + i2_3 * 2 + i2_4)
+                                    w = T.axis.spatial(112, i0_0_i1_0_i2_0_i3_0_i4_0_fused * 56 + i0_1_i1_1_i2_1_i3_1_i4_1_fused % 2 * 28 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 14 // 2 * 4 + i3_3 * 4 + i3_4)
+                                    co = T.axis.spatial(64, i0_2_i1_2_i2_2_i3_2_i4_2_fused % 2 * 32 + i4_3 + i4_4)
+                                    rd = T.axis.reduce(7, i5_2 + i5_0 * 7 + i5_1)
                                     rh = T.axis.reduce(7, i6_0 * 7 + i6_1 + i6_2)
-                                    rw = T.axis.reduce(7, (i7_0 + i7_1) * 7 + i7_2)
+                                    rw = T.axis.reduce(7, i7_0 * 7 + i7_1 * 7 + i7_2)
                                     rc = T.axis.reduce(3, i8_0 * 3 + i8_1 + i8_2)
                                     T.reads(PadInput_shared[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight_shared[rd, rh, rw, rc, co])
                                     T.writes(conv3d_ndhwc_local[n, d, h, w, co])
@@ -338,15 +338,15 @@ def test_cuda_cap():
                                     weight_shared[v0, v1, v2, v3, v4, v5] = weight[v0, v1, v2, v3, v4, v5]
                             for i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3, i6_2, i7_2, i8_2, i9_2, i0_4, i1_4, i2_4, i3_4, i4_4, i5_4 in T.grid(1, 1, 1, 4, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 8):
                                 with T.block("conv2d_capsule_nhwijc"):
-                                    n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
-                                    h = T.axis.spatial(8, (i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 256 // 64 + 0 + 0) * 2 + i1_3 + i1_4)
-                                    w = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8 + 0 + 0 + i2_3 + i2_4)
-                                    cap_i = T.axis.spatial(4, (i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 8 // 4 + 0) * 2 + i0_2_i1_2_i2_2_i3_2_i4_2_i5_2_fused % 4 // 2 + i3_3 + i3_4)
-                                    cap_j = T.axis.spatial(4, ((0 + 0) * 2 + i0_2_i1_2_i2_2_i3_2_i4_2_i5_2_fused % 2 + i4_3) * 2 + i4_4)
-                                    co = T.axis.spatial(32, (i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 4 + 0 + 0 + i5_3) * 8 + i5_4)
-                                    rh = T.axis.reduce(3, i6_0 + i6_1 + i6_2)
+                                    n = T.axis.spatial(1, i0_4 + i0_3)
+                                    h = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused // 64 * 2 + i1_3 + i1_4)
+                                    w = T.axis.spatial(8, i2_3 + i2_4 + i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8)
+                                    cap_i = T.axis.spatial(4, i3_3 + i3_4 + i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 8 // 4 * 2 + i0_2_i1_2_i2_2_i3_2_i4_2_i5_2_fused // 2)
+                                    cap_j = T.axis.spatial(4, i0_2_i1_2_i2_2_i3_2_i4_2_i5_2_fused % 2 * 2 + i4_3 * 2 + i4_4)
+                                    co = T.axis.spatial(32, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 4 * 8 + i5_3 * 8 + i5_4)
+                                    rh = T.axis.reduce(3, i6_1 + i6_2 + i6_0)
                                     rw = T.axis.reduce(3, i7_0 + i7_1 + i7_2)
-                                    cap_k = T.axis.reduce(4, (i8_0 + i8_1) * 2 + i8_2)
+                                    cap_k = T.axis.reduce(4, i8_0 * 2 + i8_1 * 2 + i8_2)
                                     rc = T.axis.reduce(32, i9_0 * 4 + i9_1 + i9_2)
                                     T.reads(PadInput_shared[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight_shared[rh, rw, cap_k, cap_j, rc, co])
                                     T.writes(conv2d_capsule_nhwijc_local[n, h, w, cap_i, cap_j, co])
@@ -436,12 +436,12 @@ def test_cuda_dep():
                                     placeholder_shared[v0, v1, v2, v3] = placeholder_1[v0, v1, v2, v3]
                             for i4_1, i5_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i0_4, i1_4, i2_4, i3_4 in T.grid(3, 1, 1, 4, 16, 8, 1, 3, 1, 7, 1, 1):
                                 with T.block("depth_conv2d_nhwc"):
-                                    n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
-                                    h = T.axis.spatial(112, ((0 * 4 + i0_1_i1_1_i2_1_i3_1_fused % 8 // 2 + 0) * 4 + i1_3) * 7 + i1_4)
-                                    w = T.axis.spatial(112, ((0 + 0) * 7 + i0_2_i1_2_i2_2_i3_2_fused % 14 // 2) * 16 + i2_3 + i2_4)
-                                    c = T.axis.spatial(32, ((0 * 2 + i0_1_i1_1_i2_1_i3_1_fused % 2) * 2 + i0_2_i1_2_i2_2_i3_2_fused % 2) * 8 + i3_3 + i3_4)
-                                    rh = T.axis.reduce(3, i4_0 * 3 + i4_1 + i4_2)
-                                    rw = T.axis.reduce(3, (i5_0 + i5_1) * 3 + i5_2)
+                                    n = T.axis.spatial(1, i0_4 + i0_3)
+                                    h = T.axis.spatial(112, i0_1_i1_1_i2_1_i3_1_fused // 2 * 28 + i1_3 * 7 + i1_4)
+                                    w = T.axis.spatial(112, i2_4 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 16 + i2_3)
+                                    c = T.axis.spatial(32, i0_1_i1_1_i2_1_i3_1_fused % 2 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 8 + i3_3 + i3_4)
+                                    rh = T.axis.reduce(3, i4_2 + i4_0 * 3 + i4_1)
+                                    rw = T.axis.reduce(3, i5_0 * 3 + i5_1 * 3 + i5_2)
                                     T.reads(PadInput_shared[n, h + rh, w + rw, c], placeholder_shared[0, rh, rw, c])
                                     T.writes(depth_conv2d_nhwc_local[n, h, w, c])
                                     T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
@@ -522,13 +522,13 @@ def test_cuda_dil():
                                     weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3]
                             for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 1, 1, 1, 1, 1, 8, 1, 1, 1, 1, 1, 1, 4):
                                 with T.block("conv2d_nhwc"):
-                                    n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
-                                    h = T.axis.spatial(109, i0_0_i1_0_i2_0_i3_0_fused % 218 // 2 + 0 + 0 + i1_3 + i1_4)
-                                    w = T.axis.spatial(109, 0 * 109 + i0_1_i1_1_i2_1_i3_1_fused % 109 + 0 + i2_3 + i2_4)
-                                    co = T.axis.spatial(64, ((i0_0_i1_0_i2_0_i3_0_fused % 2 + 0 + 0) * 8 + i3_3) * 4 + i3_4)
+                                    n = T.axis.spatial(1, i0_3 + i0_4)
+                                    h = T.axis.spatial(109, i1_4 + i0_0_i1_0_i2_0_i3_0_fused // 2 + i1_3)
+                                    w = T.axis.spatial(109, i0_1_i1_1_i2_1_i3_1_fused + i2_3 + i2_4)
+                                    co = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i3_3 * 4 + i3_4)
                                     rh = T.axis.reduce(7, i4_0 + i4_1 + i4_2)
-                                    rw = T.axis.reduce(7, i5_0 + i5_1 + i5_2)
-                                    rc = T.axis.reduce(3, i6_0 + i6_1 + i6_2)
+                                    rw = T.axis.reduce(7, i5_2 + i5_0 + i5_1)
+                                    rc = T.axis.reduce(3, i6_1 + i6_2 + i6_0)
                                     T.reads(PadInput_shared[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc], weight_shared[rh, rw, rc, co])
                                     T.writes(conv2d_nhwc_local[n, h, w, co])
                                     T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py
index 4351fe5b63..b859b655ef 100644
--- a/tests/python/unittest/test_tir_schedule_reorder.py
+++ b/tests/python/unittest/test_tir_schedule_reorder.py
@@ -214,9 +214,9 @@ def test_reorder_with_opaque_access():
     verify_trace_roundtrip(sch=sch, mod=opaque_access)
 
 
-def test_reorder_with_partial_affineness():
+def test_reorder_overlapped_access():
     @T.prim_func
-    def non_affine_func(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
+    def overlapped_access(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
         # example to write first axis multiple times
         for v0, v1, v2 in T.grid(6, 4, 4):
             with T.block("block"):
@@ -225,7 +225,7 @@ def test_reorder_with_partial_affineness():
                 B[i, j] = A[i, j] + 1.0
 
     @T.prim_func
-    def non_affine_func_reorder(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
+    def overlapped_access_reorder(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
         # example to write first axis multiple times
         for v0, v2, v1 in T.grid(6, 4, 4):
             with T.block("block"):
@@ -233,6 +233,30 @@ def test_reorder_with_partial_affineness():
                 j = T.axis.spatial(4, v2)
                 B[i, j] = A[i, j] + 1.0
 
+    sch = tir.Schedule(overlapped_access, debug_mask="all")
+    v0, v1, v2 = sch.get_loops(sch.get_block("block"))
+    sch.reorder(v0, v2, v1)
+    tvm.ir.assert_structural_equal(overlapped_access_reorder, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=overlapped_access)
+
+
+def test_reorder_with_partial_affineness():
+    @T.prim_func
+    def non_affine_func(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
+        for v0, v1, v2 in T.grid(6, 4, 4):
+            with T.block("block"):
+                i = T.axis.spatial(14, v0 * v0 + v1)
+                j = T.axis.spatial(4, v2)
+                B[i, j] = A[i, j] + 1.0
+
+    @T.prim_func
+    def non_affine_func_reorder(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
+        for v0, v2, v1 in T.grid(6, 4, 4):
+            with T.block("block"):
+                i = T.axis.spatial(14, v0 * v0 + v1)
+                j = T.axis.spatial(4, v2)
+                B[i, j] = A[i, j] + 1.0
+
     sch = tir.Schedule(non_affine_func, debug_mask="all")
     v0, v1, v2 = sch.get_loops(sch.get_block("block"))
     with pytest.raises(tvm.tir.ScheduleError):
diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py
index 9fd678174d..3ae88e0abb 100644
--- a/tests/python/unittest/test_tir_schedule_split_fuse.py
+++ b/tests/python/unittest/test_tir_schedule_split_fuse.py
@@ -177,7 +177,7 @@ def elementwise_split_case0(a: T.handle, b: T.handle) -> None:
     B = T.match_buffer(b, [128, 128, 128])
     for i1, i2, i3, j1, j2, k1, k2 in T.grid(2, 1, 64, 4, 32, 16, 8):
         with T.block("B"):
-            vi = T.axis.S(128, (i1 + i2) * 64 + i3)
+            vi = T.axis.S(128, i1 * 64 + i2 * 64 + i3)
             vj = T.axis.S(128, j1 * 32 + j2)
             vk = T.axis.S(128, k1 * 8 + k2)
             T.reads([A[vi, vj, vk]])
@@ -191,9 +191,9 @@ def elementwise_split_case1(a: T.handle, b: T.handle) -> None:
     B = T.match_buffer(b, [128, 128, 128])
     for i1, i2, i3, j1, j2, j3, k1, k2, k3 in T.grid(2, 1, 64, 2, 1, 64, 2, 1, 64):
         with T.block("B"):
-            vi = T.axis.S(128, (i1 + i2) * 64 + i3)
-            vj = T.axis.S(128, (j1 + j2) * 64 + j3)
-            vk = T.axis.S(128, (k1 + k2) * 64 + k3)
+            vi = T.axis.S(128, i1 * 64 + i2 * 64 + i3)
+            vj = T.axis.S(128, j1 * 64 + j2 * 64 + j3)
+            vk = T.axis.S(128, k1 * 64 + k2 * 64 + k3)
             T.reads([A[vi, vj, vk]])
             T.writes([B[vi, vj, vk]])
             B[vi, vj, vk] = A[vi, vj, vk] * 2.0
diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
index 1b4c34973f..bbeb8d8760 100644
--- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py
+++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
@@ -758,7 +758,7 @@ def test_non_perfect_tiling_cache():
     s = tir.ScheduleState(non_perfect_tiling_cache, debug_mask="all")
     # pylint: disable=protected-access
     assert s._get_cached_flags(_get_block(s, "cache")) == CachedFlags(
-        affine_binding=False,
+        affine_binding=True,
         region_cover=True,
         stage_pipeline=True,
     )