You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/07/20 09:58:21 UTC
[tvm] branch main updated: Fix #12039‘s broken cases (#12143)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7abdce2660 Fix #12039‘s broken cases (#12143)
7abdce2660 is described below
commit 7abdce26606551776e55a458622e23182e8ae9d4
Author: wrongtest <wr...@gmail.com>
AuthorDate: Wed Jul 20 17:58:11 2022 +0800
Fix #12039‘s broken cases (#12143)
---
src/arith/iter_affine_map.cc | 91 ++++++++----
tests/python/unittest/test_arith_intset.py | 7 +-
.../python/unittest/test_arith_iter_affine_map.py | 58 +++++++-
.../unittest/test_meta_schedule_space_cpu.py | 164 ++++++++++-----------
.../unittest/test_meta_schedule_space_cuda.py | 84 +++++------
tests/python/unittest/test_tir_schedule_reorder.py | 30 +++-
.../unittest/test_tir_schedule_split_fuse.py | 8 +-
.../test_tir_schedule_state_cached_flags.py | 2 +-
8 files changed, 281 insertions(+), 163 deletions(-)
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index d2aa16ded1..83e2821c98 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -177,8 +177,12 @@ class IterMapRewriter : public ExprMutator {
using Parent = ExprMutator;
explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters,
- bool simplify_trivial_iterators, Array<String>* errors)
- : analyzer_(analyzer), errors_(*errors), padding_predicate_(const_false()) {
+ IterMapLevel check_level, bool simplify_trivial_iterators,
+ Array<String>* errors)
+ : analyzer_(analyzer),
+ check_level_(check_level),
+ errors_(*errors),
+ padding_predicate_(const_false()) {
for (auto kv : input_iters) {
const Var& var = kv.first;
const Range& vrng = kv.second;
@@ -419,6 +423,8 @@ class IterMapRewriter : public ExprMutator {
// Internal analyzer
Analyzer* analyzer_;
+ // Iter map check level
+ IterMapLevel check_level_;
// Error messages for each unresolved expression.
Array<String>& errors_;
// The var map
@@ -651,7 +657,7 @@ class IterMapRewriter : public ExprMutator {
if (predicate_induced_max.defined())
predicate_induced_max = predicate_induced_max.value() - base;
}
- Optional<IterSumExpr> opt = TryFuseIters(expr);
+ Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
ICHECK(!opt.defined() || opt.value()->args.size() == 1);
// scale should be 1
if (opt.defined() && is_one(opt.value()->args[0]->scale)) {
@@ -702,7 +708,7 @@ class IterMapRewriter : public ExprMutator {
IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
// We are normalizing a regular iter
if (expr->args.size() < 1) return expr;
- Optional<IterSumExpr> opt = TryFuseIters(expr);
+ Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
if (opt.defined()) {
return opt.value();
} else {
@@ -735,9 +741,10 @@ class IterMapRewriter : public ExprMutator {
* return a corresponding IterSumExpr with extra offset if needed.
* Try to normalize IterSum into a fused IterMark
* \param expr The input sum.
+ * \param check_level The check level if iter mapping.
* \return The sum with the fused IterMark and extra offset if succeed.
*/
- Optional<IterSumExpr> TryFuseIters(IterSumExpr expr) {
+ Optional<IterSumExpr> TryFuseIters(IterSumExpr expr, IterMapLevel check_level) {
// select the iterators in order
std::vector<bool> visited(expr->args.size(), false);
std::vector<IterSplitExpr> flattened_iters, grouped_iters;
@@ -758,14 +765,42 @@ class IterMapRewriter : public ExprMutator {
}
// check if it can be remapped into a fused pattern.
PrimExpr expected_extra_base = 0;
+ PrimExpr tail_extent = 0;
PrimExpr expected_scale = base_scale.value();
for (size_t i = 0; i < expr->args.size();) {
- // find j such that expr->args[j] has expected scale
- size_t j = i == 0 ? base_index : 0;
- for (; j < expr->args.size(); ++j) {
- if (!visited[j] && analyzer_->CanProveEqual(expr->args[j]->scale, expected_scale)) break;
+ // find position such that expr->args[j] match expected scale
+ int j = i == 0 ? base_index : expr->args.size() - 1;
+
+ size_t matched_pos = expr->args.size();
+ PrimExpr matched_scale{nullptr};
+ bool is_exact_match{false};
+
+ for (; j >= 0; --j) {
+ if (visited[j]) {
+ continue;
+ }
+ const PrimExpr& cur_scale = expr->args[j]->scale;
+
+ // for bijective mapping, the matched scale must equal to expected scale
+ if (analyzer_->CanProveEqual(cur_scale, expected_scale)) {
+ matched_pos = j;
+ matched_scale = cur_scale;
+ is_exact_match = true;
+ break;
+ }
+ if (check_level != IterMapLevel::Bijective && base_scale.value()->value == 1) {
+ // find the closest scale which is less or equal to expected scale
+ if (analyzer_->CanProveGreaterEqual(expected_scale - cur_scale, 0) &&
+ analyzer_->CanProveGreaterEqual(cur_scale, 0)) {
+ if (matched_pos == expr->args.size() ||
+ analyzer_->CanProveLess(matched_scale - cur_scale, 0)) {
+ matched_pos = j;
+ matched_scale = cur_scale;
+ }
+ }
+ }
}
- if (j == expr->args.size()) {
+ if (matched_pos == expr->args.size()) {
return NullOpt;
}
// look for the longest constrained iter started from expr->args[j]
@@ -775,8 +810,8 @@ class IterMapRewriter : public ExprMutator {
// otherwise we expect the scale of i to be 2*5=10
Optional<IterSumExpr> constraint_to_match;
for (const IterSumExpr& iter : constrained_iters_flattened_) {
- if (IterSplitEqual(expr->args[j], iter->args.back(), false)) {
- // find a predicate started from expr->args[j]
+ if (IterSplitEqual(expr->args[matched_pos], iter->args.back(), false)) {
+ // find a predicate started from match position
if (!constraint_to_match ||
constraint_to_match.value()->args.size() < iter->args.size()) {
constraint_to_match = iter;
@@ -793,7 +828,7 @@ class IterMapRewriter : public ExprMutator {
size_t k = 0;
for (; k < expr->args.size(); ++k) {
if (!visited[k] && IterSplitEqual(expr->args[k], *it, false)) {
- if (analyzer_->CanProveEqual((*it)->scale * expected_scale, expr->args[k]->scale))
+ if (analyzer_->CanProveEqual((*it)->scale * matched_scale, expr->args[k]->scale))
break;
}
}
@@ -806,20 +841,25 @@ class IterMapRewriter : public ExprMutator {
auto iter = sum_fuse_map_.find(constraint_to_match.value());
ICHECK(iter != sum_fuse_map_.end());
const IterMarkWithOffset& iter_matched = iter->second;
- grouped_iters.emplace_back(iter_matched.mark, expected_scale);
- expected_extra_base += iter_matched.offset * expected_scale;
- expected_scale *= iter_matched.mark->extent;
+ grouped_iters.emplace_back(iter_matched.mark, div(matched_scale, base_scale.value()));
+ expected_extra_base += iter_matched.offset * matched_scale;
+ if (!is_exact_match) {
+ tail_extent += expected_scale - matched_scale;
+ }
+ expected_scale = matched_scale * iter_matched.mark->extent;
// move forward
i += constraint_to_match.value()->args.size();
} else {
// constraint_to_match not found, skip this iterator
- visited[j] = true;
- IterSplitExpr arg = expr->args[j];
- arg.CopyOnWrite()->scale =
- analyzer_->Simplify(div(expr->args[j]->scale, base_scale.value()));
+ visited[matched_pos] = true;
+ IterSplitExpr arg = expr->args[matched_pos];
+ arg.CopyOnWrite()->scale = analyzer_->Simplify(div(arg->scale, base_scale.value()));
flattened_iters.push_back(arg);
grouped_iters.push_back(arg);
- expected_scale *= expr->args[j]->extent;
+ if (!is_exact_match) {
+ tail_extent += expected_scale - matched_scale;
+ }
+ expected_scale = matched_scale * expr->args[matched_pos]->extent;
++i;
}
}
@@ -843,7 +883,8 @@ class IterMapRewriter : public ExprMutator {
expr->base + expected_extra_base);
} else {
// new iter, form a new mark
- IterMark mark = IterMark(structured_form, div(expected_scale, base_scale.value()));
+ IterMark mark =
+ IterMark(structured_form, div(expected_scale, base_scale.value()) + tail_extent);
sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0);
flattened_map_[structured_form] = flattened_form;
return IterSumExpr({IterSplitExpr(mark, base_scale.value())},
@@ -1086,8 +1127,8 @@ IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range
constraints.begin(), constraints.end(),
[](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; });
- IterMapRewriter rewriter(analyzer, constrained_input_iters, simplify_trivial_iterators,
- &result->errors);
+ IterMapRewriter rewriter(analyzer, constrained_input_iters, check_level,
+ simplify_trivial_iterators, &result->errors);
// Step0.0: rewrite constraints in the order from size-small ones to size-big ones
for (const IterConstraint& constraint : constraints) {
auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound,
@@ -1281,7 +1322,7 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o
} else if (sum->args.size() == 1) {
return sum;
}
- auto opt_fused = TryFuseIters(sum);
+ auto opt_fused = TryFuseIters(sum, check_level_);
if (!opt_fused) {
ErrorLogger(this) << "Dividend " << tvm::PrettyPrint(original_dividend)
<< ", can't be written as a single fused IterSum";
diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py
index ca9d1077fe..74b53442ec 100644
--- a/tests/python/unittest/test_arith_intset.py
+++ b/tests/python/unittest/test_arith_intset.py
@@ -323,10 +323,6 @@ def test_region_lower_bound_for_non_perfect_tile():
def test_region_lower_bound_unfusable():
- # This test is designed to trigger an error in DetectIterMap,
- # resulting from a numerator which required multiple input
- # variables. The bug resulted in an exception being thrown,
- # rather than a return value of None.
var_dom = {
tvm.tir.Var("i", "int32"): tvm.ir.Range(8),
tvm.tir.Var("j", "int32"): tvm.ir.Range(4),
@@ -336,7 +332,8 @@ def test_region_lower_bound_unfusable():
tvm.ir.Range.from_min_extent((i + j) // 2, 1),
]
result = tvm.arith.estimate_region_lower_bound(region, var_dom, predicate=True)
- assert result is None
+ assert result[0].min_value == 0
+ assert result[0].max_value == 5
def test_union_lower_bound():
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py
index 7bc5ead298..6a2fdbbb3f 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -61,7 +61,6 @@ def assert_iter_sum_pattern(
)
indices = res.indices
assert len(indices) == len(keys), res.errors
- print(indices)
for i, input_iter in enumerate(keys):
spec = expect_dict[input_iter]
(
@@ -446,6 +445,13 @@ def test_predicate():
predicate=xo * 129 + xi < 128,
)
+ # strided iteration predicate
+ assert_iter_sum_pattern(
+ {xo * 16 + xi * 4: (10, 0, 4)},
+ var_dom([(xo, 3), (xi, 4)]),
+ predicate=xo * 4 + xi < 10,
+ )
+
def convert_division(divisions):
if divisions is None or len(divisions) == 0:
@@ -1010,5 +1016,55 @@ def test_padding():
assert_iter_sum_failure({flm(x, 16)}, var_dom([(x, 3)]))
+def test_overlapped_fuse():
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
+ z = tvm.tir.Var("z", "int32")
+ a = tvm.tir.Var("x", "int32")
+ b = tvm.tir.Var("y", "int32")
+
+ # non-bijective fuse of two
+ assert_iter_sum_pattern(
+ {
+ x * 7 + y: (22, 0, 1),
+ },
+ var_dom([(x, 3), (y, 8)]),
+ check_level="surjective",
+ )
+ assert_iter_sum_failure([x * 7 + y], var_dom([(x, 3), (y, 8)]), check_level="bijective")
+
+ # non-bijective fuse of three
+ assert_iter_sum_pattern(
+ {
+ x * 18 + y * 7 + z: (40, 0, 1),
+ },
+ var_dom([(x, 2), (y, 3), (z, 8)]),
+ check_level="surjective",
+ )
+ assert_iter_sum_failure([x * 7 + y], var_dom([(x, 2), (y, 3), (z, 8)]), check_level="bijective")
+
+ # negative scale fusion is not allowed
+ assert_iter_sum_failure([x * -7 + y], var_dom([(x, 3), (y, 8)]), check_level="surjective")
+ assert_iter_sum_failure([x * 7 - y], var_dom([(x, 3), (y, 8)]), check_level="surjective")
+
+ # with predicate
+ assert_iter_sum_pattern(
+ {
+ a * 40 + b * 20 + x * 18 + y * 3 + z: (125, 6, 1),
+ },
+ var_dom([(a, 3), (b, 2), (x, 2), (y, 6), (z, 8)]),
+ predicate=tvm.tir.all(z < 4, 1 < x * 6 + y, x * 6 + y < 10),
+ check_level="surjective",
+ )
+
+ # stride=1 kernel
+ assert_iter_sum_pattern(
+ {x + a: (230, 0, 1)}, var_dom([(x, 224), (a, 7)]), check_level="surjective"
+ )
+
+ # do not allow both strided and overlapped
+ assert_iter_sum_failure([5 * x + 2 * y], var_dom([(x, 4), (y, 3)]), check_level="surjective")
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/unittest/test_meta_schedule_space_cpu.py b/tests/python/unittest/test_meta_schedule_space_cpu.py
index 36f365e732..12aa150f57 100644
--- a/tests/python/unittest/test_meta_schedule_space_cpu.py
+++ b/tests/python/unittest/test_meta_schedule_space_cpu.py
@@ -48,11 +48,11 @@ def test_cpu_c1d():
for i0_0, i1_0, i2_0, i0_1_1, i1_1_1, i2_1_1 in T.grid(1, 1, 2, 1, 1, 8):
for i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
with T.block("conv1d_nlc"):
- n = T.axis.spatial(1, i0_0 + i0_1_1 + i0_2 + i0_3)
- l = T.axis.spatial(128, i1_1_1 * 128 + i1_0 * 128 + i1_2 * 2 + i1_3)
- co = T.axis.spatial(128, (i2_0 * 8 + i2_1_1) * 8 + i2_2 + i2_3)
+ n = T.axis.spatial(1, i0_1_1 + i0_2 + i0_3 + i0_0)
+ l = T.axis.spatial(128, i1_0 * 128 + i1_1_1 * 128 + i1_2 * 2 + i1_3)
+ co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1_1 * 8 + i2_2)
rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
- rc = T.axis.reduce(64, i4_0 + i4_1)
+ rc = T.axis.reduce(64, i4_1 + i4_0)
T.reads(PadInput[n, l * 2 + rl, co // 128 * 64 + rc], weight[rl, rc, co])
T.writes(conv1d_nlc_global[n, l, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
@@ -89,11 +89,11 @@ def test_cpu_c1d():
PadInput[i0, i1, i2] = T.if_then_else(1 <= i1 and i1 < 257, inputs[i0, i1 - 1, i2], T.float32(0), dtype="float32")
for i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
with T.block("conv1d_nlc"):
- n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
- l = T.axis.spatial(128, i1_1 * 128 + i1_0 * 128 + i1_2 * 2 + i1_3)
- co = T.axis.spatial(128, (i2_0 * 8 + i2_1) * 8 + i2_2 + i2_3)
+ n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
+ l = T.axis.spatial(128, i1_0 * 128 + i1_1 * 128 + i1_2 * 2 + i1_3)
+ co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1 * 8 + i2_2)
rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
- rc = T.axis.reduce(64, i4_0 + i4_1)
+ rc = T.axis.reduce(64, i4_1 + i4_0)
T.reads(PadInput[n, l * 2 + rl, co // 128 * 64 + rc], weight[rl, rc, co])
T.writes(conv1d_nlc_global[n, l, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
@@ -107,7 +107,7 @@ def test_cpu_c1d():
T.reads(conv1d_nlc_global[v0, v1, v2])
T.writes(conv1d_nlc[v0, v1, v2])
conv1d_nlc[v0, v1, v2] = conv1d_nlc_global[v0, v1, v2]
-
+
@T.prim_func
def c1d_2(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 128), "float32"], conv1d_nlc: T.Buffer[(1, 128, 128), "float32"]) -> None:
# function attr dict
@@ -119,11 +119,11 @@ def test_cpu_c1d():
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1, i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 8, 1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
with T.block("conv1d_nlc"):
- n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
- l = T.axis.spatial(128, i1_1 * 128 + i1_0 * 128 + i1_2 * 2 + i1_3)
- co = T.axis.spatial(128, (i2_0 * 8 + i2_1) * 8 + i2_2 + i2_3)
+ n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
+ l = T.axis.spatial(128, i1_0 * 128 + i1_1 * 128 + i1_2 * 2 + i1_3)
+ co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1 * 8 + i2_2)
rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
- rc = T.axis.reduce(64, i4_0 + i4_1)
+ rc = T.axis.reduce(64, i4_1 + i4_0)
T.reads(inputs[n, l * 2 + rl - 1, co // 128 * 64 + rc], weight[rl, rc, co])
T.writes(conv1d_nlc[n, l, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
@@ -201,11 +201,11 @@ def test_cpu_c2d():
for i3_1 in T.serial(8):
for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(7, 7, 1, 1, 2, 1, 1, 1, 1, 3, 1, 8, 1, 4):
with T.block("conv2d_nhwc"):
- n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
- h = T.axis.spatial(112, ((i1_0 + i1_1) * 2 + i1_2) * 8 + i1_3)
- w = T.axis.spatial(112, i2_0 * 28 + i2_1 + i2_2 + i2_3)
- co = T.axis.spatial(64, (i3_0 * 8 + i3_1 + i3_2) * 4 + i3_3)
- rh = T.axis.reduce(7, i4_0 + i4_1)
+ n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
+ h = T.axis.spatial(112, i1_0 * 16 + i1_1 * 16 + i1_2 * 8 + i1_3)
+ w = T.axis.spatial(112, i2_3 + i2_0 * 28 + i2_1 + i2_2)
+ co = T.axis.spatial(64, i3_0 * 32 + i3_1 * 4 + i3_2 * 4 + i3_3)
+ rh = T.axis.reduce(7, i4_1 + i4_0)
rw = T.axis.reduce(7, i5_0 + i5_1)
rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rh, rw, rc, co])
@@ -243,11 +243,11 @@ def test_cpu_c2d():
for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 7, 4, 2):
for i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 28, 8, 7, 7, 1, 1, 2, 1, 1, 1, 1, 3, 1, 8, 1, 4):
with T.block("conv2d_nhwc"):
- n = T.axis.spatial(1, i0_3 + i0_2 + i0_1_1 + i0_0)
- h = T.axis.spatial(112, ((i1_0 + i1_1_1) * 2 + i1_2) * 8 + i1_3)
- w = T.axis.spatial(112, i2_0 * 28 + i2_1_1 + i2_2 + i2_3)
- co = T.axis.spatial(64, (i3_0 * 8 + i3_1_1 + i3_2) * 4 + i3_3)
- rh = T.axis.reduce(7, i4_0 + i4_1)
+ n = T.axis.spatial(1, i0_3 + i0_0 + i0_1_1 + i0_2)
+ h = T.axis.spatial(112, i1_0 * 16 + i1_1_1 * 16 + i1_2 * 8 + i1_3)
+ w = T.axis.spatial(112, i2_3 + i2_0 * 28 + i2_1_1 + i2_2)
+ co = T.axis.spatial(64, i3_0 * 32 + i3_1_1 * 4 + i3_2 * 4 + i3_3)
+ rh = T.axis.reduce(7, i4_1 + i4_0)
rw = T.axis.reduce(7, i5_0 + i5_1)
rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rh, rw, rc, co])
@@ -287,11 +287,11 @@ def test_cpu_c2d():
PadInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 and i1 < 227 and 3 <= i2 and i2 < 227, inputs[i0, i1 - 3, i2 - 3, i3], T.float32(0), dtype="float32")
for i2_0, i3_0, i0_1, i1_1, i2_1, i3_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(4, 2, 1, 1, 28, 8, 7, 7, 1, 1, 2, 1, 1, 1, 1, 3, 1, 8, 1, 4):
with T.block("conv2d_nhwc"):
- n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
- h = T.axis.spatial(112, ((i1_0 + i1_1) * 2 + i1_2) * 8 + i1_3)
- w = T.axis.spatial(112, i2_0 * 28 + i2_1 + i2_2 + i2_3)
- co = T.axis.spatial(64, (i3_0 * 8 + i3_1 + i3_2) * 4 + i3_3)
- rh = T.axis.reduce(7, i4_0 + i4_1)
+ n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
+ h = T.axis.spatial(112, i1_0 * 16 + i1_1 * 16 + i1_2 * 8 + i1_3)
+ w = T.axis.spatial(112, i2_3 + i2_0 * 28 + i2_1 + i2_2)
+ co = T.axis.spatial(64, i3_0 * 32 + i3_1 * 4 + i3_2 * 4 + i3_3)
+ rh = T.axis.reduce(7, i4_1 + i4_0)
rw = T.axis.reduce(7, i5_0 + i5_1)
rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rh, rw, rc, co])
@@ -378,15 +378,15 @@ def test_cpu_c3d():
for i0_1, i1_1, i2_1, i3_1, i4_1 in T.grid(1, 4, 4, 14, 1):
for i5_0, i6_0, i7_0, i8_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_1, i6_1, i7_1, i8_1, i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 7, 7, 3, 1, 1, 1, 1, 32, 7, 1, 1, 1, 1, 1, 7, 8, 1):
with T.block("conv3d_ndhwc"):
- n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
- d = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3)
- h = T.axis.spatial(112, (i2_0 * 4 + i2_1 + i2_2) * 7 + i2_3)
- w = T.axis.spatial(112, (i3_0 * 14 + i3_1 + i3_2) * 8 + i3_3)
- co = T.axis.spatial(64, (i4_0 + i4_1) * 32 + i4_2 + i4_3)
+ n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
+ d = T.axis.spatial(8, i1_3 + i1_0 * 4 + i1_1 + i1_2)
+ h = T.axis.spatial(112, i2_0 * 28 + i2_1 * 7 + i2_2 * 7 + i2_3)
+ w = T.axis.spatial(112, i3_0 * 112 + i3_1 * 8 + i3_2 * 8 + i3_3)
+ co = T.axis.spatial(64, i4_3 + i4_0 * 32 + i4_1 * 32 + i4_2)
rd = T.axis.reduce(7, i5_0 * 7 + i5_1)
- rh = T.axis.reduce(7, i6_0 + i6_1)
+ rh = T.axis.reduce(7, i6_1 + i6_0)
rw = T.axis.reduce(7, i7_0 + i7_1)
- rc = T.axis.reduce(3, i8_0 + i8_1)
+ rc = T.axis.reduce(3, i8_1 + i8_0)
T.reads(PadInput[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rd, rh, rw, rc, co])
T.writes(conv3d_ndhwc_global[n, d, h, w, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
@@ -428,15 +428,15 @@ def test_cpu_c3d():
PadInput[i0, i1, i2, i3, i4] = T.if_then_else(3 <= i1 and i1 < 19 and 3 <= i2 and i2 < 227 and 3 <= i3 and i3 < 227, inputs[i0, i1 - 3, i2 - 3, i3 - 3, i4], T.float32(0), dtype="float32")
for i4_1, i5_0, i6_0, i7_0, i8_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_1, i6_1, i7_1, i8_1, i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 1, 7, 7, 3, 1, 1, 1, 1, 32, 7, 1, 1, 1, 1, 1, 7, 8, 1):
with T.block("conv3d_ndhwc"):
- n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
- d = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3)
- h = T.axis.spatial(112, (i2_0 * 4 + i2_1 + i2_2) * 7 + i2_3)
- w = T.axis.spatial(112, (i3_0 * 14 + i3_1 + i3_2) * 8 + i3_3)
- co = T.axis.spatial(64, (i4_0 + i4_1) * 32 + i4_2 + i4_3)
+ n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
+ d = T.axis.spatial(8, i1_3 + i1_0 * 4 + i1_1 + i1_2)
+ h = T.axis.spatial(112, i2_0 * 28 + i2_1 * 7 + i2_2 * 7 + i2_3)
+ w = T.axis.spatial(112, i3_0 * 112 + i3_1 * 8 + i3_2 * 8 + i3_3)
+ co = T.axis.spatial(64, i4_3 + i4_0 * 32 + i4_1 * 32 + i4_2)
rd = T.axis.reduce(7, i5_0 * 7 + i5_1)
- rh = T.axis.reduce(7, i6_0 + i6_1)
+ rh = T.axis.reduce(7, i6_1 + i6_0)
rw = T.axis.reduce(7, i7_0 + i7_1)
- rc = T.axis.reduce(3, i8_0 + i8_1)
+ rc = T.axis.reduce(3, i8_1 + i8_0)
T.reads(PadInput[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rd, rh, rw, rc, co])
T.writes(conv3d_ndhwc_global[n, d, h, w, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
@@ -476,15 +476,15 @@ def test_cpu_c3d():
PadInput[i0, i1, i2, i3, i4] = T.if_then_else(3 <= i1 and i1 < 19 and 3 <= i2 and i2 < 227 and 3 <= i3 and i3 < 227, inputs[i0, i1 - 3, i2 - 3, i3 - 3, i4], T.float32(0), dtype="float32")
for i4_1, i5_0, i6_0, i7_0, i8_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_1, i6_1, i7_1, i8_1, i0_3, i1_3, i2_3, i3_3, i4_3 in T.grid(1, 1, 7, 7, 3, 1, 1, 1, 1, 32, 7, 1, 1, 1, 1, 1, 7, 8, 1):
with T.block("conv3d_ndhwc"):
- n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
- d = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3)
- h = T.axis.spatial(112, (i2_0 * 4 + i2_1 + i2_2) * 7 + i2_3)
- w = T.axis.spatial(112, (i3_0 * 14 + i3_1 + i3_2) * 8 + i3_3)
- co = T.axis.spatial(64, (i4_0 + i4_1) * 32 + i4_2 + i4_3)
+ n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
+ d = T.axis.spatial(8, i1_3 + i1_0 * 4 + i1_1 + i1_2)
+ h = T.axis.spatial(112, i2_0 * 28 + i2_1 * 7 + i2_2 * 7 + i2_3)
+ w = T.axis.spatial(112, i3_0 * 112 + i3_1 * 8 + i3_2 * 8 + i3_3)
+ co = T.axis.spatial(64, i4_3 + i4_0 * 32 + i4_1 * 32 + i4_2)
rd = T.axis.reduce(7, i5_0 * 7 + i5_1)
- rh = T.axis.reduce(7, i6_0 + i6_1)
+ rh = T.axis.reduce(7, i6_1 + i6_0)
rw = T.axis.reduce(7, i7_0 + i7_1)
- rc = T.axis.reduce(3, i8_0 + i8_1)
+ rc = T.axis.reduce(3, i8_1 + i8_0)
T.reads(PadInput[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rd, rh, rw, rc, co])
T.writes(conv3d_ndhwc[n, d, h, w, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
@@ -574,14 +574,14 @@ def test_cpu_cap():
for i2_1, i3_1, i4_1, i5_1 in T.grid(4, 1, 4, 2):
for i6_0, i7_0, i8_0, i9_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_2, i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3 in T.grid(1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16):
with T.block("conv2d_capsule_nhwijc"):
- n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
+ n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1)
h = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3)
- w = T.axis.spatial(8, (i2_0 * 4 + i2_1) * 2 + i2_2 + i2_3)
- cap_i = T.axis.spatial(4, (i3_0 + i3_1 + i3_2) * 4 + i3_3)
+ w = T.axis.spatial(8, i2_0 * 8 + i2_1 * 2 + i2_2 + i2_3)
+ cap_i = T.axis.spatial(4, i3_0 * 4 + i3_1 * 4 + i3_2 * 4 + i3_3)
cap_j = T.axis.spatial(4, i4_0 * 4 + i4_1 + i4_2 + i4_3)
- co = T.axis.spatial(32, (i5_0 * 2 + i5_1 + i5_2) * 16 + i5_3)
+ co = T.axis.spatial(32, i5_0 * 32 + i5_1 * 16 + i5_2 * 16 + i5_3)
rh = T.axis.reduce(3, i6_0 * 3 + i6_1)
- rw = T.axis.reduce(3, i7_0 + i7_1)
+ rw = T.axis.reduce(3, i7_1 + i7_0)
cap_k = T.axis.reduce(4, i8_0 + i8_1)
rc = T.axis.reduce(32, i9_0 * 32 + i9_1)
T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight[rh, rw, cap_k, cap_j, rc, co])
@@ -625,14 +625,14 @@ def test_cpu_cap():
PadInput[i0, i1, i2, i3, i4, i5] = T.if_then_else(1 <= i1 and i1 < 17 and 1 <= i2 and i2 < 17, inputs[i0, i1 - 1, i2 - 1, i3, i4, i5], T.float32(0), dtype="float32")
for i6_0, i7_0, i8_0, i9_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_2, i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3 in T.grid(1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16):
with T.block("conv2d_capsule_nhwijc"):
- n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
+ n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1)
h = T.axis.spatial(8, i1_0 * 4 + i1_1 + i1_2 + i1_3)
- w = T.axis.spatial(8, (i2_0 * 4 + i2_1) * 2 + i2_2 + i2_3)
- cap_i = T.axis.spatial(4, (i3_0 + i3_1 + i3_2) * 4 + i3_3)
+ w = T.axis.spatial(8, i2_0 * 8 + i2_1 * 2 + i2_2 + i2_3)
+ cap_i = T.axis.spatial(4, i3_0 * 4 + i3_1 * 4 + i3_2 * 4 + i3_3)
cap_j = T.axis.spatial(4, i4_0 * 4 + i4_1 + i4_2 + i4_3)
- co = T.axis.spatial(32, (i5_0 * 2 + i5_1 + i5_2) * 16 + i5_3)
+ co = T.axis.spatial(32, i5_0 * 32 + i5_1 * 16 + i5_2 * 16 + i5_3)
rh = T.axis.reduce(3, i6_0 * 3 + i6_1)
- rw = T.axis.reduce(3, i7_0 + i7_1)
+ rw = T.axis.reduce(3, i7_1 + i7_0)
cap_k = T.axis.reduce(4, i8_0 + i8_1)
rc = T.axis.reduce(32, i9_0 * 32 + i9_1)
T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight[rh, rw, cap_k, cap_j, rc, co])
@@ -667,14 +667,14 @@ def test_cpu_cap():
PadInput[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1] = T.if_then_else(1 <= i1_1 and i1_1 < 17 and 1 <= i2_1 and i2_1 < 17, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1, i4_1, i5_1], T.float32(0), dtype="float32")
for i0_0, i1_0, i2_0, i3_0, i4_0, i5_0, i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_1_1, i5_1_1, i6_0, i7_0, i8_0, i9_0, i0_2, i1_2, i2_2, i3_2, i4_2, i5_2, i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3 in T.grid(1, 2, 1, 1, 1, 1, 1, 4, 4, 1, 4, 2, 1, 3, 4, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 32, 1, 1, 1, 4, 1, 16):
with T.block("conv2d_capsule_nhwijc"):
- n = T.axis.spatial(1, i0_3 + i0_2 + i0_1_1 + i0_0)
+ n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1_1)
h = T.axis.spatial(8, i1_0 * 4 + i1_1_1 + i1_2 + i1_3)
- w = T.axis.spatial(8, (i2_0 * 4 + i2_1_1) * 2 + i2_2 + i2_3)
- cap_i = T.axis.spatial(4, (i3_0 + i3_1_1 + i3_2) * 4 + i3_3)
+ w = T.axis.spatial(8, i2_0 * 8 + i2_1_1 * 2 + i2_2 + i2_3)
+ cap_i = T.axis.spatial(4, i3_0 * 4 + i3_1_1 * 4 + i3_2 * 4 + i3_3)
cap_j = T.axis.spatial(4, i4_0 * 4 + i4_1_1 + i4_2 + i4_3)
- co = T.axis.spatial(32, (i5_0 * 2 + i5_1_1 + i5_2) * 16 + i5_3)
+ co = T.axis.spatial(32, i5_0 * 32 + i5_1_1 * 16 + i5_2 * 16 + i5_3)
rh = T.axis.reduce(3, i6_0 * 3 + i6_1)
- rw = T.axis.reduce(3, i7_0 + i7_1)
+ rw = T.axis.reduce(3, i7_1 + i7_0)
cap_k = T.axis.reduce(4, i8_0 + i8_1)
rc = T.axis.reduce(32, i9_0 * 32 + i9_1)
T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight[rh, rw, cap_k, cap_j, rc, co])
@@ -763,7 +763,7 @@ def test_cpu_dep():
for i0_0, i1_0, i2_0, i3_0, i0_1_1, i1_1_1, i2_1_1, i3_1_1 in T.grid(1, 1, 1, 1, 1, 4, 4, 8):
for i4_0, i5_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 1, 2, 7, 2, 3, 3, 1, 14, 4, 2):
with T.block("depth_conv2d_nhwc"):
- n = T.axis.spatial(1, i0_0 + i0_1_1 + i0_2 + i0_3)
+ n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1_1)
h = T.axis.spatial(112, i1_0 * 112 + i1_1_1 * 28 + i1_2 * 14 + i1_3)
w = T.axis.spatial(112, i2_0 * 112 + i2_1_1 * 28 + i2_2 * 4 + i2_3)
c = T.axis.spatial(32, i3_0 * 32 + i3_1_1 * 4 + i3_2 * 2 + i3_3)
@@ -804,7 +804,7 @@ def test_cpu_dep():
for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 1, 1, 1):
for i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_0, i5_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 4, 4, 8, 1, 1, 1, 2, 7, 2, 3, 3, 1, 14, 4, 2):
with T.block("depth_conv2d_nhwc"):
- n = T.axis.spatial(1, i0_0 + i0_1_1 + i0_2 + i0_3)
+ n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1_1)
h = T.axis.spatial(112, i1_0 * 112 + i1_1_1 * 28 + i1_2 * 14 + i1_3)
w = T.axis.spatial(112, i2_0 * 112 + i2_1_1 * 28 + i2_2 * 4 + i2_3)
c = T.axis.spatial(32, i3_0 * 32 + i3_1_1 * 4 + i3_2 * 2 + i3_3)
@@ -843,7 +843,7 @@ def test_cpu_dep():
PadInput[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 113 and 1 <= i2 and i2 < 113, placeholder[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32")
for i2_1, i3_1, i4_0, i5_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i0_3, i1_3, i2_3, i3_3 in T.grid(4, 8, 1, 1, 1, 2, 7, 2, 3, 3, 1, 14, 4, 2):
with T.block("depth_conv2d_nhwc"):
- n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
+ n = T.axis.spatial(1, i0_2 + i0_3 + i0_0 + i0_1)
h = T.axis.spatial(112, i1_0 * 112 + i1_1 * 28 + i1_2 * 14 + i1_3)
w = T.axis.spatial(112, i2_0 * 112 + i2_1 * 28 + i2_2 * 4 + i2_3)
c = T.axis.spatial(32, i3_0 * 32 + i3_1 * 4 + i3_2 * 2 + i3_3)
@@ -926,11 +926,11 @@ def test_cpu_dil():
PadInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 and i1 < 227 and 3 <= i2 and i2 < 227, inputs[i0, i1 - 3, i2 - 3, i3], T.float32(0), dtype="float32")
for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(7, 1, 1, 1, 1, 109, 8, 1, 7, 3, 1, 1, 1, 1):
with T.block("conv2d_nhwc"):
- n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
- h = T.axis.spatial(109, i1_0 + i1_1 + i1_2 + i1_3)
- w = T.axis.spatial(109, (i2_0 + i2_1) * 109 + i2_2 + i2_3)
- co = T.axis.spatial(64, (i3_0 * 2 + i3_1) * 8 + i3_2 + i3_3)
- rh = T.axis.reduce(7, i4_0 + i4_1)
+ n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
+ h = T.axis.spatial(109, i1_2 + i1_3 + i1_0 + i1_1)
+ w = T.axis.spatial(109, i2_3 + i2_0 * 109 + i2_1 * 109 + i2_2)
+ co = T.axis.spatial(64, i3_0 * 16 + i3_1 * 8 + i3_2 + i3_3)
+ rh = T.axis.reduce(7, i4_1 + i4_0)
rw = T.axis.reduce(7, i5_0 * 7 + i5_1)
rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
T.reads(PadInput[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc], weight[rh, rw, rc, co])
@@ -972,11 +972,11 @@ def test_cpu_dil():
PadInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 and i1 < 227 and 3 <= i2 and i2 < 227, inputs[i0, i1 - 3, i2 - 3, i3], T.float32(0), dtype="float32")
for i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 1, 1, 109, 8, 1, 7, 3, 1, 1, 1, 1):
with T.block("conv2d_nhwc"):
- n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
- h = T.axis.spatial(109, i1_0 + i1_1 + i1_2 + i1_3)
- w = T.axis.spatial(109, (i2_0 + i2_1) * 109 + i2_2 + i2_3)
- co = T.axis.spatial(64, (i3_0 * 2 + i3_1) * 8 + i3_2 + i3_3)
- rh = T.axis.reduce(7, i4_0 + i4_1)
+ n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
+ h = T.axis.spatial(109, i1_2 + i1_3 + i1_0 + i1_1)
+ w = T.axis.spatial(109, i2_3 + i2_0 * 109 + i2_1 * 109 + i2_2)
+ co = T.axis.spatial(64, i3_0 * 16 + i3_1 * 8 + i3_2 + i3_3)
+ rh = T.axis.reduce(7, i4_1 + i4_0)
rw = T.axis.reduce(7, i5_0 * 7 + i5_1)
rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
T.reads(PadInput[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc], weight[rh, rw, rc, co])
@@ -1016,11 +1016,11 @@ def test_cpu_dil():
PadInput[i0, i1, i2, i3] = T.if_then_else(3 <= i1 and i1 < 227 and 3 <= i2 and i2 < 227, inputs[i0, i1 - 3, i2 - 3, i3], T.float32(0), dtype="float32")
for i2_0, i3_0, i0_1, i1_1, i2_1, i3_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 4, 1, 1, 1, 2, 7, 1, 1, 1, 1, 109, 8, 1, 7, 3, 1, 1, 1, 1):
with T.block("conv2d_nhwc"):
- n = T.axis.spatial(1, i0_3 + i0_2 + i0_1 + i0_0)
- h = T.axis.spatial(109, i1_0 + i1_1 + i1_2 + i1_3)
- w = T.axis.spatial(109, (i2_0 + i2_1) * 109 + i2_2 + i2_3)
- co = T.axis.spatial(64, (i3_0 * 2 + i3_1) * 8 + i3_2 + i3_3)
- rh = T.axis.reduce(7, i4_0 + i4_1)
+ n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
+ h = T.axis.spatial(109, i1_2 + i1_3 + i1_0 + i1_1)
+ w = T.axis.spatial(109, i2_3 + i2_0 * 109 + i2_1 * 109 + i2_2)
+ co = T.axis.spatial(64, i3_0 * 16 + i3_1 * 8 + i3_2 + i3_3)
+ rh = T.axis.reduce(7, i4_1 + i4_0)
rw = T.axis.reduce(7, i5_0 * 7 + i5_1)
rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
T.reads(PadInput[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc], weight[rh, rw, rc, co])
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py
index b8723e286a..7323bc441f 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda.py
@@ -47,7 +47,7 @@ def test_cuda_c1d():
for ax0_ax1_ax2_fused in T.serial(260):
with T.block("PadInput_shared"):
v0 = T.axis.spatial(1, 0)
- v1 = T.axis.spatial(258, i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused % 260 // 4)
+ v1 = T.axis.spatial(258, i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused // 4)
v2 = T.axis.spatial(64, i4_0 * 4 + ax0_ax1_ax2_fused % 4)
T.reads(inputs[v0, v1 - 1, v2])
T.writes(PadInput_shared[v0, v1, v2])
@@ -64,11 +64,11 @@ def test_cuda_c1d():
weight_shared[v0, v1, v2] = weight[v0, v1, v2]
for i3_1, i4_1, i0_3, i1_3, i2_3, i3_2, i4_2, i0_4, i1_4, i2_4 in T.grid(1, 2, 1, 1, 2, 3, 2, 1, 4, 8):
with T.block("conv1d_nlc"):
- n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
- l = T.axis.spatial(128, (i0_0_i1_0_i2_0_fused % 4 * 8 + i0_1_i1_1_i2_1_fused % 16 // 2 + 0 + i1_3) * 4 + i1_4)
- co = T.axis.spatial(128, (((0 * 2 + i0_1_i1_1_i2_1_fused % 2) * 4 + i0_2_i1_2_i2_2_fused % 4) * 2 + i2_3) * 8 + i2_4)
- rl = T.axis.reduce(3, (i3_0 + i3_1) * 3 + i3_2)
- rc = T.axis.reduce(64, (i4_0 * 2 + i4_1) * 2 + i4_2)
+ n = T.axis.spatial(1, i0_4 + i0_3)
+ l = T.axis.spatial(128, i0_0_i1_0_i2_0_fused * 32 + i0_1_i1_1_i2_1_fused // 2 * 4 + i1_3 * 4 + i1_4)
+ co = T.axis.spatial(128, i0_1_i1_1_i2_1_fused % 2 * 64 + i0_2_i1_2_i2_2_fused * 16 + i2_3 * 8 + i2_4)
+ rl = T.axis.reduce(3, i3_0 * 3 + i3_1 * 3 + i3_2)
+ rc = T.axis.reduce(64, i4_0 * 4 + i4_1 * 2 + i4_2)
T.reads(PadInput_shared[n, l * 2 + rl, co // 128 * 64 + rc], weight_shared[rl, rc, co])
T.writes(conv1d_nlc_local[n, l, co])
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
@@ -130,7 +130,7 @@ def test_cuda_c2d():
for ax0_ax1_ax2_ax3_fused in T.serial(80379):
with T.block("PadInput_shared"):
v0 = T.axis.spatial(1, 0)
- v1 = T.axis.spatial(230, ax0_ax1_ax2_ax3_fused % 80379 // 351)
+ v1 = T.axis.spatial(230, ax0_ax1_ax2_ax3_fused // 351)
v2 = T.axis.spatial(230, i0_0_i1_0_i2_0_i3_0_fused // 8 * 112 + ax0_ax1_ax2_ax3_fused % 351 // 3)
v3 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 3)
T.reads(inputs[v0, v1 - 3, v2 - 3, v3])
@@ -149,13 +149,13 @@ def test_cuda_c2d():
weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3]
for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 7, 1, 1, 8, 4, 1, 7, 1, 3, 1, 1, 1, 2):
with T.block("conv2d_nhwc"):
- n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
- h = T.axis.spatial(112, ((0 + 0) * 14 + i0_2_i1_2_i2_2_i3_2_fused % 14) * 8 + i1_3 + i1_4)
- w = T.axis.spatial(112, (i0_0_i1_0_i2_0_i3_0_fused % 16 // 8 * 14 + i0_1_i1_1_i2_1_i3_1_fused % 56 // 4 + 0) * 4 + i2_3 + i2_4)
- co = T.axis.spatial(64, (i0_0_i1_0_i2_0_i3_0_fused % 8 * 4 + i0_1_i1_1_i2_1_i3_1_fused % 4 + 0 + i3_3) * 2 + i3_4)
- rh = T.axis.reduce(7, (i4_0 + i4_1) * 7 + i4_2)
- rw = T.axis.reduce(7, i5_0 * 7 + i5_1 + i5_2)
- rc = T.axis.reduce(3, (i6_0 + i6_1) * 3 + i6_2)
+ n = T.axis.spatial(1, i0_3 + i0_4)
+ h = T.axis.spatial(112, i1_4 + i0_2_i1_2_i2_2_i3_2_fused * 8 + i1_3)
+ w = T.axis.spatial(112, i0_0_i1_0_i2_0_i3_0_fused // 8 * 56 + i0_1_i1_1_i2_1_i3_1_fused // 4 * 4 + i2_3 + i2_4)
+ co = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 8 * 8 + i0_1_i1_1_i2_1_i3_1_fused % 4 * 2 + i3_3 * 2 + i3_4)
+ rh = T.axis.reduce(7, i4_0 * 7 + i4_1 * 7 + i4_2)
+ rw = T.axis.reduce(7, i5_2 + i5_0 * 7 + i5_1)
+ rc = T.axis.reduce(3, i6_0 * 3 + i6_1 * 3 + i6_2)
T.reads(PadInput_shared[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight_shared[rh, rw, rc, co])
T.writes(conv2d_nhwc_local[n, h, w, co])
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
@@ -219,7 +219,7 @@ def test_cuda_c3d():
for ax0_ax1_ax2_ax3_ax4_fused in T.serial(1687959):
with T.block("PadInput_shared"):
v0 = T.axis.spatial(1, 0)
- v1 = T.axis.spatial(22, ax0_ax1_ax2_ax3_ax4_fused % 1687959 // 80379)
+ v1 = T.axis.spatial(22, ax0_ax1_ax2_ax3_ax4_fused // 80379)
v2 = T.axis.spatial(230, ax0_ax1_ax2_ax3_ax4_fused % 80379 // 351)
v3 = T.axis.spatial(230, i0_0_i1_0_i2_0_i3_0_i4_0_fused * 112 + ax0_ax1_ax2_ax3_ax4_fused % 351 // 3)
v4 = T.axis.spatial(3, ax0_ax1_ax2_ax3_ax4_fused % 3)
@@ -240,14 +240,14 @@ def test_cuda_c3d():
weight_shared[v0, v1, v2, v3, v4] = weight[v0, v1, v2, v3, v4]
for i5_1, i6_1, i7_1, i8_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_2, i6_2, i7_2, i8_2, i0_4, i1_4, i2_4, i3_4, i4_4 in T.grid(7, 7, 1, 3, 1, 2, 2, 1, 32, 1, 1, 7, 1, 1, 1, 2, 4, 1):
with T.block("conv3d_ndhwc"):
- n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
- d = T.axis.spatial(8, ((0 + 0) * 4 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 392 // 98) * 2 + i1_3 + i1_4)
- h = T.axis.spatial(112, (((0 * 4 + i0_1_i1_1_i2_1_i3_1_i4_1_fused % 8 // 2) * 7 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 98 // 14) * 2 + i2_3) * 2 + i2_4)
- w = T.axis.spatial(112, ((i0_0_i1_0_i2_0_i3_0_i4_0_fused % 2 * 2 + i0_1_i1_1_i2_1_i3_1_i4_1_fused % 2) * 7 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 14 // 2 + i3_3) * 4 + i3_4)
- co = T.axis.spatial(64, ((0 + 0) * 2 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 2) * 32 + i4_3 + i4_4)
- rd = T.axis.reduce(7, i5_0 * 7 + i5_1 + i5_2)
+ n = T.axis.spatial(1, i0_4 + i0_3)
+ d = T.axis.spatial(8, i1_4 + i0_2_i1_2_i2_2_i3_2_i4_2_fused // 98 * 2 + i1_3)
+ h = T.axis.spatial(112, i0_1_i1_1_i2_1_i3_1_i4_1_fused // 2 * 28 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 98 // 14 * 4 + i2_3 * 2 + i2_4)
+ w = T.axis.spatial(112, i0_0_i1_0_i2_0_i3_0_i4_0_fused * 56 + i0_1_i1_1_i2_1_i3_1_i4_1_fused % 2 * 28 + i0_2_i1_2_i2_2_i3_2_i4_2_fused % 14 // 2 * 4 + i3_3 * 4 + i3_4)
+ co = T.axis.spatial(64, i0_2_i1_2_i2_2_i3_2_i4_2_fused % 2 * 32 + i4_3 + i4_4)
+ rd = T.axis.reduce(7, i5_2 + i5_0 * 7 + i5_1)
rh = T.axis.reduce(7, i6_0 * 7 + i6_1 + i6_2)
- rw = T.axis.reduce(7, (i7_0 + i7_1) * 7 + i7_2)
+ rw = T.axis.reduce(7, i7_0 * 7 + i7_1 * 7 + i7_2)
rc = T.axis.reduce(3, i8_0 * 3 + i8_1 + i8_2)
T.reads(PadInput_shared[n, d * 2 + rd, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight_shared[rd, rh, rw, rc, co])
T.writes(conv3d_ndhwc_local[n, d, h, w, co])
@@ -338,15 +338,15 @@ def test_cuda_cap():
weight_shared[v0, v1, v2, v3, v4, v5] = weight[v0, v1, v2, v3, v4, v5]
for i6_1, i7_1, i8_1, i9_1, i0_3, i1_3, i2_3, i3_3, i4_3, i5_3, i6_2, i7_2, i8_2, i9_2, i0_4, i1_4, i2_4, i3_4, i4_4, i5_4 in T.grid(1, 1, 1, 4, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 8):
with T.block("conv2d_capsule_nhwijc"):
- n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
- h = T.axis.spatial(8, (i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 256 // 64 + 0 + 0) * 2 + i1_3 + i1_4)
- w = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8 + 0 + 0 + i2_3 + i2_4)
- cap_i = T.axis.spatial(4, (i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 8 // 4 + 0) * 2 + i0_2_i1_2_i2_2_i3_2_i4_2_i5_2_fused % 4 // 2 + i3_3 + i3_4)
- cap_j = T.axis.spatial(4, ((0 + 0) * 2 + i0_2_i1_2_i2_2_i3_2_i4_2_i5_2_fused % 2 + i4_3) * 2 + i4_4)
- co = T.axis.spatial(32, (i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 4 + 0 + 0 + i5_3) * 8 + i5_4)
- rh = T.axis.reduce(3, i6_0 + i6_1 + i6_2)
+ n = T.axis.spatial(1, i0_4 + i0_3)
+ h = T.axis.spatial(8, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused // 64 * 2 + i1_3 + i1_4)
+ w = T.axis.spatial(8, i2_3 + i2_4 + i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8)
+ cap_i = T.axis.spatial(4, i3_3 + i3_4 + i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 8 // 4 * 2 + i0_2_i1_2_i2_2_i3_2_i4_2_i5_2_fused // 2)
+ cap_j = T.axis.spatial(4, i0_2_i1_2_i2_2_i3_2_i4_2_i5_2_fused % 2 * 2 + i4_3 * 2 + i4_4)
+ co = T.axis.spatial(32, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 4 * 8 + i5_3 * 8 + i5_4)
+ rh = T.axis.reduce(3, i6_1 + i6_2 + i6_0)
rw = T.axis.reduce(3, i7_0 + i7_1 + i7_2)
- cap_k = T.axis.reduce(4, (i8_0 + i8_1) * 2 + i8_2)
+ cap_k = T.axis.reduce(4, i8_0 * 2 + i8_1 * 2 + i8_2)
rc = T.axis.reduce(32, i9_0 * 4 + i9_1 + i9_2)
T.reads(PadInput_shared[n, h * 2 + rh, w * 2 + rw, cap_i, cap_k, rc], weight_shared[rh, rw, cap_k, cap_j, rc, co])
T.writes(conv2d_capsule_nhwijc_local[n, h, w, cap_i, cap_j, co])
@@ -436,12 +436,12 @@ def test_cuda_dep():
placeholder_shared[v0, v1, v2, v3] = placeholder_1[v0, v1, v2, v3]
for i4_1, i5_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i0_4, i1_4, i2_4, i3_4 in T.grid(3, 1, 1, 4, 16, 8, 1, 3, 1, 7, 1, 1):
with T.block("depth_conv2d_nhwc"):
- n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
- h = T.axis.spatial(112, ((0 * 4 + i0_1_i1_1_i2_1_i3_1_fused % 8 // 2 + 0) * 4 + i1_3) * 7 + i1_4)
- w = T.axis.spatial(112, ((0 + 0) * 7 + i0_2_i1_2_i2_2_i3_2_fused % 14 // 2) * 16 + i2_3 + i2_4)
- c = T.axis.spatial(32, ((0 * 2 + i0_1_i1_1_i2_1_i3_1_fused % 2) * 2 + i0_2_i1_2_i2_2_i3_2_fused % 2) * 8 + i3_3 + i3_4)
- rh = T.axis.reduce(3, i4_0 * 3 + i4_1 + i4_2)
- rw = T.axis.reduce(3, (i5_0 + i5_1) * 3 + i5_2)
+ n = T.axis.spatial(1, i0_4 + i0_3)
+ h = T.axis.spatial(112, i0_1_i1_1_i2_1_i3_1_fused // 2 * 28 + i1_3 * 7 + i1_4)
+ w = T.axis.spatial(112, i2_4 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 16 + i2_3)
+ c = T.axis.spatial(32, i0_1_i1_1_i2_1_i3_1_fused % 2 * 16 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 8 + i3_3 + i3_4)
+ rh = T.axis.reduce(3, i4_2 + i4_0 * 3 + i4_1)
+ rw = T.axis.reduce(3, i5_0 * 3 + i5_1 * 3 + i5_2)
T.reads(PadInput_shared[n, h + rh, w + rw, c], placeholder_shared[0, rh, rw, c])
T.writes(depth_conv2d_nhwc_local[n, h, w, c])
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
@@ -522,13 +522,13 @@ def test_cuda_dil():
weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3]
for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 1, 1, 1, 1, 1, 8, 1, 1, 1, 1, 1, 1, 4):
with T.block("conv2d_nhwc"):
- n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
- h = T.axis.spatial(109, i0_0_i1_0_i2_0_i3_0_fused % 218 // 2 + 0 + 0 + i1_3 + i1_4)
- w = T.axis.spatial(109, 0 * 109 + i0_1_i1_1_i2_1_i3_1_fused % 109 + 0 + i2_3 + i2_4)
- co = T.axis.spatial(64, ((i0_0_i1_0_i2_0_i3_0_fused % 2 + 0 + 0) * 8 + i3_3) * 4 + i3_4)
+ n = T.axis.spatial(1, i0_3 + i0_4)
+ h = T.axis.spatial(109, i1_4 + i0_0_i1_0_i2_0_i3_0_fused // 2 + i1_3)
+ w = T.axis.spatial(109, i0_1_i1_1_i2_1_i3_1_fused + i2_3 + i2_4)
+ co = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i3_3 * 4 + i3_4)
rh = T.axis.reduce(7, i4_0 + i4_1 + i4_2)
- rw = T.axis.reduce(7, i5_0 + i5_1 + i5_2)
- rc = T.axis.reduce(3, i6_0 + i6_1 + i6_2)
+ rw = T.axis.reduce(7, i5_2 + i5_0 + i5_1)
+ rc = T.axis.reduce(3, i6_1 + i6_2 + i6_0)
T.reads(PadInput_shared[n, h * 2 + rh * 2, w * 2 + rw * 2, co // 64 * 3 + rc], weight_shared[rh, rw, rc, co])
T.writes(conv2d_nhwc_local[n, h, w, co])
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py
index 4351fe5b63..b859b655ef 100644
--- a/tests/python/unittest/test_tir_schedule_reorder.py
+++ b/tests/python/unittest/test_tir_schedule_reorder.py
@@ -214,9 +214,9 @@ def test_reorder_with_opaque_access():
verify_trace_roundtrip(sch=sch, mod=opaque_access)
-def test_reorder_with_partial_affineness():
+def test_reorder_overlapped_access():
@T.prim_func
- def non_affine_func(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
+ def overlapped_access(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
# example to write first axis multiple times
for v0, v1, v2 in T.grid(6, 4, 4):
with T.block("block"):
@@ -225,7 +225,7 @@ def test_reorder_with_partial_affineness():
B[i, j] = A[i, j] + 1.0
@T.prim_func
- def non_affine_func_reorder(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
+ def overlapped_access_reorder(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
# example to write first axis multiple times
for v0, v2, v1 in T.grid(6, 4, 4):
with T.block("block"):
@@ -233,6 +233,30 @@ def test_reorder_with_partial_affineness():
j = T.axis.spatial(4, v2)
B[i, j] = A[i, j] + 1.0
+ sch = tir.Schedule(overlapped_access, debug_mask="all")
+ v0, v1, v2 = sch.get_loops(sch.get_block("block"))
+ sch.reorder(v0, v2, v1)
+ tvm.ir.assert_structural_equal(overlapped_access_reorder, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=overlapped_access)
+
+
+def test_reorder_with_partial_affineness():
+ @T.prim_func
+ def non_affine_func(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
+ for v0, v1, v2 in T.grid(6, 4, 4):
+ with T.block("block"):
+ i = T.axis.spatial(14, v0 * v0 + v1)
+ j = T.axis.spatial(4, v2)
+ B[i, j] = A[i, j] + 1.0
+
+ @T.prim_func
+ def non_affine_func_reorder(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
+ for v0, v2, v1 in T.grid(6, 4, 4):
+ with T.block("block"):
+ i = T.axis.spatial(14, v0 * v0 + v1)
+ j = T.axis.spatial(4, v2)
+ B[i, j] = A[i, j] + 1.0
+
sch = tir.Schedule(non_affine_func, debug_mask="all")
v0, v1, v2 = sch.get_loops(sch.get_block("block"))
with pytest.raises(tvm.tir.ScheduleError):
diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py
index 9fd678174d..3ae88e0abb 100644
--- a/tests/python/unittest/test_tir_schedule_split_fuse.py
+++ b/tests/python/unittest/test_tir_schedule_split_fuse.py
@@ -177,7 +177,7 @@ def elementwise_split_case0(a: T.handle, b: T.handle) -> None:
B = T.match_buffer(b, [128, 128, 128])
for i1, i2, i3, j1, j2, k1, k2 in T.grid(2, 1, 64, 4, 32, 16, 8):
with T.block("B"):
- vi = T.axis.S(128, (i1 + i2) * 64 + i3)
+ vi = T.axis.S(128, i1 * 64 + i2 * 64 + i3)
vj = T.axis.S(128, j1 * 32 + j2)
vk = T.axis.S(128, k1 * 8 + k2)
T.reads([A[vi, vj, vk]])
@@ -191,9 +191,9 @@ def elementwise_split_case1(a: T.handle, b: T.handle) -> None:
B = T.match_buffer(b, [128, 128, 128])
for i1, i2, i3, j1, j2, j3, k1, k2, k3 in T.grid(2, 1, 64, 2, 1, 64, 2, 1, 64):
with T.block("B"):
- vi = T.axis.S(128, (i1 + i2) * 64 + i3)
- vj = T.axis.S(128, (j1 + j2) * 64 + j3)
- vk = T.axis.S(128, (k1 + k2) * 64 + k3)
+ vi = T.axis.S(128, i1 * 64 + i2 * 64 + i3)
+ vj = T.axis.S(128, j1 * 64 + j2 * 64 + j3)
+ vk = T.axis.S(128, k1 * 64 + k2 * 64 + k3)
T.reads([A[vi, vj, vk]])
T.writes([B[vi, vj, vk]])
B[vi, vj, vk] = A[vi, vj, vk] * 2.0
diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
index 1b4c34973f..bbeb8d8760 100644
--- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py
+++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
@@ -758,7 +758,7 @@ def test_non_perfect_tiling_cache():
s = tir.ScheduleState(non_perfect_tiling_cache, debug_mask="all")
# pylint: disable=protected-access
assert s._get_cached_flags(_get_block(s, "cache")) == CachedFlags(
- affine_binding=False,
+ affine_binding=True,
region_cover=True,
stage_pipeline=True,
)