You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/11/14 22:51:45 UTC
[tvm] branch main updated: [TIR][Bugfix] Fix AXIS_SEPARATORS in tir.Schedule.transform_layout (#13326)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b6fae9b35e [TIR][Bugfix] Fix AXIS_SEPARATORS in tir.Schedule.transform_layout (#13326)
b6fae9b35e is described below
commit b6fae9b35eff4ad1f7cc2e83d8d7da5d701d8e44
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon Nov 14 16:51:38 2022 -0600
[TIR][Bugfix] Fix AXIS_SEPARATORS in tir.Schedule.transform_layout (#13326)
Preivously, the block SREF reuse only included a single step of
changes, and would have an incorrect mapping if multiple sequential
changes to the TIR block occurred. This could happen if a
`BufferStore` was updated, followed by replacement of `Block` iter
vars/values. This commit tracks the Block replacements across each
usage, to ensure the SREF instances remain valid.
---
.../schedule/primitive/layout_transformation.cc | 462 +++++++++++++--------
src/tir/schedule/state.cc | 7 +-
.../unittest/test_tir_schedule_transform_layout.py | 48 ++-
3 files changed, 326 insertions(+), 191 deletions(-)
diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc
index e4c91dac58..c0b4ddfb4a 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -73,7 +73,7 @@ class TransformLayoutPlanner : private StmtExprVisitor {
// Loops within the analyzed block that should be replaced
struct ReplacementPlan {
Map<For, Stmt> replacements;
- Map<Block, Block> block_sref_reuse;
+ Map<Block, Block> new_block_to_old;
};
// The block to be inserted, along with the location at which it
@@ -100,6 +100,25 @@ class TransformLayoutPlanner : private StmtExprVisitor {
}
private:
+ struct WriteInfo {
+ // The BufferStore object
+ BufferStore store;
+
+ // The block realize that contains the store, if any.
+ Optional<BlockRealize> innermost_block_realize;
+
+ // The nested loops whose values contribute to the indices used in
+ // the store. Not all loop variables in the loopnest need to
+ // contribute, but the first and last must.
+ std::vector<For> dependent_loopnest;
+
+ // Whether the padding could be represented as a tir::if_then_else
+ // node. This requires that the surrounding loop iterators
+ // iterate over all pre-transformation buffer axes, that there are
+ // no data dependencies between loop iterations, and that
+ bool contains_row_major_traversal{false};
+ };
+
explicit TransformLayoutPlanner(Buffer old_buffer) : old_buffer_(old_buffer) {}
void VisitStmt_(const ForNode* op) override {
@@ -197,33 +216,217 @@ class TransformLayoutPlanner : private StmtExprVisitor {
class BufferStoreReplacer : public StmtExprMutator {
public:
- BufferStoreReplacer(std::function<Optional<Stmt>(const BufferStoreNode*)> replace_store,
- std::function<Optional<Stmt>(const BlockRealizeNode*, const BlockRealize&)>
- replace_block_realize)
- : replace_store_(replace_store), replace_block_realize_(replace_block_realize) {}
+ BufferStoreReplacer(const WriteInfo& info, const Buffer& new_buffer, PrimExpr padding_predicate,
+ const IndexMap& inverse, const Optional<IndexMap>& pad_value,
+ Map<Block, Block>* new_block_to_old)
+ : info(info),
+ new_buffer(new_buffer),
+ new_indices(inverse->initial_indices.Map([](const Var& var) -> PrimExpr { return var; })),
+ padding_predicate(padding_predicate),
+ inverse(inverse),
+ pad_value(pad_value),
+ new_block_to_old(*new_block_to_old) {
+ ICHECK_EQ(info.dependent_loopnest.size(), inverse->final_indices.size());
+ for (size_t i = 0; i < info.dependent_loopnest.size(); i++) {
+ Var var = info.dependent_loopnest[i]->loop_var;
+ PrimExpr expr = inverse->final_indices[i];
+ var_remap.Set(var, expr);
+ }
+
+ DefineBlockUpdates();
+ }
+
+ bool is_all_stores_replaced() const { return all_stores_replaced; }
+
+ private:
+ void DefineBlockUpdates() {
+ if (!info.innermost_block_realize) {
+ return;
+ }
+
+ BlockRealize block_realize = info.innermost_block_realize.value();
+ const auto& block = block_realize->block;
+ const Array<PrimExpr>& old_indices = info.store->indices;
+ const auto& old_iter_vars = block->iter_vars;
+
+ this->new_iter_vars = old_iter_vars;
+ this->new_iter_values = block_realize->iter_values;
+
+ if (old_indices.empty()) {
+ return;
+ }
+
+ // Find the block iterators that are used to access the buffer. Must be in the same
+ // order as they appear in the indices.
+ if (block->iter_vars.size() < old_indices.size()) {
+ return;
+ }
+
+ size_t block_index_start = 0;
+ for (; block_index_start < old_iter_vars.size() - old_indices.size(); block_index_start++) {
+ if (old_indices[0].same_as(old_iter_vars[block_index_start]->var)) {
+ break;
+ }
+ }
+ if (block_index_start > old_iter_vars.size() - old_indices.size()) {
+ return;
+ }
+
+ for (size_t i = 0; i < old_indices.size(); i++) {
+ if (!old_indices[i].same_as(old_iter_vars[block_index_start + i]->var) ||
+ old_iter_vars[block_index_start + i]->iter_type != kDataPar) {
+ return;
+ }
+ }
+
+ // If we got to this point, all indices used to access the
+ // buffer are virtual indices defined in the innermost block.
+ // Therefore, generate new virtual indices for iterating over
+ // the post-transform buffer.
+
+ new_indices = inverse->initial_indices.Map([](Var var) -> PrimExpr {
+ std::stringstream ss;
+ ss << "v_" << var->name_hint;
+ return Var(ss.str(), var.dtype());
+ });
+
+ Map<Var, PrimExpr>
+ loop_var_to_virtual_var; // For updating padding_predicate in terms of the new indices
+ Array<PrimExpr> new_iter_values; // For BlockRealize
+ Array<IterVar> new_iter_vars; // For Block
+
+ for (size_t i = 0; i < block_index_start; i++) {
+ new_iter_vars.push_back(old_iter_vars[i]);
+ new_iter_values.push_back(block_realize->iter_values[i]);
+ }
+
+ ICHECK_EQ(new_indices.size(), new_buffer->shape.size());
+ for (size_t i = 0; i < new_indices.size(); i++) {
+ Var var = inverse->initial_indices[i];
+ Var virtual_var = Downcast<Var>(new_indices[i]);
+ PrimExpr dim = new_buffer->shape[i];
+ new_iter_values.push_back(var);
+ new_iter_vars.push_back(
+ IterVar(Range::FromMinExtent(make_zero(dim.dtype()), dim), virtual_var, kDataPar));
+ loop_var_to_virtual_var.Set(var, virtual_var);
+ }
+
+ for (size_t i = block_index_start + old_indices.size(); i < old_iter_vars.size(); i++) {
+ new_iter_vars.push_back(old_iter_vars[i]);
+ new_iter_values.push_back(block_realize->iter_values[i]);
+ }
+
+ ICHECK_EQ(inverse->final_indices.size(), old_indices.size());
+ for (size_t i = 0; i < old_indices.size(); i++) {
+ Var var = Downcast<Var>(old_indices[i]);
+ PrimExpr expr = Substitute(inverse->final_indices[i], loop_var_to_virtual_var);
+ var_remap.Set(var, expr);
+ }
+
+ padding_predicate = Substitute(padding_predicate, loop_var_to_virtual_var);
+
+ this->new_iter_vars = new_iter_vars;
+ this->new_iter_values = new_iter_values;
+ }
Stmt VisitStmt_(const BufferStoreNode* op) final {
- if (auto replacement = replace_store_(op)) {
- auto store = Downcast<BufferStore>(replacement.value());
- return StmtExprMutator::VisitStmt_(store.get());
+ bool can_replace = [&]() -> bool {
+ if (!op->buffer.same_as(info.store->buffer)) {
+ return false;
+ }
+
+ const Array<PrimExpr>& old_indices = info.store->indices;
+
+ ICHECK_EQ(old_indices.size(), op->indices.size());
+ ExprDeepEqual expr_equal;
+ for (size_t i = 0; i < old_indices.size(); i++) {
+ if (!expr_equal(old_indices[i], op->indices[i])) {
+ return false;
+ }
+ }
+ return true;
+ }();
+
+ BufferStore store = GetRef<BufferStore>(op);
+ if (can_replace) {
+ PrimExpr pad_value_at_index = pad_value.value()->MapIndices(new_indices)[0];
+ store =
+ BufferStore(new_buffer, if_then_else(padding_predicate, pad_value_at_index, op->value),
+ new_indices);
} else {
- return StmtExprMutator::VisitStmt_(op);
+ all_stores_replaced = false;
}
+ return StmtExprMutator::VisitStmt_(store.get());
}
Stmt VisitStmt_(const BlockRealizeNode* op) final {
- auto realize = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
- if (auto replacement = replace_block_realize_(op, realize)) {
- return replacement.value();
+ BlockRealize realize = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
+
+ if (op == info.innermost_block_realize.get()) {
+ Block block = realize->block;
+ if (!block->iter_vars.same_as(this->new_iter_vars)) {
+ block.CopyOnWrite()->iter_vars = this->new_iter_vars;
+ RecordReplacement(op->block, block);
+ }
+
+ if (!block.same_as(realize->block) ||
+ !realize->iter_values.same_as(this->new_iter_values)) {
+ auto write_ptr = realize.CopyOnWrite();
+ write_ptr->block = block;
+ write_ptr->iter_values = this->new_iter_values;
+ }
+ }
+
+ return std::move(realize);
+ }
+
+ Stmt VisitStmt_(const BlockNode* op) final {
+ Block orig = GetRef<Block>(op);
+ Block mutated = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
+
+ RecordReplacement(orig, mutated);
+ return std::move(mutated);
+ }
+
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ Var var = GetRef<Var>(op);
+ if (auto opt = var_remap.Get(var)) {
+ return opt.value();
} else {
- return std::move(realize);
+ return std::move(var);
}
}
- private:
- std::function<Optional<Stmt>(const BufferStoreNode*)> replace_store_;
- std::function<Optional<Stmt>(const BlockRealizeNode*, const BlockRealize&)>
- replace_block_realize_;
+ void RecordReplacement(Block before, Block after) {
+ if (before.same_as(after)) {
+ return;
+ }
+
+ ICHECK(!new_block_to_old.count(after));
+
+ while (true) {
+ if (auto opt = new_block_to_old.Get(before)) {
+ before = opt.value();
+ } else {
+ break;
+ }
+ }
+
+ new_block_to_old.Set(after, before);
+ }
+
+ const WriteInfo& info;
+ const Buffer& new_buffer;
+ Array<PrimExpr> new_indices;
+ Array<IterVar> new_iter_vars;
+ Array<PrimExpr> new_iter_values;
+ PrimExpr padding_predicate;
+ const IndexMap& inverse;
+ const Optional<IndexMap>& pad_value;
+ Map<Block, Block>& new_block_to_old;
+ bool all_stores_replaced{true};
+
+ Map<Var, PrimExpr> var_remap;
};
TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap inverse,
@@ -296,159 +499,20 @@ class TransformLayoutPlanner : private StmtExprVisitor {
return std::nullopt;
}
+ Map<Block, Block> new_block_to_old;
auto generate_if_then_else_block = [&](const WriteInfo& info) -> Optional<Stmt> {
if (!info.contains_row_major_traversal || !pad_value.defined() ||
is_zero(padding_predicate)) {
return NullOpt;
}
- Array<PrimExpr> old_indices = info.store->indices;
- PrimExpr if_then_else_condition = padding_predicate;
- Array<PrimExpr> new_indices;
- for (const auto& var : inverse->initial_indices) {
- new_indices.push_back(var);
- }
-
- auto replace_block_realize =
- [&]() -> std::function<Optional<Stmt>(const BlockRealizeNode*, const BlockRealize&)> {
- auto no_change = [](const BlockRealizeNode*, const BlockRealize&) -> Optional<Stmt> {
- return NullOpt;
- };
- if (!info.innermost_block_realize) {
- return no_change;
- }
- if (old_indices.empty()) {
- return no_change;
- }
-
- BlockRealize block_realize = info.innermost_block_realize.value();
- const auto& block = block_realize->block;
-
- // Find the block iterators that are used to access the buffer. Must be in the same order
- // as they appear in the indices.
- if (block->iter_vars.size() < old_indices.size()) {
- return no_change;
- }
- const auto& iter_vars = block->iter_vars;
- size_t block_index_start = 0;
- for (; block_index_start < iter_vars.size() - old_indices.size(); block_index_start++) {
- if (old_indices[0].same_as(iter_vars[block_index_start]->var)) {
- break;
- }
- }
- if (block_index_start > iter_vars.size() - old_indices.size()) {
- return no_change;
- }
-
- for (size_t i = 0; i < old_indices.size(); i++) {
- if (!old_indices[i].same_as(iter_vars[block_index_start + i]->var) ||
- iter_vars[block_index_start + i]->iter_type != kDataPar) {
- return no_change;
- }
- }
-
- // If we got to this point, all indices used to access the
- // buffer are virtual indices defined in the innermost block.
- // Therefore, generate new virtual indices for iterating over
- // the post-transform buffer.
- Array<PrimExpr> new_iter_values; // For BlockRealize
- Array<IterVar> new_iter_vars; // For Block
- Array<PrimExpr> new_access_indices; // For BufferStore
- Map<Var, PrimExpr> loop_var_to_virtual_var; // For updating if_then_else_condition
-
- for (size_t i = 0; i < block_index_start; i++) {
- new_iter_vars.push_back(iter_vars[i]);
- new_iter_values.push_back(block_realize->iter_values[i]);
- }
-
- ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
- for (size_t i = 0; i < inverse->initial_indices.size(); i++) {
- Var var = inverse->initial_indices[i];
- PrimExpr dim = new_buffer->shape[i];
- std::stringstream ss;
- ss << "v_" << var->name_hint;
- Var virtual_var(ss.str(), var.dtype());
- new_iter_values.push_back(var);
- new_iter_vars.push_back(
- IterVar(Range::FromMinExtent(make_zero(dim.dtype()), dim), virtual_var, kDataPar));
- new_access_indices.push_back(virtual_var);
- loop_var_to_virtual_var.Set(var, virtual_var);
- }
-
- for (size_t i = block_index_start + old_indices.size(); i < iter_vars.size(); i++) {
- new_iter_vars.push_back(iter_vars[i]);
- new_iter_values.push_back(block_realize->iter_values[i]);
- }
-
- Map<Var, PrimExpr> old_virtual_var_to_new_virtual_var;
- ICHECK_EQ(inverse->final_indices.size(), old_indices.size());
- for (size_t i = 0; i < old_indices.size(); i++) {
- Var var = Downcast<Var>(old_indices[i]);
- PrimExpr expr = Substitute(inverse->final_indices[i], loop_var_to_virtual_var);
- old_virtual_var_to_new_virtual_var.Set(var, expr);
- }
-
- if_then_else_condition = Substitute(if_then_else_condition, loop_var_to_virtual_var);
- new_indices = new_access_indices;
-
- return [target_realize = info.innermost_block_realize, new_iter_vars, new_iter_values,
- old_virtual_var_to_new_virtual_var](const BlockRealizeNode* op,
- const BlockRealize& visited) -> Optional<Stmt> {
- if (op == target_realize.get()) {
- Block block = visited->block;
- block =
- Downcast<Block>(Substitute(std::move(block), old_virtual_var_to_new_virtual_var));
- block.CopyOnWrite()->iter_vars = new_iter_vars;
-
- BlockRealize realize = visited;
- {
- auto write_ptr = realize.CopyOnWrite();
- write_ptr->block = block;
- write_ptr->iter_values = new_iter_values;
- }
- return realize;
- } else {
- return NullOpt;
- }
- };
- }();
-
- bool all_stores_replaced = true;
- auto replace_store = [&](const BufferStoreNode* op) -> Optional<Stmt> {
- if (!op->buffer.same_as(info.store->buffer)) {
- all_stores_replaced = false;
- return NullOpt;
- }
- ICHECK_EQ(old_indices.size(), op->indices.size());
- ExprDeepEqual expr_equal;
- for (size_t i = 0; i < old_indices.size(); i++) {
- if (!expr_equal(old_indices[i], op->indices[i])) {
- all_stores_replaced = false;
- return NullOpt;
- }
- }
-
- PrimExpr pad_value_at_index = pad_value.value()->MapIndices(new_indices)[0];
- return BufferStore(new_buffer,
- if_then_else(if_then_else_condition, pad_value_at_index, op->value),
- new_indices);
- };
-
- BufferStoreReplacer replacer(replace_store, replace_block_realize);
+ BufferStoreReplacer replacer(info, new_buffer, padding_predicate, inverse, pad_value,
+ &new_block_to_old);
Stmt stmt = replacer(info.dependent_loopnest.back()->body);
- if (!all_stores_replaced) {
+ if (!replacer.is_all_stores_replaced()) {
return NullOpt;
}
- std::unordered_map<const VarNode*, PrimExpr> var_remap;
- ICHECK_EQ(info.dependent_loopnest.size(), inverse->final_indices.size());
- for (size_t i = 0; i < info.dependent_loopnest.size(); i++) {
- Var var = info.dependent_loopnest[i]->loop_var;
- PrimExpr expr = inverse->final_indices[i];
- var_remap[var.get()] = expr;
- }
- stmt = Substitute(std::move(stmt), var_remap);
-
ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) {
size_t i = (inverse->initial_indices.size() - 1) - rev_i;
@@ -471,7 +535,7 @@ class TransformLayoutPlanner : private StmtExprVisitor {
}
if (loop_replacements.size()) {
- return ReplacementPlan{std::move(loop_replacements)};
+ return ReplacementPlan{std::move(loop_replacements), std::move(new_block_to_old)};
} else {
return std::nullopt;
}
@@ -603,25 +667,6 @@ class TransformLayoutPlanner : private StmtExprVisitor {
std::vector<BindVariableDefinition> bound_vars_;
};
- struct WriteInfo {
- // The BufferStore object
- BufferStore store;
-
- // The block realize that contains the store, if any.
- Optional<BlockRealize> innermost_block_realize;
-
- // The nested loops whose values contribute to the indices used in
- // the store. Not all loop variables in the loopnest need to
- // contribute, but the first and last must.
- std::vector<For> dependent_loopnest;
-
- // Whether the padding could be represented as a tir::if_then_else
- // node. This requires that the surrounding loop iterators
- // iterate over all pre-transformation buffer axes, that there are
- // no data dependencies between loop iterations, and that
- bool contains_row_major_traversal{false};
- };
-
/*! \brief Collected information about each BufferStore */
std::vector<WriteInfo> write_info_;
@@ -683,7 +728,20 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
auto write_ptr = result.CopyOnWrite();
write_ptr->body = SeqStmt({plan_ptr->prologue, write_ptr->body});
}
- return {result, rewriter.block_sref_reuse_};
+
+ Map<Block, Block> block_sref_reuse;
+ for (auto [after, before] : rewriter.new_block_to_old_) {
+ while (auto opt = rewriter.new_block_to_old_.Get(before)) {
+ before = opt.value();
+ }
+ while (auto opt = block_sref_reuse.Get(after)) {
+ after = opt.value();
+ }
+
+ block_sref_reuse.Set(before, after);
+ }
+
+ return {result, block_sref_reuse};
}
private:
@@ -696,7 +754,11 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
new_buffer_(new_buffer),
index_map_(index_map),
plan_(plan),
- buffer_data_to_buffer_{{new_buffer->data, new_buffer}} {}
+ buffer_data_to_buffer_{{new_buffer->data, new_buffer}} {
+ if (auto plan_ptr = std::get_if<TransformLayoutPlanner::ReplacementPlan>(&plan_)) {
+ new_block_to_old_ = plan_ptr->new_block_to_old;
+ }
+ }
void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) {
*buffer = new_buffer_;
@@ -765,7 +827,20 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
}
Stmt VisitStmt_(const BlockNode* op) final {
+ Block orig = [&]() {
+ Block block = GetRef<Block>(op);
+ while (true) {
+ if (auto it = new_block_to_old_.find(block); it != new_block_to_old_.end()) {
+ block = (*it).second;
+ } else {
+ break;
+ }
+ }
+ return block;
+ }();
+
Block block = Downcast<Block>(Parent::VisitStmt_(op));
+
auto infered_access_regions = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
auto* n = block.CopyOnWrite();
RewriteAccessRegion(&n->reads, infered_access_regions[0]);
@@ -777,16 +852,35 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
return buffer;
}
});
- block_sref_reuse_.Set(GetRef<Block>(op), block);
+
+ RecordReplacement(orig, block);
return std::move(block);
}
+ void RecordReplacement(Block before, Block after) {
+ if (before.same_as(after)) {
+ return;
+ }
+
+ ICHECK(!new_block_to_old_.count(after));
+
+ while (true) {
+ if (auto opt = new_block_to_old_.Get(before)) {
+ before = opt.value();
+ } else {
+ break;
+ }
+ }
+
+ new_block_to_old_.Set(after, before);
+ }
+
const Buffer& old_buffer_;
const Buffer& new_buffer_;
const IndexMap& index_map_;
const TransformLayoutPlanner::TransformPlan& plan_;
Map<Var, Buffer> buffer_data_to_buffer_;
- Map<Block, Block> block_sref_reuse_;
+ Map<Block, Block> new_block_to_old_;
};
class BufferIsSubregionError : public ScheduleError {
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index 27056124d9..a901eff6f2 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -662,10 +662,11 @@ class SRefTreePruner : public StmtVisitor {
<< GetRef<Block>(op);
StmtSRef& sref = it->second;
// Detect reuse
- auto reuse_it = reuse_info_.block_sref_reuse.find(op);
- if (reuse_it != reuse_info_.block_sref_reuse.end()) {
+ const auto& sref_reuse = reuse_info_.block_sref_reuse;
+ if (auto reuse_it = sref_reuse.find(op); reuse_it != sref_reuse.end()) {
+ const BlockNode* to_reuse = reuse_it->second;
// sref can be reused
- reused_srefs_.emplace(reuse_it->second, std::move(sref));
+ reused_srefs_.emplace(to_reuse, std::move(sref));
} else {
sref->Reset();
self_->block_info.erase(sref);
diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py
index ca5ac12a97..e904789223 100644
--- a/tests/python/unittest/test_tir_schedule_transform_layout.py
+++ b/tests/python/unittest/test_tir_schedule_transform_layout.py
@@ -415,13 +415,13 @@ class BasePaddingCompare(tvm.testing.CompareBeforeAfter):
transformed_buffer = tvm.testing.parameter("A")
+ index_map = tvm.testing.parameter(lambda i: [i // 4, i % 4])
+
@pytest.fixture
- def transform(self, pad_value, transformed_buffer):
+ def transform(self, pad_value, transformed_buffer, index_map):
def transform(mod):
sch = tir.Schedule(mod)
- sch.transform_layout(
- "block", transformed_buffer, lambda i: [i // 4, i % 4], pad_value=pad_value
- )
+ sch.transform_layout("block", transformed_buffer, index_map, pad_value=pad_value)
return sch.mod
return transform
@@ -885,5 +885,45 @@ class TestTransformLayoutWithVar(tvm.testing.CompareBeforeAfter):
)
+class TestTransformWithAxisSeparators(BasePaddingCompare):
+ """Axis separators may be specified in a transform"""
+
+ index_map = tvm.testing.parameter(lambda i: [i // 4, tvm.tir.IndexMap.AXIS_SEPARATOR, i % 4])
+ pad_value = tvm.testing.parameter(0)
+
+ def before(a: T.handle):
+ A = T.match_buffer(a, [14], "int32")
+ for i in T.serial(14):
+ with T.block("block"):
+ vi = T.axis.remap("S", [i])
+ A[vi] = 42
+
+ def expected(a: T.handle):
+ A = T.match_buffer(a, [4, 4], "int32", axis_separators=[1])
+ for i, j in T.grid(4, 4):
+ with T.block("block"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ A[vi, vj] = T.if_then_else(vi == 3 and 2 <= vj, 0, 42, dtype="int32")
+
+
+class TestTransformWithAxisSeparatorsOpaqueBlock(BasePaddingCompare):
+ """Axis separators may be specified in a transform of opaque block"""
+
+ index_map = tvm.testing.parameter(lambda i: [i // 4, tvm.tir.IndexMap.AXIS_SEPARATOR, i % 4])
+ pad_value = tvm.testing.parameter(0)
+
+ def before(a: T.handle):
+ A = T.match_buffer(a, [14], "int32")
+ for i in T.serial(14):
+ with T.block("block"):
+ A[i] = 42
+
+ def expected(a: T.handle):
+ A = T.match_buffer(a, [4, 4], "int32", axis_separators=[1])
+ for i, j in T.grid(4, 4):
+ with T.block("block"):
+ A[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, 42, dtype="int32")
+
+
if __name__ == "__main__":
tvm.testing.main()