You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/07/13 15:57:09 UTC
[tvm] branch main updated: [TIR][Schedule] Refactor Tensorize (#12070)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7d9a07ccc7 [TIR][Schedule] Refactor Tensorize (#12070)
7d9a07ccc7 is described below
commit 7d9a07ccc70eef951bcfff0333c2f82cdc6a3b12
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Wed Jul 13 08:57:01 2022 -0700
[TIR][Schedule] Refactor Tensorize (#12070)
* Refactor blockize
* Refactor tensorize
* Address review comments
* typo
* rename variables according to review
---
src/tir/schedule/primitive/blockize_tensorize.cc | 853 ++++++++++-----------
.../python/unittest/test_tir_schedule_blockize.py | 322 ++++----
2 files changed, 580 insertions(+), 595 deletions(-)
diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc
index 4ede2dd90d..9c3029ebf5 100644
--- a/src/tir/schedule/primitive/blockize_tensorize.cc
+++ b/src/tir/schedule/primitive/blockize_tensorize.cc
@@ -24,6 +24,20 @@
namespace tvm {
namespace tir {
+template <class T>
+bool UsesVar(const T& x, const Var& var) {
+ return UsesVar(x, [tgt = var.get()](const VarNode* v) { return v == tgt; });
+}
+
+Range RangeFromExtent(const PrimExpr& extent) {
+ return Range::FromMinExtent(make_zero(extent->dtype), extent);
+}
+
+template <class T>
+T DeepCopy(const T& stmt) {
+ return Downcast<T>(LoadJSON(SaveJSON(stmt)));
+}
+
/*!
* \brief ScheduleError that the bindings of the inner block are not divisible by the subspace
* represented by the outer loops.
@@ -64,16 +78,16 @@ class SubspaceNotDivisibleError : public ScheduleError {
*
* \param iter_vars The input iterators
* \param bindings The values of iter_vars
- * \param outer_loops Iterators outside the subspace.
- * \param inner_loops Iterators of the subspace
* \param predicate The predicate constraint on the input iterators.
+ * \param outer_iters The iters of the outer space
+ * \param inner_iters The iters of the inner space
* \return The result of the subspace division.
*/
Array<Array<arith::IterMark>> TrivialSubspaceDivision(const Array<IterVar>& iter_vars,
const Array<PrimExpr>& bindings,
+ const PrimExpr& predicate,
const Array<Var>& outer_iters,
- const Array<Var>& inner_iters,
- const PrimExpr& predicate) {
+ const Array<Var>& inner_iters) {
if (!is_one(predicate)) return {};
Array<Array<arith::IterMark>> res;
std::unordered_set<const VarNode*> outer_loop_vars;
@@ -95,7 +109,7 @@ Array<Array<arith::IterMark>> TrivialSubspaceDivision(const Array<IterVar>& iter
auto use_inner_loop_vars = make_uses_var(inner_iters);
arith::IterMark unit_iter_mark(arith::IterSumExpr({}, 0), 1);
- for (size_t i = 0; i < bindings.size(); ++i) {
+ for (int i = 0, n = bindings.size(); i < n; ++i) {
bool outer = use_outer_loop_vars(bindings[i]);
bool inner = use_inner_loop_vars(bindings[i]);
arith::IterMark iter_mark;
@@ -122,531 +136,462 @@ Array<Array<arith::IterMark>> TrivialSubspaceDivision(const Array<IterVar>& iter
}
/*!
- * \brief Generate the blockized init block.
- * \param block The original block with init.
- * \param inner_block_realize The block realize of the inner block after blockize.
- * \param inner_loops The inner loops after blockize.
- * \return The subtree of the init block and its outer loops.
+ * \brief Subspace division. The space is divided into two subspaces:
+ * 1. The subspace represented by the outer loops above `loop_sref` (exclusive).
+ * 2. The subspace represented by the inner loops below `loop_sref` (inclusive).
+ * \param realize The inner block
+ * \param block_sref The sref to the inner block
+ * \param loop_sref The loop that is the root of the second subspace.
+ * \param loops The loops that represents the second part of the subspace.
+ * \param analyzer The arithmetic analyzer to use.
*/
-Stmt GenerateBlockizedInit(const Block& block, const BlockRealize& inner_block_realize,
- const std::vector<const ForNode*>& inner_loops) {
- Array<IterVar> init_block_iters;
- Array<PrimExpr> init_bindings;
- const Block& inner_block = inner_block_realize->block;
-
- // Step 1: Collect data-parallel block iters
- for (size_t i = 0; i < inner_block->iter_vars.size(); i++) {
- const IterVar& iter_var = inner_block->iter_vars[i];
- const PrimExpr& binding = inner_block_realize->iter_values[i];
- if (iter_var->iter_type == IterVarType::kDataPar &&
- UsesVar(block->init.value(),
- [tgt_var = iter_var->var.get()](const VarNode* var) { return var == tgt_var; })) {
- init_block_iters.push_back(iter_var);
- init_bindings.push_back(binding);
+Array<Array<arith::IterMark>> SubspaceDivide(const BlockRealize& realize,
+ const StmtSRef& block_sref, //
+ const StmtSRef& loop_sref, //
+ std::vector<const ForNode*>* loops,
+ arith::Analyzer* analyzer) {
+ Array<Var> inner_vars;
+ Array<Var> outer_vars;
+ Map<Var, Range> loop_var_domain;
+ bool inner = true;
+ for (StmtSRefNode* sref = block_sref->parent; //
+ sref && sref->stmt->IsInstance<ForNode>(); //
+ sref = sref->parent) {
+ const ForNode* loop = static_cast<const ForNode*>(sref->stmt);
+ if (inner) {
+ loops->push_back(loop);
+ inner_vars.push_back(loop->loop_var);
+ } else {
+ outer_vars.push_back(loop->loop_var);
}
- }
-
- // Step 2: Collect loops related to iters of the init block
- std::vector<const ForNode*> init_loops;
- for (const ForNode* inner_loop : inner_loops) {
- for (const PrimExpr& init_binding : init_bindings) {
- if (UsesVar(init_binding, [tgt_var = inner_loop->loop_var.get()](const VarNode* var) {
- return var == tgt_var;
- })) {
- init_loops.push_back(inner_loop);
- break;
- }
+ loop_var_domain.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
+ if (sref == loop_sref.get()) {
+ inner = false;
}
}
-
- // Step 3: Create new block iters for the init block
- Map<Var, PrimExpr> subst_map;
- for (size_t i = 0; i < init_block_iters.size(); i++) {
- IterVar new_iter_var = init_block_iters[i];
- Var old_var = new_iter_var->var;
- Var new_var = old_var.copy_with_suffix("_init");
- new_iter_var.CopyOnWrite()->var = new_var;
- subst_map.Set(old_var, new_var);
- init_block_iters.Set(i, std::move(new_iter_var));
- }
-
- // Step 4: Generate loop nests and the init block
- Stmt new_init = BlockRealize(
- /*iter_values=*/init_bindings,
- /*predicate=*/inner_block_realize->predicate,
- /*block=*/
- Block{/*iter_vars=*/init_block_iters,
- /*reads=*/{},
- /*writes=*/block->writes,
- /*name_hint=*/block->name_hint + "_init",
- /*body=*/block->init.value(),
- /*init=*/NullOpt});
-
- // Step 5: Generate the parent loops for the init block
- for (const ForNode* init_loop : init_loops) {
- ObjectPtr<ForNode> new_loop = make_object<ForNode>(*init_loop);
- new_loop->loop_var = init_loop->loop_var.copy_with_suffix("");
- subst_map.Set(init_loop->loop_var, new_loop->loop_var);
- new_loop->body = std::move(new_init);
- new_init = For(new_loop);
+ Array<Array<arith::IterMark>> result =
+ arith::SubspaceDivide(realize->iter_values, loop_var_domain, inner_vars, realize->predicate,
+ arith::IterMapLevel::Surjective, analyzer);
+ if (!result.empty()) {
+ return result;
}
-
- // Step 6: Substitute with new loop variables and block iters to prevent duplication of
- // variables in the outer block.
- new_init = Substitute(new_init, subst_map);
-
- return new_init;
+ return TrivialSubspaceDivision(realize->block->iter_vars,
+ realize->iter_values, //
+ realize->predicate, //
+ outer_vars, inner_vars);
}
/*!
- * \brief A helper to collect the parent loops of the block. The loops are divided into two groups,
- * 'outer_loops', and 'inner_loops', by a specified loop as the separator. 'outer_loops' are the
- * ancestor loops of the separator loop. 'inner_loops' include the separator loop itself, and its
- * successor loops. It is possible that 'outer_loops' is empty.
+ * \brief Derive the block bindings for both inner and outer block
+ * \param iter_vars The original block iterators to the inner block
+ * \param division The subspace division.
+ * \param outer_iter_vars The outer block iterators.
+ * \param outer_bindings The outer block bindings.
+ * \param inner_iter_vars The inner block iterators.
+ * \param inner_bindings The inner block bindings.
+ * \return A substitution plan to the iterators in the original inner block.
*/
-class LoopSubspaceCollector {
- public:
- /*!
- * \brief Collect the parent loops of the block and store the result in the corresponding fields.
- * \param block_sref The sref to the target block.
- * \param loop_sref The sref to the separator loop. The loop itself is counted as an inner loop.
- */
- void Collect(const StmtSRef& block_sref, const StmtSRef& loop_sref) {
- bool inner = true;
- for (StmtSRefNode* current_sref = block_sref->parent;
- current_sref && current_sref->stmt->IsInstance<ForNode>();
- current_sref = current_sref->parent) {
- const auto* current_loop = current_sref->StmtAs<ForNode>();
- ICHECK(current_loop);
- if (inner) {
- inner_loops.push_back(current_loop);
- inner_loop_vars.push_back(current_loop->loop_var);
- } else {
- outer_loops.push_back(current_loop);
- outer_loop_vars.push_back(current_loop->loop_var);
- }
- loop_var_domain.Set(current_loop->loop_var,
- Range::FromMinExtent(current_loop->min, current_loop->extent));
- if (current_sref == loop_sref.get()) inner = false;
+Map<Var, PrimExpr> DeriveBlockBinding(const Array<IterVar>& iter_vars, //
+ const Array<Array<arith::IterMark>>& division, //
+ Array<IterVar>* outer_iter_vars, //
+ Array<PrimExpr>* outer_bindings, //
+ Array<IterVar>* inner_iter_vars, //
+ Array<PrimExpr>* inner_bindings) {
+ using arith::IterMapExpr;
+ using arith::IterMapExprNode;
+ using arith::NormalizeIterMapToExpr;
+ Map<Var, PrimExpr> block_var_subst;
+ ICHECK_EQ(iter_vars.size() + 1, division.size());
+ for (int i = 0, n = iter_vars.size(); i < n; ++i) {
+ const IterVar& iter_var = iter_vars[i];
+ arith::IterMark outer_mark = division[i][0];
+ arith::IterMark inner_mark = division[i][1];
+ IterMapExpr outer_binding = Downcast<IterMapExpr>(outer_mark->source);
+ IterMapExpr inner_binding = Downcast<IterMapExpr>(inner_mark->source);
+ // After computing the subspace division, bindings[i] can be written as
+ // outer_binding * inner_binding->extent + inner_binding
+ // The outer block will have binding: iter_outer -> outer_binding
+ // The inner block will have binding: iter_inner -> inner_binding
+ // The iter in the original block will be substituted with base + iter_inner where
+ // base == iter_outer * iter_inner_extent
+ if (is_one(inner_mark->extent)) { // IsOuter
+ // extract this iter var to outer block directly
+ outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding));
+ outer_iter_vars->push_back(iter_var);
+ continue;
}
+ // create iter var for the outer block
+ IterVar outer_iter(/*dom=*/RangeFromExtent(outer_mark->extent),
+ /*var=*/iter_var->var.copy_with_suffix("_o"),
+ /*iter_type=*/iter_var->iter_type);
+ outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding));
+ outer_iter_vars->push_back(outer_iter);
+ // create iter var for the inner block
+ IterVar inner_iter(/*dom=*/RangeFromExtent(inner_mark->extent),
+ /*var=*/iter_var->var.copy_with_suffix("_i"),
+ /*iter_type=*/iter_var->iter_type);
+ inner_bindings->push_back(NormalizeIterMapToExpr(inner_binding));
+ inner_iter_vars->push_back(inner_iter);
+ // substitution
+ PrimExpr sub{nullptr};
+ if (is_one(outer_mark->extent)) {
+ sub = inner_iter->var;
+ } else {
+ sub = outer_iter * inner_mark->extent + inner_iter->var;
+ }
+ block_var_subst.Set(iter_var->var, sub);
}
- /*! \brief Outer loops which are ancestors of the separator. */
- std::vector<const ForNode*> outer_loops;
- /*! \brief Inner loops which are the separator itself or its successors. */
- std::vector<const ForNode*> inner_loops;
- /*! \brief Loop variables of the outer loops. */
- Array<Var> outer_loop_vars;
- /*! \brief Loop variables of the inner loops. */
- Array<Var> inner_loop_vars;
- /*! \brief Domain of the loop variables. */
- Map<Var, Range> loop_var_domain;
-};
+ return block_var_subst;
+}
/*!
- * \brief Check the bindings of the block iters can be divided by a subspace collected by the
- * collector.
- * \param mod The current IR module.
- * \param block_realize The block realize to be checked.
- * \param collector The collector which has collected the loops of the block.
- * \param analyzer The arithmetic analyzer.
- * \return The result of the subspace division.
- * \throws ScheduleError If the bindings are not divisible by the subspace.
+ * \brief Generate the inner block for blockization
+ * \param is_write_reduction Whether the write regions of the inner block are actually reduction.
+ * \param iter_vars IterVars used in the inner block.
+ * \param iter_values IterVar bindings used in the inner block.
+ * \param predicate The predicate of the inner block.
+ * \param block The inner block as a template to be created from. This method will modify its
+ * `iter_vars`, `init` and `reads` fields.
+ * \return The inner block created.
*/
-Array<Array<arith::IterMark>> CheckSubspaceDivisible(const IRModule& mod,
- const BlockRealize& block_realize,
- const LoopSubspaceCollector& collector,
- arith::Analyzer* analyzer) {
- const Block& block = block_realize->block;
-
- Array<Array<arith::IterMark>> division = arith::SubspaceDivide(
- block_realize->iter_values, collector.loop_var_domain, collector.inner_loop_vars,
- block_realize->predicate, arith::IterMapLevel::Surjective, analyzer);
-
- if (division.empty()) {
- // If we can't do perfect subspace division, check if it is a trivial case of subspace division.
- // In this case, we can still blockize.
- division = TrivialSubspaceDivision(block->iter_vars, block_realize->iter_values,
- collector.outer_loop_vars, collector.inner_loop_vars,
- block_realize->predicate);
- }
- if (division.empty()) {
- throw SubspaceNotDivisibleError(mod, GetRef<For>(collector.inner_loops.back()), block);
+BlockRealize GenerateInner(bool is_write_reduction,
+ const Array<IterVar>& iter_vars, //
+ const Array<PrimExpr>& iter_values, //
+ const PrimExpr& predicate, //
+ Block block) {
+ BlockNode* n = block.CopyOnWrite();
+ n->iter_vars = iter_vars;
+ n->init = NullOpt;
+ if (is_write_reduction) {
+ Array<BufferRegion> reads;
+ reads.reserve(block->writes.size() + block->reads.size());
+ reads.insert(reads.end(), block->writes.begin(), block->writes.end());
+ reads.insert(reads.end(), block->reads.begin(), block->reads.end());
+ n->reads = std::move(reads);
}
- return division;
+ return BlockRealize(/*iter_values=*/iter_values, /*predicate=*/predicate,
+ /*block=*/block);
}
/*!
- * \brief The binding extractor to compute the bindings of the outer and the inner blocks after
- * blockize.
+ * \brief Generate the init stmt for the outer block
+ * \param block The original block with init.
+ * \param inner_realize The block realize of the inner block after blockize.
+ * \param loops The inner loops after blockize.
+ * \return The subtree of the init block and its outer loops.
*/
-class BlockizedBindingExtractor {
- public:
- /*!
- * \brief Extract bindings for blockize.
- * \param iter_vars The iter vars of the original inner block.
- * \param division The result of the subspace division.
- */
- void ExtractBindings(const Array<IterVar>& iter_vars,
- const Array<Array<arith::IterMark>>& division, arith::Analyzer* analyzer) {
- ICHECK_EQ(iter_vars.size() + 1, division.size());
- for (size_t i = 0; i < iter_vars.size(); ++i) {
- const IterVar& iter_var = iter_vars[i];
- arith::IterMark outer_mark = division[i][0];
- arith::IterMark inner_mark = division[i][1];
- const auto* outer_binding =
- TVM_TYPE_AS(outer_binding, outer_mark->source, arith::IterMapExprNode);
- const auto* inner_binding =
- TVM_TYPE_AS(inner_binding, inner_mark->source, arith::IterMapExprNode);
-
- // After computing the subspace division, bindings[i] can be written as
- // outer_binding * inner_binding->extent + inner_binding
- // The outer block will have binding: iter_outer -> outer_binding
- // The inner block will have binding: iter_inner -> inner_binding
- // The iter in the original block will be substituted with base + iter_inner where
- // base == iter_outer * iter_inner_extent
-
- if (is_one(division[i][1]->extent)) { // IsOuter
- // extract this iter var to outer block directly
- outer_bindings.push_back(
- arith::NormalizeIterMapToExpr(GetRef<arith::IterMapExpr>(outer_binding)));
- outer_iter_vars.push_back(iter_var);
- } else {
- // create iter var for the outer block
- const IterVar outer_var(/*dom=*/Range::FromMinExtent(0, division[i][0]->extent),
- /*var=*/iter_var->var.copy_with_suffix("_o"),
- /*iter_type=*/iter_var->iter_type);
- outer_bindings.push_back(
- arith::NormalizeIterMapToExpr(GetRef<arith::IterMapExpr>(outer_binding)));
- outer_iter_vars.push_back(outer_var);
- PrimExpr base = is_one(division[i][0]->extent) ? 0 : outer_var * division[i][1]->extent;
- // create iter var for the inner block
- IterVar new_iter(Range::FromMinExtent(0, division[i][1]->extent), Var(iter_var->var),
- iter_var->iter_type, iter_var->thread_tag, iter_var->span);
- inner_iter_dom_map.Set(new_iter->var, arith::IntSet::FromRange(new_iter->dom));
- analyzer->Bind(new_iter->var, new_iter->dom);
- inner_iter_vars.push_back(new_iter);
- inner_bindings.push_back(
- arith::NormalizeIterMapToExpr(GetRef<arith::IterMapExpr>(inner_binding)));
- inner_iter_subst_map.Set(iter_var->var, base + new_iter->var);
+Stmt GenerateOuterInit(const Stmt& block_init, const BlockRealize& inner_realize,
+ const std::vector<const ForNode*>& loops, String block_name) {
+ const Block& inner_block = inner_realize->block;
+ Map<Var, PrimExpr> subst_map;
+ // Step 1: Create new block vars for the block inside the init stmt of outer block
+ // A iter is used in the block if
+ // 1) It is data parallel
+ // 2) It is used in the original init block
+ Array<IterVar> iter_vars;
+ Array<PrimExpr> iter_values;
+ ICHECK_EQ(inner_block->iter_vars.size(), inner_realize->iter_values.size());
+ int n = inner_block->iter_vars.size();
+ iter_vars.reserve(n);
+ iter_values.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ const IterVar& old_iter_var = inner_block->iter_vars[i];
+ const PrimExpr& iter_value = inner_realize->iter_values[i];
+ if (old_iter_var->iter_type == IterVarType::kDataPar &&
+ UsesVar(block_init, old_iter_var->var)) {
+ ObjectPtr<IterVarNode> new_iter_var = make_object<IterVarNode>(*old_iter_var.get());
+ new_iter_var->var = new_iter_var->var.copy_with_suffix("_init");
+ subst_map.Set(old_iter_var->var, new_iter_var->var);
+ iter_vars.push_back(IterVar(new_iter_var));
+ iter_values.push_back(iter_value);
+ }
+ }
+ // Step 2: Generate the block inside init stmt of outer block
+ Stmt stmt = BlockRealize(
+ /*iter_values=*/iter_values,
+ /*predicate=*/inner_realize->predicate,
+ /*block=*/
+ Block(/*iter_vars=*/iter_vars,
+ /*reads=*/{},
+ /*writes=*/inner_block->writes,
+ /*name_hint=*/block_name,
+ /*body=*/block_init,
+ /*init=*/NullOpt));
+ // Step 3. Create the loop nest on top of the block
+ for (const ForNode* loop : loops) {
+ bool is_init_loop = false;
+ for (const PrimExpr& init_binding : iter_values) {
+ if (UsesVar(init_binding, loop->loop_var)) {
+ is_init_loop = true;
+ break;
}
}
+ if (is_init_loop) {
+ ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop);
+ new_loop->loop_var = loop->loop_var.copy_with_suffix("");
+ new_loop->body = std::move(stmt);
+ subst_map.Set(loop->loop_var, new_loop->loop_var);
+ stmt = For(new_loop);
+ }
}
- Map<Var, PrimExpr> inner_iter_subst_map;
- /*! \brief Iters of the outer block. */
- Array<IterVar> outer_iter_vars;
- /*! \brief Iters of the outer block. */
- Array<IterVar> inner_iter_vars;
- /*! \brief Binding values of the outer block. */
- Array<PrimExpr> outer_bindings;
- /*! \brief Binding values of the inner block. */
- Array<PrimExpr> inner_bindings;
- /*! \brief The domain of the inner block iters. */
- Map<Var, arith::IntSet> inner_iter_dom_map;
-};
+ // Step 4: Substitute the iter vars and loop vars
+ return Substitute(stmt, subst_map);
+}
/*!
- * \brief Replacer for the inner block after blockize. Inner block iters will be replaced with
- * base + inner_iter and the expressions after substituion will be simplified if possible.
+ * \brief Substitute variables in the stmt, do simplification and track block substitution
+ * \param stmt The stmt to be substituted.
+ * \param sub The substitution map.
+ * \param block_sref_reuse The block substitution happens during the substitution.
+ * \param analyzer The analyzer for arithmetic simplification.
+ * \return The substituted stmt.
*/
-class InnerIterReplacer : public StmtExprMutator {
- public:
- /*!
- * \brief The constructor
- * \param subst_map The substitution map of the inner block iters.
- * \param analyzer The arithmetic analyzer.
- * \param block_sref_reuse The map to save the block reuse information.
- */
- InnerIterReplacer(Map<Var, PrimExpr> subst_map, arith::Analyzer* analyzer,
- Map<Block, Block>* block_sref_reuse)
- : subst_map_(std::move(subst_map)),
- analyzer_(analyzer),
- block_sref_reuse_(block_sref_reuse) {}
-
- PrimExpr VisitExpr_(const VarNode* op) final {
- auto it = subst_map_.find(GetRef<Var>(op));
- if (it != subst_map_.end()) {
- return (*it).second;
+Stmt Substitute(const Stmt& stmt, const Map<Var, PrimExpr>& sub,
+ Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer) {
+ struct Replacer : public StmtExprMutator {
+ explicit Replacer(const Map<Var, PrimExpr>& sub, Map<Block, Block>* block_sref_reuse,
+ arith::Analyzer* analyzer)
+ : sub_(sub), block_sref_reuse_(block_sref_reuse), analyzer_(analyzer) {}
+
+ PrimExpr VisitExpr(const PrimExpr& op) final {
+ PrimExpr result = StmtExprMutator::VisitExpr(op);
+ if (!result.same_as(op)) {
+ return analyzer_->Simplify(result);
+ }
+ return result;
}
- return StmtExprMutator::VisitExpr_(op);
- }
- PrimExpr VisitExpr(const PrimExpr& op) final {
- PrimExpr result = StmtExprMutator::VisitExpr(op);
- if (!result.same_as(op)) {
- return analyzer_->Simplify(result);
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ if (Optional<PrimExpr> e = sub_.Get(GetRef<Var>(op))) {
+ return e.value();
+ }
+ return StmtExprMutator::VisitExpr_(op);
}
- return result;
- }
- Stmt VisitStmt_(const BlockNode* op) final {
- Stmt result = StmtExprMutator::VisitStmt_(op);
- if (!result.same_as(GetRef<Stmt>(op))) {
- block_sref_reuse_->Set(GetRef<Block>(op), Downcast<Block>(result));
+ Stmt VisitStmt_(const BlockNode* op) final {
+ Block src = GetRef<Block>(op);
+ Block tgt = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
+ if (!src.same_as(tgt)) {
+ block_sref_reuse_->Set(src, tgt);
+ }
+ return tgt;
}
- return result;
- }
- private:
- Map<Var, PrimExpr> subst_map_;
- arith::Analyzer* analyzer_;
- Map<Block, Block>* block_sref_reuse_;
-};
+ const Map<Var, PrimExpr>& sub_;
+ Map<Block, Block>* block_sref_reuse_;
+ arith::Analyzer* analyzer_;
+ };
+ return Replacer(sub, block_sref_reuse, analyzer)(stmt);
+}
/*!
- * \brief Compute the access region of the outer block by relaxing the inner loops.
- * \param buffer_region The original buffer region.
- * \param The range of the inner loops.
- * \return The new buffer region.
+ * \brief Relax the variables for the given regions
+ * \param regions The regions to be relaxed.
+ * \param dom_map The variables to be relaxed
+ * \return The relaxed regions
*/
-BufferRegion RelaxBlockizedInnerIters(const BufferRegion& buffer_region,
- const Map<Var, arith::IntSet>& inner_iter_relaxed_range) {
- Array<Range> new_region;
- new_region.reserve(buffer_region->region.size());
- Array<arith::IntSet> relaxed_int_set =
- arith::EvalSet(buffer_region->region, inner_iter_relaxed_range);
- ICHECK(buffer_region->region.size() == buffer_region->buffer->shape.size());
- for (size_t i = 0; i < buffer_region->region.size(); i++) {
- Range max_range = Range::FromMinExtent(0, buffer_region->buffer->shape[i]);
- new_region.push_back(relaxed_int_set[i].CoverRange(max_range));
+Array<BufferRegion> EvalSetRegions(const Array<BufferRegion>& regions,
+ const Map<Var, arith::IntSet>& dom_map) {
+ Array<BufferRegion> results;
+ results.reserve(regions.size());
+ for (const BufferRegion& buffer_region : regions) {
+ const Buffer& buffer = buffer_region->buffer;
+ Array<arith::IntSet> relaxed = arith::EvalSet(buffer_region->region, dom_map);
+ ICHECK_EQ(relaxed.size(), buffer->shape.size());
+ int ndim = buffer->shape.size();
+ Array<Range> new_region;
+ new_region.reserve(ndim);
+ for (int i = 0; i < ndim; ++i) {
+ new_region.push_back(relaxed[i].CoverRange(RangeFromExtent(buffer->shape[i])));
+ }
+ results.push_back(BufferRegion(buffer, new_region));
}
- return BufferRegion(buffer_region->buffer, std::move(new_region));
+ return results;
}
/*!
- * \brief Generate the outer block after blockize.
- * \param extractor The binding extractor which has extracted the blockized bindings.
- * \param block The original inner block.
- * \param inner_block_realize The block realize of the inner block after blockize.
- * \param inner_loops The inner loops after blockize.
- * \param predicate The outer predicate of the subspace division.
- * \return The block realize of the outer block after blockize.
+ * \brief Create the loop nest on top of the given stmt.
+ * \param stmt The stmt to be wrapped.
+ * \param loops The loop nests
+ * \return The wrapped stmt.
*/
-BlockRealize GenerateBlockizedOuterBlock(const BlockizedBindingExtractor& extractor,
- const Block& block, BlockRealize inner_block_realize,
- const std::vector<const ForNode*>& inner_loops,
- PrimExpr predicate) {
- // Step 1: Generate the init block if needed
- Optional<Stmt> new_init = NullOpt;
- if (block->init.defined()) {
- new_init = GenerateBlockizedInit(block, inner_block_realize, inner_loops);
- }
-
- // Step 2: Compute the access regions of the outer block by relaxing the inner loops
- Array<BufferRegion> new_reads = block->reads;
- Array<BufferRegion> new_writes = block->writes;
-
- auto f_mutate = [&](const BufferRegion& buffer_region) {
- return RelaxBlockizedInnerIters(buffer_region, extractor.inner_iter_dom_map);
- };
- new_reads.MutateByApply(f_mutate);
- new_writes.MutateByApply(f_mutate);
-
- // Step 3: Generate the body of the outer block. The body of the outer block is the inner block
- // realize and its surrounding loops.
- Stmt outer_block_body = inner_block_realize;
- for (const ForNode* loop : inner_loops) {
+Stmt MakeLoopNest(Stmt stmt, const std::vector<const ForNode*>& loops) {
+ for (const ForNode* loop : loops) {
ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop);
- new_loop->body = std::move(outer_block_body);
- outer_block_body = For(new_loop);
+ new_loop->body = std::move(stmt);
+ stmt = For(new_loop);
}
-
- // Step 4: Generate the outer block and block realize.
- return BlockRealize(/*iter_values=*/std::move(extractor.outer_bindings),
- /*predicate=*/std::move(predicate),
- /*block=*/
- Block(/*iter_vars=*/std::move(extractor.outer_iter_vars), //
- /*reads=*/std::move(new_reads), //
- /*writes=*/std::move(new_writes), //
- /*name_hint=*/block->name_hint + "_o", //
- /*body=*/std::move(outer_block_body), //
- /*init=*/std::move(new_init)));
+ return stmt;
}
-StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) {
+BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref,
+ Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer) {
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
- arith::Analyzer analyzer;
-
- // Step 1: Check the loop has a single child BlockRealize on the sref tree.
+ // Step 1: Check and get the only block under `loop`.
BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, loop_sref);
Block block = block_realize->block;
StmtSRef block_sref = self->stmt2ref.at(block.get());
-
- // Step 2: Collect loops inside and outside loop_sref.
- LoopSubspaceCollector collector;
- collector.Collect(block_sref, loop_sref);
-
- // Step 3: Calculate subspace division for the inner loops.
+ // Step 2: Derive subspace division
+ std::vector<const ForNode*> loops;
Array<Array<arith::IterMark>> division =
- CheckSubspaceDivisible(self->mod, block_realize, collector, &analyzer);
-
- // Step 4: Generate bindings for the outer block and the inner block based on the result of
- // the subspace division.
- BlockizedBindingExtractor extractor;
- extractor.ExtractBindings(block->iter_vars, division, &analyzer);
- const PrimExpr& outer_pred = division.back()[0]->extent;
- const PrimExpr& inner_pred = division.back()[1]->extent;
-
- // Step 5: Substitute the iter vars in the original block with the inner iters after the subspace
- // division
- Map<Block, Block> block_sref_reuse;
- InnerIterReplacer replacer(std::move(extractor.inner_iter_subst_map), &analyzer,
- &block_sref_reuse);
- Block new_block = Downcast<Block>(replacer(block));
-
- // Step 6: Generate the inner block.
- bool outer_reduction = false; // whether there are outer reduction iter vars.
- for (const IterVar& iter_var : extractor.outer_iter_vars) {
- if (iter_var->iter_type == kCommReduce) {
- outer_reduction = true;
- }
+ SubspaceDivide(block_realize, block_sref, loop_sref, &loops, analyzer);
+ if (division.empty()) {
+ throw SubspaceNotDivisibleError(self->mod, GetRef<For>(loops.back()), block);
}
- BlockRealizeNode* inner_block_realize = block_realize.CopyOnWrite();
- inner_block_realize->iter_values = extractor.inner_bindings;
- inner_block_realize->predicate = inner_pred;
- inner_block_realize->block = new_block;
- BlockNode* inner_block = inner_block_realize->block.CopyOnWrite();
- inner_block->iter_vars = extractor.inner_iter_vars;
- inner_block->init = NullOpt;
- /* Add write regions to read regions if
- * 1. there are outer reduction iter vars.
- * 2. the init block is defined for current block.
- */
- if (outer_reduction && block->init.defined()) {
- Array<BufferRegion> new_reads;
- for (const BufferRegion& write_access : inner_block->writes) {
- new_reads.push_back(write_access);
- }
- for (const BufferRegion& read_access : inner_block->reads) {
- new_reads.push_back(read_access);
+ PrimExpr outer_predicate = division.back()[0]->extent;
+ PrimExpr inner_predicate = division.back()[1]->extent;
+ // Step 3. Derive block bindings for both outer and inner block.
+ Array<IterVar> outer_iter_vars;
+ Array<IterVar> inner_iter_vars;
+ Array<PrimExpr> outer_bindings;
+ Array<PrimExpr> inner_bindings;
+ Map<Var, PrimExpr> block_var_subst = //
+ DeriveBlockBinding(block->iter_vars, division, //
+ &outer_iter_vars, &outer_bindings, //
+ &inner_iter_vars, &inner_bindings);
+ // Step 4: Do var substitution to adjust to the new block bindings
+ Map<Var, arith::IntSet> inner_iter_dom;
+ for (const IterVar& iter : inner_iter_vars) {
+ inner_iter_dom.Set(iter->var, arith::IntSet::FromRange(iter->dom));
+ analyzer->Bind(iter->var, iter->dom);
+ }
+ Block block_subst =
+ Downcast<Block>(Substitute(block, block_var_subst, block_sref_reuse, analyzer));
+ // Step 5: Generate the inner block. The write regions of the inner blocks will be reduction if
+ // 1. The original block has init stmt.
+ // 2. There are outer reduction iter vars.
+ bool has_outer_reduction = false;
+ if (block_subst->init.defined()) {
+ for (const IterVar& iter_var : outer_iter_vars) {
+ if (iter_var->iter_type == kCommReduce) {
+ has_outer_reduction = true;
+ break;
+ }
}
- inner_block->reads = std::move(new_reads);
}
- block_sref_reuse.Set(block, inner_block_realize->block);
-
+ BlockRealize inner_realize = GenerateInner(/*is_write_reduction=*/has_outer_reduction,
+ /*iter_vars=*/inner_iter_vars,
+ /*iter_values*/ inner_bindings,
+ /*predicate=*/inner_predicate,
+ /*block=*/block_subst);
+ block_sref_reuse->Set(block, inner_realize->block);
// Step 6: Generate the outer block.
- BlockRealize outer_realize =
- GenerateBlockizedOuterBlock(extractor, new_block, GetRef<BlockRealize>(inner_block_realize),
- collector.inner_loops, outer_pred);
- // Step 7: Do the actual replacement
- self->Replace(loop_sref, outer_realize, block_sref_reuse);
-
- // Step 8: Update the cached flags
- StmtSRef outer_block_sref = self->stmt2ref.at(outer_realize->block.get());
- StmtSRef scope_root = tir::GetScopeRoot(self, outer_block_sref, /*require_stage_pipeline=*/false);
+ return BlockRealize(
+ /*iter_values=*/std::move(outer_bindings),
+ /*predicate=*/std::move(outer_predicate),
+ /*block=*/
+ Block(/*iter_vars=*/std::move(outer_iter_vars),
+ /*reads=*/EvalSetRegions(block_subst->reads, inner_iter_dom),
+ /*writes=*/EvalSetRegions(block_subst->writes, inner_iter_dom),
+ /*name_hint=*/block_subst->name_hint + "_o",
+ /*body=*/MakeLoopNest(inner_realize, loops),
+ /*init=*/
+ block_subst->init.defined() //
+ ? GenerateOuterInit(block_subst->init.value(), inner_realize, loops,
+ block_subst->name_hint + "_init")
+ : Optional<Stmt>(NullOpt)));
+}
+
+StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) {
+ arith::Analyzer analyzer;
+ Map<Block, Block> block_sref_reuse;
+ BlockRealize blockized = BlockizeImpl(self, loop_sref, &block_sref_reuse, &analyzer);
+ self->Replace(loop_sref, blockized, block_sref_reuse);
+ StmtSRef result = self->stmt2ref.at(blockized->block.get());
+ StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false);
bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root);
self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root));
self->block_info[scope_root].affine_binding = scope_block_affine_binding;
- return outer_block_sref;
-}
-
-/*!
- * \brief Update the map from the buffers in the desc to the impl of the tensor
- * intrinsic.
- * \param intrinsic The tensor intrinsic.
- * \param buffer_map The map to be updated.
- */
-void RemapTensorIntrinBuffers(
- const TensorIntrin& intrinsic,
- std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>* buffer_map) {
- ICHECK_EQ(intrinsic->desc->params.size(), intrinsic->impl->params.size());
- for (size_t i = 0; i < intrinsic->desc->params.size(); ++i) {
- const Var& lhs_var = intrinsic->desc->params[i];
- const Buffer& lhs_buffer = intrinsic->desc->buffer_map[lhs_var];
- const Var& rhs_var = intrinsic->impl->params[i];
- const Buffer& rhs_buffer = intrinsic->impl->buffer_map[rhs_var];
- (*buffer_map)[rhs_buffer] = lhs_buffer;
- }
+ return result;
}
-void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref,
- const TensorIntrin& intrinsic) {
- /*!
- * Check:
- * - Check buffer binding, including type, alignment, shape and etc.
- * - Check the sub AST is equal to the desc function.
- *
- * Mutate:
- * - Blockize the sub AST (please refer blockize for details)
- * - Bind buffers
- * - Mutate the impl of the tensor intrinsic by replacing its buffers with new
- * buffers created via match buffer region.
- * - Replace the sub tree with the mutated function.
- */
- const BlockRealize& desc_block_realize = Downcast<BlockRealize>(intrinsic->desc->body);
- const BlockRealize& impl_block_realize = Downcast<BlockRealize>(intrinsic->impl->body);
- Block impl_block = impl_block_realize->block;
-
+void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& intrin) {
// Step 1: Blockize the subtree rooted at the given loop if needed
- StmtSRef block_sref{nullptr};
- if (block_or_loop_sref->StmtAs<ForNode>()) {
- block_sref = Blockize(self, block_or_loop_sref);
+ BlockRealize block_realize{nullptr};
+ Optional<Block> old_block = NullOpt;
+ if (sref->stmt->IsInstance<BlockNode>()) {
+ block_realize = GetBlockRealize(self, sref);
+ old_block = block_realize->block;
+ } else if (sref->stmt->IsInstance<ForNode>()) {
+ arith::Analyzer analyzer;
+ Map<Block, Block> block_sref_reuse;
+ block_realize = BlockizeImpl(self, sref, &block_sref_reuse, &analyzer);
} else {
- ICHECK(block_or_loop_sref->StmtAs<BlockNode>());
- block_sref = block_or_loop_sref;
+ LOG(FATAL) << "TypeError: Tensorize only support For or Block, but gets: "
+ << GetRef<Stmt>(sref->stmt);
+ throw;
}
- const BlockRealize& block_realize = GetBlockRealize(self, block_sref);
-
- // Step 2: Compare the block with the desc of the tensor intrinsic, find the correspondence
- // between buffers in the block and the desc.
+ PrimFunc intrin_desc = intrin->desc;
+ PrimFunc intrin_impl = DeepCopy(intrin->impl);
+ // Step 2: Structural pattern matching
TensorizeComparator comparator(self->mod, /*assert_mode=*/true);
- comparator.VisitStmt(block_realize, desc_block_realize);
-
- // Step 3: Find the correspondence between buffers in the current AST and the impl of
- // the tensor intrinsic
- // Step 3.1: Map from intrinsic func buffer to desc func buffer
- std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> intrin_buffer_map;
- RemapTensorIntrinBuffers(intrinsic, &intrin_buffer_map);
- // Step 3.2: Map form intrinsic func buffer to current AST buffer
- std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map;
- for (const auto& pair : intrin_buffer_map) {
- auto it = comparator.rhs_buffer_map_.find(pair.second);
- ICHECK(it != comparator.rhs_buffer_map_.end()) << pair.second;
- buffer_map[pair.first] = it->second;
+ comparator.VisitStmt(block_realize, intrin_desc->body);
+ // Step 3: Prepare necessary mapping
+ // 1) Buffer mapping from intrin impl buffers to intrin desc buffers.
+ // 2) Buffer mapping from intrin impl buffers to buffers in the current AST.
+ // 3) Mapping impl buffers to their accessed regions.
+ std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> impl2desc;
+ ICHECK_EQ(intrin_desc->params.size(), intrin_impl->params.size());
+ for (int i = 0, n = intrin_desc->params.size(); i < n; ++i) {
+ const Buffer& desc = intrin_desc->buffer_map[intrin_desc->params[i]];
+ const Buffer& impl = intrin_impl->buffer_map[intrin_impl->params[i]];
+ impl2desc[impl] = desc;
}
-
- // Step 4: Create MatchBufferRegion for the params of the impl function of the tensor
- // intrin to make them subregions of the buffer in the original IR.
- std::unordered_map<Buffer, Array<Range>, ObjectPtrHash, ObjectPtrEqual> buffer_region_map;
+ std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> impl2cur;
+ for (const auto& pair : impl2desc) {
+ const Buffer& impl = pair.first;
+ const Buffer& desc = pair.second;
+ ICHECK(comparator.rhs_buffer_map_.count(desc));
+ impl2cur[impl] = comparator.rhs_buffer_map_[desc];
+ }
+ std::unordered_map<Buffer, Array<Range>, ObjectPtrHash, ObjectPtrEqual> impl2region;
+ Block impl_block = Downcast<BlockRealize>(intrin_impl->body)->block;
for (const BufferRegion& read : impl_block->reads) {
- buffer_region_map.emplace(read->buffer, read->region);
+ impl2region.emplace(read->buffer, read->region);
}
for (const BufferRegion& write : impl_block->writes) {
- buffer_region_map.emplace(write->buffer, write->region);
+ impl2region.emplace(write->buffer, write->region);
}
+ // Step 4: Create MatchBufferRegion for the params of the impl function of the tensor
+ // intrin to make them subregions of the buffer in the original IR.
Array<MatchBufferRegion> match_buffer_regions;
- match_buffer_regions.reserve(intrinsic->impl->params.size());
- for (size_t i = 0; i < intrinsic->impl->params.size(); ++i) {
- const auto& param = intrinsic->impl->params[i];
- const auto& buffer = intrinsic->impl->buffer_map.at(param);
- const auto& source = buffer_map.at(buffer);
- // add the detected base indices to each buffer access region of the tensor intrinsic
- Region old_region = buffer_region_map.at(buffer);
- const auto& indices_base = comparator.buffer_indices_.at(source);
+ match_buffer_regions.reserve(intrin_impl->params.size());
+ for (int i = 0, n = intrin_impl->params.size(); i < n; ++i) {
+ const Buffer& impl = intrin_impl->buffer_map.at(intrin_impl->params[i]);
+ const Buffer& cur = impl2cur.at(impl);
+ const Array<Range>& old_region = impl2region.at(impl);
+ const std::vector<PrimExpr>& indices_base = comparator.buffer_indices_.at(cur);
int offset = static_cast<int>(indices_base.size()) - static_cast<int>(old_region.size());
ICHECK(offset >= 0);
- Region new_region;
- new_region.reserve(source->shape.size());
+ Array<Range> new_region;
+ new_region.reserve(cur->shape.size());
for (int i = 0; i < offset; i++) {
- new_region.push_back(Range::FromMinExtent(indices_base[i], 1));
+ PrimExpr min = indices_base[i];
+ PrimExpr extent = make_const(min.dtype(), 1);
+ new_region.push_back(Range::FromMinExtent(min, extent));
}
for (int i = 0; i < static_cast<int>(old_region.size()); i++) {
- new_region.push_back(Range::FromMinExtent(indices_base[i + offset], old_region[i]->extent));
+ PrimExpr min = indices_base[i + offset];
+ PrimExpr extent = old_region[i]->extent;
+ new_region.push_back(Range::FromMinExtent(min, extent));
}
- match_buffer_regions.push_back(MatchBufferRegion(buffer, BufferRegion(source, new_region)));
+ match_buffer_regions.push_back(MatchBufferRegion(impl, BufferRegion(cur, new_region)));
}
-
// Step 5: Replace the subtree in the original IR with the tensor intrin impl.
- ObjectPtr<BlockNode> new_block_ptr = make_object<BlockNode>(*block_realize->block.get());
- new_block_ptr->body = impl_block->body;
- ICHECK(new_block_ptr->match_buffers.empty());
- new_block_ptr->match_buffers = std::move(match_buffer_regions);
- Block new_block(new_block_ptr);
-
- self->Replace(block_sref, new_block, {{block_realize->block, new_block}});
-
+ {
+ BlockNode* block = block_realize.CopyOnWrite()->block.CopyOnWrite();
+ block->body = impl_block->body;
+ block->match_buffers = std::move(match_buffer_regions);
+ }
+ if (old_block.defined()) {
+ self->Replace(sref, block_realize->block, {{old_block.value(), block_realize->block}});
+ } else {
+ self->Replace(sref, block_realize, {});
+ }
// Step 6: Update the cached flags.
- StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
- self->UpdateScopeBlockInfo(static_cast<const BlockNode*>(scope_root->stmt)->body);
+ StmtSRef result = self->stmt2ref.at(block_realize->block.get());
+ StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false);
+ self->UpdateScopeBlockInfo(scope_root->StmtAs<BlockNode>()->body);
}
/******** InstructionKind Registration ********/
diff --git a/tests/python/unittest/test_tir_schedule_blockize.py b/tests/python/unittest/test_tir_schedule_blockize.py
index 481421cfdf..6d13281320 100644
--- a/tests/python/unittest/test_tir_schedule_blockize.py
+++ b/tests/python/unittest/test_tir_schedule_blockize.py
@@ -15,12 +15,10 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
-import sys
-import pytest
import tvm
import tvm.testing
-from tvm.script import tir as T
from tvm import tir
+from tvm.script import tir as T
from tvm.tir.schedule.testing import verify_trace_roundtrip
# fmt: off
@@ -33,177 +31,219 @@ def single_elementwise(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
-
-@T.prim_func
-def single_elementwise_blockized1(
- A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]
-) -> None:
- with T.block("blockized_B"):
- vio = T.axis.spatial(1, 0)
- vjo = T.axis.spatial(1, 0)
- T.reads(A[0:128, 0:128])
- T.writes(B[0:128, 0:128])
- for i, j in T.grid(128, 128):
- with T.block("B"):
- vi, vj = T.axis.remap("SS", [i, j])
- T.reads(A[vi, vj])
- T.writes(B[vi, vj])
- B[vi, vj] = A[vi, vj] * T.float32(2)
+# fmt: on
+# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
-@T.prim_func
-def single_elementwise_blockized2(
- A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]
-) -> None:
- for i in T.serial(128):
+def test_blockize_outer():
+ @T.prim_func
+ def after_blockize_outer(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer[(128, 128), "float32"],
+ ) -> None:
with T.block("blockized_B"):
- vi = T.axis.spatial(128, i)
+ vio = T.axis.spatial(1, 0)
vjo = T.axis.spatial(1, 0)
- T.reads(A[vi, 0:128])
- T.writes(B[vi, 0:128])
- for j in T.serial(128):
- with T.block("B"):
- vj = T.axis.remap("S", [j])
- T.reads(A[vi, vj])
- T.writes(B[vi, vj])
- B[vi, vj] = A[vi, vj] * T.float32(2)
-
-
-@T.prim_func
-def two_elementwise(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
- B = T.alloc_buffer([128, 128], dtype="float32")
- for i, j in T.grid(128, 128):
- with T.block("B"):
- vi, vj = T.axis.remap("SS", [i, j])
- T.reads(A[vi, vj])
- T.writes(B[vi, vj])
- B[vi, vj] = A[vi, vj] * T.float32(2)
- for i, j in T.grid(128, 128):
- with T.block("C"):
- vi, vj = T.axis.remap("SS", [i, j])
- T.reads(B[vi, vj])
- T.writes(C[vi, vj])
- C[vi, vj] = B[vi, vj] + T.float32(1)
-
-
-@T.prim_func
-def two_elementwise_blockized(
- A: T.Buffer[(128, 128), "float32"],
- C: T.Buffer[(128, 128), "float32"]
-) -> None:
- B = T.alloc_buffer([128, 128], dtype="float32")
- for i_0, j_0 in T.grid(8, 8):
- with T.block("blockized_B"):
- vio, vjo = T.axis.remap("SS", [i_0, j_0])
- T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
- T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
- for i_1, j_1 in T.grid(16, 16):
+ for i, j in T.grid(128, 128):
with T.block("B"):
- vi, vj = T.axis.remap("SS", [i_1, j_1])
- T.reads(A[vio * 16 + vi, vjo * 16 + vj])
- T.writes(B[vio * 16 + vi, vjo * 16 + vj])
- B[vio * 16 + vi, vjo * 16 + vj] = A[vio * 16 + vi, vjo * 16 + vj] * T.float32(2)
- with T.block("blockized_C"):
- vio, vjo = T.axis.remap("SS", [i_0, j_0])
- T.reads(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
- T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
- for ax0, ax1 in T.grid(16, 16):
- with T.block("C"):
- vi, vj = T.axis.remap("SS", [ax0, ax1])
- T.reads(B[vio * 16 + vi, vjo * 16 + vj])
- T.writes(C[vio * 16 + vi, vjo * 16 + vj])
- C[vio * 16 + vi, vjo * 16 + vj] = B[vio * 16 + vi, vjo * 16 + vj] + T.float32(1)
-
-
-@T.prim_func
-def rowsum(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) -> None:
- for k, i in T.grid(128, 128):
- with T.block("B"):
- vk, vi = T.axis.remap("RS", [k, i])
- with T.init():
- B[vi] = 0.0
- B[vi] = B[vi] + A[vi, vk]
-
-
-@T.prim_func
-def rowsum_blockized(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) -> None:
- with T.block("blockized_B"):
- vko = T.axis.R(1, 0)
- vio = T.axis.S(1, 0)
- with T.init():
- for i1 in T.serial(0, 128):
- with T.block("B_init"):
- vi_init = T.axis.S(128, i1)
- B[vi_init] = T.float32(0)
- for i0, i1_1 in T.grid(128, 128):
- with T.block("B"):
- vk, vi = T.axis.remap("RS", [i0, i1_1])
- B[vi] = B[vi] + A[vi, vk]
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] * 2.0
-
-# fmt: off
-# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
-
-def test_blockize_outer():
func = single_elementwise
- # schedule
s = tir.Schedule(func, debug_mask="all")
- B = s.get_block("B")
- x, y = s.get_loops(B)
+ x, _ = s.get_loops(s.get_block("B"))
s.blockize(x)
- print(s.mod['main'].script())
- tvm.ir.assert_structural_equal(s.mod["main"], single_elementwise_blockized1)
+ tvm.ir.assert_structural_equal(s.mod["main"], after_blockize_outer)
verify_trace_roundtrip(sch=s, mod=func)
def test_blockize_inner():
+ @T.prim_func
+ def after_blockize_inner(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer[(128, 128), "float32"],
+ ) -> None:
+ for i in T.serial(128):
+ with T.block("blockized_B"):
+ vi = T.axis.spatial(128, i)
+ vjo = T.axis.spatial(1, 0)
+ for j in T.serial(128):
+ with T.block("B"):
+ vj = T.axis.remap("S", [j])
+ B[vi, vj] = A[vi, vj] * 2.0
+
func = single_elementwise
- # schedule
s = tir.Schedule(func, debug_mask="all")
- B = s.get_block("B")
- x, y = s.get_loops(B)
+ _, y = s.get_loops(s.get_block("B"))
s.blockize(y)
- tvm.ir.assert_structural_equal(s.mod["main"], single_elementwise_blockized2)
+ tvm.ir.assert_structural_equal(s.mod["main"], after_blockize_inner)
verify_trace_roundtrip(sch=s, mod=func)
def test_two_elementwise_blockize_reverse_compute_at():
- func = two_elementwise
+ @T.prim_func
+ def before_blockize_rca(
+ A: T.Buffer[(128, 128), "float32"],
+ C: T.Buffer[(128, 128), "float32"],
+ ) -> None:
+ B = T.alloc_buffer([128, 128], dtype="float32")
+ for i, j in T.grid(8, 8):
+ with T.block("B_o"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ for i_1, j_1 in T.grid(16, 16):
+ with T.block("B"):
+ vi_i, vj_i = T.axis.remap("SS", [i_1, j_1])
+ T.reads(A[vi * 16 + vi_i, vj * 16 + vj_i])
+ T.writes(B[vi * 16 + vi_i, vj * 16 + vj_i])
+ B[vi * 16 + vi_i, vj * 16 + vj_i] = A[vi * 16 + vi_i, vj * 16 + vj_i] * 2.0
+ for ax0, ax1 in T.grid(16, 16):
+ with T.block("C"):
+ vi = T.axis.spatial(128, i * 16 + ax0)
+ vj = T.axis.spatial(128, j * 16 + ax1)
+ T.reads(B[vi, vj])
+ T.writes(C[vi, vj])
+ C[vi, vj] = B[vi, vj] + 1.0
+
+ @T.prim_func
+ def after_blockize_rca(
+ A: T.Buffer[(128, 128), "float32"],
+ C: T.Buffer[(128, 128), "float32"],
+ ) -> None:
+ B = T.alloc_buffer([128, 128], dtype="float32")
+ for i, j in T.grid(8, 8):
+ with T.block("B_o"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ for i_1, j_1 in T.grid(16, 16):
+ with T.block("B"):
+ vi_i, vj_i = T.axis.remap("SS", [i_1, j_1])
+ T.reads(A[vi * 16 + vi_i, vj * 16 + vj_i])
+ T.writes(B[vi * 16 + vi_i, vj * 16 + vj_i])
+ B[vi * 16 + vi_i, vj * 16 + vj_i] = A[vi * 16 + vi_i, vj * 16 + vj_i] * 2.0
+ with T.block("C_o"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ for ax0, ax1 in T.grid(16, 16):
+ with T.block("C"):
+ vi_i, vj_i = T.axis.remap("SS", [ax0, ax1])
+ T.reads(B[vi * 16 + vi_i, vj * 16 + vj_i])
+ T.writes(C[vi * 16 + vi_i, vj * 16 + vj_i])
+ C[vi * 16 + vi_i, vj * 16 + vj_i] = B[vi * 16 + vi_i, vj * 16 + vj_i] + 1.0
+
+ func = before_blockize_rca
s = tir.Schedule(func, debug_mask="all")
- B = s.get_block("B")
- C = s.get_block("C")
- x, y = s.get_loops(B)
- xo, xi = s.split(x, factors=[None, 16])
- yo, yi = s.split(y, factors=[None, 16])
- s.reorder(xo, yo, xi, yi)
- s.blockize(xi)
- s.reverse_compute_at(C, yo)
- s.blockize(s.get_loops(C)[-2])
- tvm.ir.assert_structural_equal(s.mod["main"], two_elementwise_blockized)
+ _, _, x, _ = s.get_loops(s.get_block("C"))
+ s.blockize(x)
+ tvm.ir.assert_structural_equal(s.mod["main"], after_blockize_rca)
verify_trace_roundtrip(sch=s, mod=func)
def test_two_elementwise_blockize_compute_at():
- func = two_elementwise
+ @T.prim_func
+ def before_blockize_compute_at(
+ A: T.Buffer[(128, 128), "float32"],
+ C: T.Buffer[(128, 128), "float32"],
+ ) -> None:
+ # body
+ # with T.block("root")
+ B = T.alloc_buffer([128, 128], dtype="float32")
+ for i_0, j_0 in T.grid(8, 8):
+ for ax0, ax1 in T.grid(16, 16):
+ with T.block("B"):
+ vi = T.axis.spatial(128, i_0 * 16 + ax0)
+ vj = T.axis.spatial(128, j_0 * 16 + ax1)
+ T.reads(A[vi, vj])
+ T.writes(B[vi, vj])
+ B[vi, vj] = A[vi, vj] * 2.0
+ with T.block("C_o"):
+ vi_o, vj_o = T.axis.remap("SS", [i_0, j_0])
+ T.reads(B[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
+ T.writes(C[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
+ for i_1, j_1 in T.grid(16, 16):
+ with T.block("C"):
+ vi_i, vj_i = T.axis.remap("SS", [i_1, j_1])
+ T.reads(B[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
+ T.writes(C[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
+ C[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = (
+ B[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + 1.0
+ )
+
+ @T.prim_func
+ def after_blockize_compute_at(
+ A: T.Buffer[(128, 128), "float32"],
+ C: T.Buffer[(128, 128), "float32"],
+ ) -> None:
+ B = T.alloc_buffer([128, 128], dtype="float32")
+ for i_0, j_0 in T.grid(8, 8):
+ with T.block("B_o"):
+ vi_o, vj_o = T.axis.remap("SS", [i_0, j_0])
+ T.reads(A[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
+ T.writes(B[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
+ for ax0, ax1 in T.grid(16, 16):
+ with T.block("B"):
+ vi_i, vj_i = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
+ T.writes(B[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
+ B[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = (
+ A[vi_o * 16 + vi_i, vj_o * 16 + vj_i] * 2.0
+ )
+ with T.block("C_o"):
+ vi_o, vj_o = T.axis.remap("SS", [i_0, j_0])
+ T.reads(B[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
+ T.writes(C[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
+ for i_1, j_1 in T.grid(16, 16):
+ with T.block("C"):
+ vi_i, vj_i = T.axis.remap("SS", [i_1, j_1])
+ T.reads(B[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
+ T.writes(C[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
+ C[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = (
+ B[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + 1.0
+ )
+
+ func = before_blockize_compute_at
s = tir.Schedule(func, debug_mask="all")
- B = s.get_block("B")
- C = s.get_block("C")
- x, y = s.get_loops(C)
- xo, xi = s.split(x, factors=[None, 16])
- yo, yi = s.split(y, factors=[None, 16])
- s.reorder(xo, yo, xi, yi)
- s.blockize(xi)
- s.compute_at(B, yo)
- s.blockize(s.get_loops(B)[-2])
- tvm.ir.assert_structural_equal(s.mod["main"], two_elementwise_blockized)
+ _, _, x, _ = s.get_loops(s.get_block("B"))
+ s.blockize(x)
+ tvm.ir.assert_structural_equal(s.mod["main"], after_blockize_compute_at)
verify_trace_roundtrip(sch=s, mod=func)
def test_blockize_init_loops():
+ @T.prim_func
+ def rowsum(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) -> None:
+ for k, i in T.grid(128, 128):
+ with T.block("B"):
+ vk, vi = T.axis.remap("RS", [k, i])
+ with T.init():
+ B[vi] = 0.0
+ B[vi] = B[vi] + A[vi, vk]
+
+ @T.prim_func
+ def after_rowsum_blockize(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer[(128,), "float32"],
+ ) -> None:
+ with T.block("blockized_B"):
+ vko = T.axis.R(1, 0)
+ vio = T.axis.S(1, 0)
+ with T.init():
+ for i1 in T.serial(0, 128):
+ with T.block("B_init"):
+ vi_init = T.axis.S(128, i1)
+ B[vi_init] = T.float32(0)
+ for i0, i1_1 in T.grid(128, 128):
+ with T.block("B"):
+ vk, vi = T.axis.remap("RS", [i0, i1_1])
+ B[vi] = B[vi] + A[vi, vk]
+
s = tir.Schedule(rowsum, debug_mask="all")
k, _ = s.get_loops(s.get_block("B"))
s.blockize(k)
- tvm.ir.assert_structural_equal(s.mod["main"], rowsum_blockized)
+ tvm.ir.assert_structural_equal(s.mod["main"], after_rowsum_blockize)
verify_trace_roundtrip(sch=s, mod=rowsum)