You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2021/10/02 22:35:10 UTC
[tvm] branch main updated: [TensorIR][M2a] Decompose-Reduction
(#9041)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6b3fe95 [TensorIR][M2a] Decompose-Reduction (#9041)
6b3fe95 is described below
commit 6b3fe95f3c5ff75f65d15d6fc2797f584c456a5a
Author: Bohan Hou <32...@users.noreply.github.com>
AuthorDate: Sat Oct 2 18:34:41 2021 -0400
[TensorIR][M2a] Decompose-Reduction (#9041)
This PR is part of the TensorIR upstreaming effort (#7527),
which adds the `decompose-reduction` scheduling primitive.
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
---
include/tvm/tir/schedule/schedule.h | 16 +
include/tvm/tir/schedule/state.h | 6 +
python/tvm/tir/schedule/schedule.py | 76 +++
src/tir/schedule/concrete_schedule.cc | 9 +
src/tir/schedule/concrete_schedule.h | 1 +
src/tir/schedule/primitive.h | 17 +
src/tir/schedule/primitive/reduction.cc | 302 +++++++++
src/tir/schedule/schedule.cc | 2 +
src/tir/schedule/state.cc | 147 +++--
src/tir/schedule/traced_schedule.cc | 10 +
src/tir/schedule/traced_schedule.h | 1 +
.../python/unittest/test_tir_schedule_reduction.py | 679 ++++-----------------
...e_reduction.py => test_tir_schedule_rfactor.py} | 0
13 files changed, 664 insertions(+), 602 deletions(-)
diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h
index 9f48d9a..c4aa1c9 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -366,6 +366,22 @@ class ScheduleNode : public runtime::Object {
virtual void ReverseComputeInline(const BlockRV& block) = 0;
/******** Schedule: Reduction ********/
/*!
+ * \brief Decompose a reduction block into two separate blocks.
+ * a) The init block, which is translated from the init statement of the reduction block;
+ * b) The update block, which is the original block without init statement.
+ *
+ * The init block is inserted right before the given loop.
+ *
+ * The schedule primitive requires:
+ * 1) The input block is a reduction block.
+ * 2) The input loop is the ancestor of the block.
+ * 3) The input loop is not lower than all the loops related to reduce block var.
+ * \param block_rv The reduction block to be decomposed
+ * \param loop_rv The loop above which the init block is inserted before.
+ * \return The init block
+ */
+ virtual BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;
+ /*!
* \brief Factorize an associative reduction block by the specified loop.
* \details An associative reduction cannot be parallelized directly,
* because it leads to potential race condition during accumulation.
diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h
index 7cd1b00..201d78f 100644
--- a/include/tvm/tir/schedule/state.h
+++ b/include/tvm/tir/schedule/state.h
@@ -143,6 +143,12 @@ class ScheduleStateNode : public Object {
/*! \brief Returns the BlockInfo correpsonding to the block sref */
TVM_DLL BlockInfo GetBlockInfo(const StmtSRef& block_sref) const;
/*!
+ * \brief Recalculate the BlockInfo recursively under stmt.
+ * If stmt is a Block itself, we will not reset its affine binding flag unless it doesn't
+ * have block vars, since the affine flag depends on the outer scope of stmt.
+ */
+ TVM_DLL void UpdateScopeBlockInfo(const Stmt& stmt);
+ /*!
* \brief Get the BlockScope correpsonding to the sref of scope root block
* \param scope_root The block sref to be retrieved
* \return The corresponding BlockScope
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index 6e27015..09a52d2 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -1224,6 +1224,82 @@ class Schedule(Object):
########## Schedule: Reduction ##########
+ def decompose_reduction(self, block: BlockRV, loop: LoopRV) -> BlockRV:
+ """Decompose a reduction block into two separate blocks.
+
+ a) The init block, which is translated from the init statement of the reduction block;
+
+ b) The update block, which is the original block without init statement.
+
+ The init block is inserted right before the given loop.
+
+ The schedule primitive requires:
+
+ 1) The input block is a reduction block.
+
+ 2) The input loop is the ancestor of the block.
+
+ 3) The input loop is not lower than all the loops related to reduce block var.
+
+ Parameters
+ ----------
+ block : BlockRV
+ The reduction block to be decomposed
+ loop : LoopRV
+ The loop above which the init block is inserted before.
+
+ Returns
+ -------
+ init_block : BlockRV
+ The init block
+
+ Examples
+ --------
+ Before decompose-reduction, in TensorIR, the IR is:
+
+ .. code-block:: python
+
+ @tvm.script.tir
+ def before_decompose(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, [128, 128])
+ B = tir.match_buffer(b, [128, 128])
+ C = tir.match_buffer(c, [128, 128])
+ for i, j, k in tir.grid(128, 128, 128):
+ with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
+ with tir.init():
+ C[vi, vj] = 0.0
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+ Create the schedule and do decompose-reduction with specified loop:
+
+ .. code-block:: python
+
+ sch = tir.Schedule(before_decompose)
+ C = sch.get_block("C")
+ i, j, k = sch.get_loops(C)
+ sch.decompose_reduction(C, i)
+ print(tvm.script.asscript(sch.mod["main"]))
+
+ After applying decompose-reduction, the IR becomes:
+
+ .. code-block:: python
+
+ @tvm.script.tir
+ def after_decompose(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, [128, 128])
+ B = tir.match_buffer(b, [128, 128])
+ C = tir.match_buffer(c, [128, 128])
+ for i in tir.serial(128):
+ for j in tir.serial(128):
+ with tir.block([128, 128]) as [vi, vj]:
+ C[vi, vj] = 0.0
+ for i, j, k in tir.grid(128, 128, 128):
+ with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+ """
+ return _ffi_api.ScheduleDecomposeReduction(self, block, loop) # type: ignore # pylint: disable=no-member
+
def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV:
"""Factorize an associative reduction block by the specified loop.
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index 93eba52..4283907 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -501,6 +501,15 @@ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_inde
/******** Schedule: Reduction ********/
+BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) {
+ StmtSRef result{nullptr};
+ TVM_TIR_SCHEDULE_BEGIN();
+ result = tir::DecomposeReduction(state_, this->GetSRef(block_rv), this->GetSRef(loop_rv));
+ TVM_TIR_SCHEDULE_END("decompose-reduction", this->error_render_level_);
+ this->state_->DebugVerify();
+ return CreateRV<BlockRV>(result);
+}
+
BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h
index c9a9402..1f9aeec 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -115,6 +115,7 @@ class ConcreteScheduleNode : public ScheduleNode {
void ReverseComputeInline(const BlockRV& block) override;
/******** Schedule: Reduction ********/
BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override;
+ BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) override;
/******** Schedule: Block annotation ********/
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 8d8acd2..057e845 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -234,6 +234,23 @@ TVM_DLL void ComputeInline(ScheduleState self, const StmtSRef& block_sref);
TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref);
/******** Schedule: Reduction ********/
/*!
+ * \brief Decompose a reduction block into two separate blocks.
+ * a) The init block, which is translated from the init statement of the reduction block;
+ * b) The update block, which is the original block without init statement.
+ *
+ * The init block is inserted right before the given loop.
+ *
+ * The schedule primitive requires:
+ * 1) The input block is a reduction block.
+ * 2) The input loop is the ancestor of the block.
+ * 3) The input loop is not lower than all the loops related to reduce block var.
+ * \param block_rv The reduction block to be decomposed
+ * \param loop_rv The loop above which the init block is inserted before.
+ * \return The init block
+ */
+TVM_DLL StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+ const StmtSRef& loop_sref);
+/*!
* \brief Factor a reduction block by the specified loop
* \details See python/tvm/tir/schedule/schedule.py
* \param self The state of the schedule
diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc
index 677b643..0653f6e 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -21,6 +21,282 @@
namespace tvm {
namespace tir {
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+ /*!
+ * \brief The open interface to users to call the helper class
+ * \param old_scope_root The original block scope before decomposition
+ * \param target_loop The loop we insert the decomposed init body before
+ * \param decompose_body The decomposed init body
+ * \param old_reduction_block The reduction block we want to decompose
+ * \return The new block scope and the updated reduction block
+ */
+ static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+ Stmt decomposed_body, Block old_reduction_block) {
+ DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body),
+ std::move(old_reduction_block));
+ return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+ replacer.new_reduction_block_);
+ }
+
+ private:
+ explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body,
+ Block old_reduction_block)
+ : target_loop_(std::move(target_loop)),
+ decomposed_body_(std::move(decomposed_body)),
+ old_reduction_block_(std::move(old_reduction_block)) {}
+
+ Stmt VisitStmt_(const ForNode* loop) final {
+ Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+ if (loop == target_loop_.get()) {
+ return SeqStmt({decomposed_body_, mutated_stmt});
+ } else {
+ return mutated_stmt;
+ }
+ }
+
+ Stmt VisitStmt_(const BlockNode* block) final {
+ if (block == old_reduction_block_.get()) {
+ ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+ p_new_block->name_hint = p_new_block->name_hint + "_update";
+ p_new_block->init = NullOpt;
+ new_reduction_block_ = Block(p_new_block);
+ return new_reduction_block_;
+ } else {
+ return StmtMutator::VisitStmt_(block);
+ }
+ }
+
+ Stmt VisitStmt_(const SeqStmtNode* seq) final {
+ Array<Stmt> new_stmts;
+ new_stmts.reserve(seq->seq.size());
+ for (const Stmt& old_stmt : seq->seq) {
+ new_stmts.push_back(VisitStmt(old_stmt));
+ }
+ return SeqStmt::Flatten(new_stmts);
+ }
+
+ private:
+ For target_loop_;
+ Stmt decomposed_body_;
+ Block old_reduction_block_;
+ Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+ explicit LoopPositionError(IRModule mod, For loop, Block block)
+ : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block";
+ }
+
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an "
+ "ancestor of block {1}.";
+ return os.str();
+ }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+ IRModule mod_;
+ For loop_;
+ Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+ static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block,
+ const BlockRealizeNode* realize,
+ const Array<StmtSRef>& loops,
+ const StmtSRef& loop_sref) {
+ for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+ // For each block var of type kCommReduce, check its binding
+ const IterVar& iter_var = block->iter_vars[i];
+ const PrimExpr& binding = realize->iter_values[i];
+ if (iter_var->iter_type != IterVarType::kCommReduce) {
+ continue;
+ }
+ for (const StmtSRef& higher_loop : loops) {
+ // Only check loops not lower than the target loop
+ if (higher_loop.same_as(loop_sref)) {
+ break;
+ }
+ // loop_var of a higher loop shouldn't contain loop var
+ const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+ if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
+ const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+ throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+ }
+ }
+ }
+ }
+
+ explicit LoopHeightError(IRModule mod, For loop, Block block)
+ : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops "
+ "related to reduce block var";
+ }
+
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops "
+ "related to reduce block var of block {1}";
+ return os.str();
+ }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }
+
+ IRModule mod_;
+ For loop_;
+ Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const VarNode*>& discarded_loops) {
+ if (is_one(pred)) return Bool(true);
+ PrimExpr new_pred = Bool(true);
+ auto f = [&](const VarNode* var) { return discarded_loops.count(var); };
+ arith::PVar<PrimExpr> lhs, rhs, rest;
+ for (;;) {
+ if ((rest && (lhs < rhs)).Match(pred)) {
+ if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+ pred = rest.Eval();
+ } else if ((lhs < rhs).Match(pred)) {
+ if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval());
+ break;
+ } else {
+ ICHECK(false) << "Unexpected predicate for reduction block";
+ }
+ }
+ return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+ const StmtSRef& loop_sref) {
+ /*!
+ * Check
+ * - block is a reduction block
+ * - loop is not lower than all the loops related to reduce block var
+ * Mutate
+ * - generate loops related to data par block vars
+ * - generate corresponding init block and update block
+ */
+ // Condition Checks and Information Collection
+ const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+ const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+ // Get the outer loops from high to low
+ Array<StmtSRef> loops = GetLoops(block_sref);
+ const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+ // Cond 0. Check loop_sref is an ancestor of block_sref
+ if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) {
+ throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block));
+ }
+ // Cond 1. Check block is reduction
+ StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+ /*require_stage_pipeline=*/false,
+ /*require_subtree_compact_dataflow=*/false);
+ CheckReductionBlock(self, block_sref, scope_root_sref);
+ // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
+ LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
+ // IR Manipulation
+ ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+ ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+ init_block->name_hint = block->name_hint + "_init";
+ init_realize->iter_values = {};
+ init_realize->block = Block(init_block);
+ // Step 1. Create new block vars and their bindings
+ // Maps an old block var to the new corresponding block var
+ std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> block_var_map;
+ block_var_map.reserve(block->iter_vars.size());
+ for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+ const IterVar& iter_var = block->iter_vars[i];
+ const PrimExpr& binding = realize->iter_values[i];
+ // Only process data parallel block vars
+ if (iter_var->iter_type != IterVarType::kDataPar) {
+ continue;
+ }
+ // Create a new block var
+ IterVar new_iter_var(/*dom=*/iter_var->dom,
+ /*var=*/iter_var->var.copy_with_suffix(""),
+ /*iter_type=*/iter_var->iter_type,
+ /*thread_tag=*/iter_var->thread_tag);
+ // Add a block var and its binding
+ init_block->iter_vars.push_back(new_iter_var);
+ init_realize->iter_values.push_back(binding);
+ // Add a mapping from old block vars to new block vars
+ block_var_map[iter_var->var] = new_iter_var->var;
+ }
+ // Step 2. After copying block vars, substitute them in init block
+ init_block->body = Substitute(block->init.value(), block_var_map);
+ for (const BufferRegion& write : block->writes) {
+ init_block->writes.push_back(
+ BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+ }
+ // Step 3. Scan loops not higher than the specified loop above the reduction block.
+ // If the loop is used in the init block binding, then it is chosen.
+ // Otherwise, it is discarded.
+ std::unordered_set<const VarNode*> discarded_loops;
+ std::vector<int> chosen_loops;
+ for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+ const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+ bool discarded = true;
+ for (const PrimExpr& expr : init_realize->iter_values) {
+ if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) {
+ continue;
+ }
+ // The loop is related to init block bindings;
+ chosen_loops.push_back(i);
+ discarded = false;
+ break;
+ }
+ if (discarded) discarded_loops.insert(loop_var);
+ // Only scan loops not higher than the given loop
+ if (loops[i].same_as(loop_sref)) {
+ break;
+ }
+ }
+ // Step 4. After scanning loops, make a new predicate in the init block realize
+ // We discard predicate that is related to discarded loops
+ init_realize->predicate = RemakePredicate(realize->predicate, discarded_loops);
+ // Step 5. Create new loops above init block
+ std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> loop_var_map;
+ Stmt body = BlockRealize(init_realize);
+ for (int i : chosen_loops) {
+ const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);
+ // Create a new equivalent to the chosen loop
+ Var old_loop_var = old_loop->loop_var;
+ Var new_loop_var = old_loop_var.copy_with_suffix("_init");
+ loop_var_map[old_loop_var] = new_loop_var;
+ body = For(/*loop_var=*/new_loop_var,
+ /*min=*/old_loop->min,
+ /*extent=*/old_loop->extent,
+ /*kind=*/ForKind::kSerial,
+ /*body=*/body);
+ }
+ body = Substitute(body, loop_var_map);
+ // Step 6. Mutate IR
+ const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(old_scope_root, scope_root_sref);
+ Block new_scope_root{nullptr};
+ Block new_reduction_block{nullptr};
+ std::tie(new_scope_root, new_reduction_block) = DecomposeReductionBlockReplacer::Replace(
+ GetRef<Block>(old_scope_root), GetRef<For>(loop), body, GetRef<Block>(block));
+ self->Replace(scope_root_sref, new_scope_root,
+ {{GetRef<Block>(old_scope_root), new_scope_root},
+ {GetRef<Block>(block), new_reduction_block}});
+ self->UpdateScopeBlockInfo(new_scope_root);
+ return self->stmt2ref.at(init_block.get());
+}
+
/******** Commutative Reducer ********/
/*!
@@ -958,6 +1234,31 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax
/******** InstructionKind Registration ********/
+struct DecomposeReductionTraits : public UnpackedInstTraits<DecomposeReductionTraits> {
+ static constexpr const char* kName = "DecomposeReduction";
+ static constexpr bool kIsPure = false;
+
+ private:
+ static constexpr size_t kNumInputs = 2;
+ static constexpr size_t kNumAttrs = 0;
+ static constexpr size_t kNumDecisions = 0;
+
+ static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv) {
+ return sch->DecomposeReduction(block_rv, loop_rv);
+ }
+
+ static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv) {
+ PythonAPICall py("decompose_reduction");
+ py.Input("block", block_rv);
+ py.Input("loop", loop_rv);
+ py.SingleOutput(outputs);
+ return py.Str();
+ }
+
+ template <typename>
+ friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
struct RFactorTraits : public UnpackedInstTraits<RFactorTraits> {
static constexpr const char* kName = "RFactor";
static constexpr bool kIsPure = false;
@@ -984,6 +1285,7 @@ struct RFactorTraits : public UnpackedInstTraits<RFactorTraits> {
};
TVM_REGISTER_INST_KIND_TRAITS(RFactorTraits);
+TVM_REGISTER_INST_KIND_TRAITS(DecomposeReductionTraits);
/******** FFI ********/
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 4262a09..84a37c3 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -155,6 +155,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline")
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline")
.set_body_method<Schedule>(&ScheduleNode::ReverseComputeInline);
/******** (FFI) Reduction ********/
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposeReduction")
+ .set_body_method<Schedule>(&ScheduleNode::DecomposeReduction);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor")
.set_body_method<Schedule>(&ScheduleNode::RFactor);
/******** (FFI) Block annotation ********/
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index 4604add..faeb0b9 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -169,34 +169,16 @@ void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new
}
/**************** Creation ****************/
-
-/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
-class StateCreator : private StmtVisitor {
+/*! \brief A helper class to update BlockInfo for a ScheduleStateNode */
+class BlockInfoCollector : private StmtVisitor {
public:
- /*!
- * \brief The entry function
- * \param self The schedule state to be completed
- */
- static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mask) {
- ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
- ScheduleStateNode* self = n.get();
- // Set `n->mod`
- n->mod = std::move(mod);
- // Set `n->debug_mask`
- n->debug_mask = debug_mask;
- // Set `n->stmt2ref` and `n->block_info`
- StateCreator creator(self);
- for (const auto& kv : n->mod->functions) {
- const BaseFunc& base_func = kv.second;
- if (const auto* func = base_func.as<PrimFuncNode>()) {
- creator.VisitStmt(func->body);
- }
- }
- return n;
+ static void Collect(ScheduleStateNode* self, const Stmt& stmt) {
+ BlockInfoCollector collector(self);
+ collector.VisitStmt(stmt);
}
private:
- explicit StateCreator(ScheduleStateNode* self)
+ explicit BlockInfoCollector(ScheduleStateNode* self)
: self_(self), srefs_{}, block2realize_{}, block_frames_{} {
block_frames_.emplace({});
}
@@ -206,25 +188,11 @@ class StateCreator : private StmtVisitor {
* \param stmt A for-loop statement or a block statement
* \return A sref to the stmt
*/
- StmtSRef PushSRef(const StmtNode* stmt) {
- if (srefs_.empty()) {
- srefs_.push_back(
- StmtSRef(stmt,
- /*parent=*/nullptr,
- /*seq_index=*/-1)); // `seq_index` will be set properly in SetSeqIndex
- } else {
- StmtSRefNode* parent = srefs_.back().get();
- srefs_.push_back(
- StmtSRef(stmt, parent,
- /*seq_index=*/-1)); // `seq_index` will be set properly in SetSeqIndex
- }
- return srefs_.back();
- }
+ void PushSRef(const StmtNode* stmt) { srefs_.push_back(self_->stmt2ref.at(stmt)); }
- /*! \brief Pop the top of the scope and record it in stmt2ref map */
- StmtSRef PopAndRecordSRef() {
- StmtSRef sref = std::move(srefs_.back());
- self_->stmt2ref[sref->stmt] = sref;
+ /*! \brief Pop the top of the scope */
+ StmtSRef PopSRef() {
+ StmtSRef sref = srefs_.back();
srefs_.pop_back();
return sref;
}
@@ -238,7 +206,10 @@ class StateCreator : private StmtVisitor {
.first->second;
// Set `affine_binding`
if (is_root_block) {
- info.affine_binding = true;
+ // If the block doesn't have outer loops and BlockRealize,
+ // then we set the affine binding flag as true only if the block has no block vars
+ const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root);
+ if (block->iter_vars.empty()) info.affine_binding = true;
} else {
info.affine_binding =
IsAffineBinding(/*realize=*/block2realize_.at(scope_root->stmt),
@@ -385,7 +356,7 @@ class StateCreator : private StmtVisitor {
analyzer_.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
PushSRef(loop);
VisitStmt(loop->body);
- PopAndRecordSRef();
+ PopSRef();
}
void VisitStmt_(const BlockRealizeNode* realize) final {
@@ -395,7 +366,7 @@ class StateCreator : private StmtVisitor {
// Recursive visit
PushSRef(block);
VisitStmt(block->body); // `block->init` is not visited
- StmtSRef sref = PopAndRecordSRef();
+ StmtSRef sref = PopSRef();
// Create BlockInfo for the block
MakeBlockInfo(sref);
// Update parent scope
@@ -409,7 +380,7 @@ class StateCreator : private StmtVisitor {
SetSeqIndexInChildren(self_, seq_stmt);
}
- /*! \brief The result ScheduleStateNode */
+ /*! \brief The ScheduleStateNode we are operating on */
ScheduleStateNode* self_;
/*! \brief The stack frame used to indicate the current scope */
std::vector<StmtSRef> srefs_;
@@ -421,6 +392,86 @@ class StateCreator : private StmtVisitor {
arith::Analyzer analyzer_;
};
+/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
+class StateCreator : private StmtVisitor {
+ public:
+ /*!
+ * \brief The entry function
+ * \param self The schedule state to be completed
+ */
+ static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mask) {
+ ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
+ ScheduleStateNode* self = n.get();
+ // Set `n->mod`
+ n->mod = std::move(mod);
+ // Set `n->debug_mask`
+ n->debug_mask = debug_mask;
+ // Set `n->stmt2ref` and `n->block_info`
+ StateCreator creator(self);
+ for (const auto& kv : n->mod->functions) {
+ const BaseFunc& base_func = kv.second;
+ if (const auto* func = base_func.as<PrimFuncNode>()) {
+ creator.VisitStmt(func->body);
+ BlockInfoCollector::Collect(self, func->body);
+ }
+ }
+ return n;
+ }
+
+ private:
+ explicit StateCreator(ScheduleStateNode* self) : self_(self) {}
+
+ /*!
+ * \brief Add a new statement to the stack, which becomes the current scope
+ * \param stmt A for-loop statement or a block statement
+ * \return A sref to the stmt
+ */
+ void PushSRef(const StmtNode* stmt) {
+ if (srefs_.empty()) {
+ srefs_.push_back(
+ StmtSRef(stmt,
+ /*parent=*/nullptr,
+ /*seq_index=*/-1)); // `seq_index` will be set properly in SetSeqIndex
+ } else {
+ StmtSRefNode* parent = srefs_.back().get();
+ srefs_.push_back(
+ StmtSRef(stmt, parent,
+ /*seq_index=*/-1)); // `seq_index` will be set properly in SetSeqIndex
+ }
+ }
+
+ /*! \brief Pop the top of the scope and record it in stmt2ref map */
+ void PopAndRecordSRef() {
+ StmtSRef sref = std::move(srefs_.back());
+ self_->stmt2ref[sref->stmt] = sref;
+ srefs_.pop_back();
+ }
+
+ void VisitStmt_(const ForNode* loop) final {
+ PushSRef(loop);
+ VisitStmt(loop->body);
+ PopAndRecordSRef();
+ }
+
+ void VisitStmt_(const BlockRealizeNode* realize) final {
+ const BlockNode* block = realize->block.get();
+ PushSRef(block);
+ VisitStmt(block->body); // `block->init` is not visited
+ PopAndRecordSRef();
+ }
+
+ void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+ // Set `seq_index` information for SeqStmtNode
+ StmtVisitor::VisitStmt_(seq_stmt);
+ SetSeqIndexInChildren(self_, seq_stmt);
+ }
+
+ /*! \brief The result ScheduleStateNode */
+ ScheduleStateNode* self_;
+ /*! \brief The stack frame used to indicate the current scope */
+ std::vector<StmtSRef> srefs_;
+};
+
/**************** Constructor ****************/
ScheduleState::ScheduleState(IRModule mod, int debug_mask) {
@@ -1034,6 +1085,10 @@ BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const {
return it->second;
}
+void ScheduleStateNode::UpdateScopeBlockInfo(const Stmt& stmt) {
+ BlockInfoCollector::Collect(this, stmt);
+}
+
TVM_DLL Array<Bool> GetCachedFlags(const ScheduleState& self, const StmtSRef& block_sref) {
const BlockInfo& info = self->GetBlockInfo(block_sref);
return {Bool(info.affine_binding), //
diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc
index 6f67959..cc48f2b 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -236,6 +236,16 @@ void TracedScheduleNode::ReverseComputeInline(const BlockRV& block_rv) {
/******** Schedule: Reduction ********/
+BlockRV TracedScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) {
+ BlockRV result = ConcreteScheduleNode::DecomposeReduction(block_rv, loop_rv);
+ static const InstructionKind& kind = InstructionKind::Get("DecomposeReduction");
+ trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
+ /*inputs=*/{block_rv, loop_rv},
+ /*attrs=*/{},
+ /*outputs=*/{result}));
+ return result;
+}
+
BlockRV TracedScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) {
BlockRV result = ConcreteScheduleNode::RFactor(loop_rv, factor_axis);
static const InstructionKind& kind = InstructionKind::Get("RFactor");
diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h
index fb89783..fae5ca8 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -82,6 +82,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
void ComputeInline(const BlockRV& block_rv) final;
void ReverseComputeInline(const BlockRV& block_rv) final;
/******** Schedule: Reduction ********/
+ BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) final;
BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) final;
/******** Schedule: Block annotation ********/
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py
index d79338a..8460b5c 100644
--- a/tests/python/unittest/test_tir_schedule_reduction.py
+++ b/tests/python/unittest/test_tir_schedule_reduction.py
@@ -28,607 +28,174 @@ from tvm.tir.schedule.testing import verify_trace_roundtrip
@T.prim_func
-def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, [128, 128])
- B = T.match_buffer(b, [128, 128])
- C = T.match_buffer(c, [128, 128])
-
- for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4):
- with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
- T.bind(vi, i0)
- T.bind(vj, i1)
- T.bind(vk, (((i2_outer * 32) + (i2_inner_outer * 4)) + i2_inner_inner))
- T.reads([C[vi, vj], A[vi, vk], B[vj, vk]])
- T.writes([C[vi, vj]])
+def rowsum_blockized(a: T.handle, b: T.handle) -> None:
+ B = T.match_buffer(b, [32, 4])
+ A = T.match_buffer(a, [32, 4, 128])
+ for i0, i2_0 in T.grid(32, 16):
+ with T.block([32, T.reduce_axis(0, 16)], "blockized_B") as [io, ko]:
+ T.bind(io, i0)
+ T.bind(ko, i2_0)
with T.init():
- C[vi, vj] = 0.0
- C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
+ for i1 in T.serial(0, 4):
+ with T.block([4], "B_init") as [ii_init]:
+ T.bind(ii_init, i1)
+ B[io, ii_init] = 0.0
+ for i1_1, i2_1 in T.grid(4, 8):
+ with T.block([4, T.reduce_axis(0, 128)], "B") as [ii, k]:
+ T.bind(ii, i1_1)
+ T.bind(k, ko * 8 + i2_1)
+ B[io, ii] = B[io, ii] + A[io, ii, k]
@T.prim_func
-def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None:
+def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
- C_rf = T.alloc_buffer([4, 128, 128])
-
- for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4):
- with T.block([4, 128, 128, T.reduce_axis(0, 4), T.reduce_axis(0, 8)], "update_rf") as [
- vi2_inner_inner,
- vi,
- vj,
- vi2_outer,
- vi2_inner_outer,
- ]:
- T.bind(vi2_inner_inner, i2_inner_inner)
- T.bind(vi, i0)
- T.bind(vj, i1)
- T.bind(vi2_outer, i2_outer)
- T.bind(vi2_inner_outer, i2_inner_outer)
- with T.init():
- C_rf[vi2_inner_inner, vi, vj] = 0.0
- C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + (
- A[vi, (((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)]
- * B[vj, (((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)]
- )
-
- for i0_1, i1_1, i2_inner_inner_1 in T.grid(128, 128, 4):
- with T.block([T.reduce_axis(0, 4), 128, 128], "update") as [
- vi2_inner_inner_1,
- vi_1,
- vj_1,
- ]:
- T.bind(vi2_inner_inner_1, i2_inner_inner_1)
- T.bind(vi_1, i0_1)
- T.bind(vj_1, i1_1)
- with T.init():
- C[vi_1, vj_1] = 0.0
- C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1]
-
-@T.prim_func
-def matmul_not_stage_pipeline(a: T.handle, b: T.handle, d: T.handle) -> None:
- A = T.match_buffer(a, [256, 256])
- B = T.match_buffer(b, [256, 256])
- D = T.match_buffer(d, [256, 256])
- C = T.alloc_buffer([256, 256])
-
- with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
+ with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
with T.init():
C[vi, vj] = 0.0
- C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
-
- with T.block([256, 256], "D") as [vi, vj]:
- D[vi, vj] = C[vi, vj]
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
@T.prim_func
-def matmul_not_same_buffer_access(a: T.handle, b: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (128, 128))
- B = T.match_buffer(b, (128, 128))
- C = T.match_buffer(c, (128, 128))
-
- with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
- with T.init():
- C[vi, vj] = 0.0
- C[vj, vi] = C[vj, vi] + A[vi, vk] * B[vk, vj]
-
-
-@T.prim_func
-def matmul_loop_multiple_children(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None:
+def matmul_decompose0(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
- D = T.match_buffer(d, [128, 128])
-
- for k, i, j in T.grid(128, 128, 128):
- with T.block([T.reduce_axis(0, 128), 128, 128], "C") as [ck, ci, cj]:
- T.bind(ck, k)
- T.bind(ci, i)
- T.bind(cj, j)
- with T.init():
- C[ci, cj] = 0.0
- C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj]
- with T.block([T.reduce_axis(0, 128), 128, 128], "D") as [dk, di, dj]:
- T.bind(dk, k)
- T.bind(di, i)
- T.bind(dj, j)
- with T.init():
- D[di, dj] = 0.0
- D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj]
+ with T.block([128, 128], "init") as [vi, vj]:
+ C[vi, vj] = 0.0
-@T.prim_func
-def square_sum(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, [16, 256, 256])
- C = T.match_buffer(c, [16])
-
- with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]:
- with T.init():
- C[b] = 0.0
- C[b] = C[b] + A[b, i, j] * A[b, i, j]
+ with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
@T.prim_func
-def square_sum_rfactor(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, [16, 256, 256])
- C = T.match_buffer(c, [16])
- C_rf = T.alloc_buffer([16, 256])
-
- for i0, i1, i2 in T.grid(16, 256, 256):
- with T.block([256, 16, T.reduce_axis(0, 256)], "C_rf") as [vi2, b, i]:
- T.bind(vi2, i2)
- T.bind(b, i0)
- T.bind(i, i1)
- with T.init():
- C_rf[b, vi2] = 0.0
- C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2])
+def matmul_decompose1(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [32, 4, 128], elem_offset=0, align=128, offset_factor=1)
+ B = T.match_buffer(b, [32, 4], elem_offset=0, align=128, offset_factor=1)
- for i0_1, i2_1 in T.grid(16, 256):
- with T.block([T.reduce_axis(0, 256), 16], "C") as [vi2_1, b_1]:
- T.bind(vi2_1, i2_1)
- T.bind(b_1, i0_1)
- with T.init():
- C[b_1] = 0.0
- C[b_1] = C[b_1] + C_rf[b_1, vi2_1]
-
-
-@T.prim_func
-def transformed_square_sum_square_root(a: T.handle, d: T.handle) -> None:
- A = T.match_buffer(a, [16, 256, 256])
- D = T.match_buffer(d, [16])
- C = T.alloc_buffer([16])
-
- for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1):
- with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]:
- T.bind(b, i0)
- T.bind(i, T.floordiv(i1_i2_fused_outer, 256))
- T.bind(j, T.floormod(i1_i2_fused_outer, 256))
- T.reads([C[b], A[b, i, j]])
- T.writes([C[b]])
- with T.init():
- C[b] = 0.0
- C[b] = C[b] + (A[b, i, j] * A[b, i, j])
- for i0_1 in T.serial(0, 16):
- with T.block([16], "D") as [b_1]:
- T.bind(b_1, i0_1)
- T.reads([C[b_1]])
- T.writes([D[b_1]])
- D[b_1] = T.sqrt(C[b_1], dtype="float32")
+ for i0 in T.serial(0, 32):
+ with T.block([32], "blockized_B_init") as [io]:
+ for i1 in T.serial(0, 4):
+ with T.block([4], "B_init") as [ii]:
+ B[io, ii] = T.float32(0)
+ for i0, i2_o in T.grid(32, 16):
+ with T.block([32, T.reduce_axis(0, 16)], "blockized_B_update") as [io, ko]:
+ for i1, i2_i in T.grid(4, 8):
+ with T.block([4, T.reduce_axis(0, 128)], "B") as [ii, k]:
+ T.bind(ii, i1)
+ T.bind(k, ((ko * 8) + i2_i))
+ B[io, ii] = B[io, ii] + A[io, ii, k]
@T.prim_func
-def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None:
- A = T.match_buffer(a, [16, 256, 256])
- D = T.match_buffer(d, [16])
- C = T.alloc_buffer([16])
- C_rf = T.alloc_buffer([1, 16])
-
- for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1):
- with T.block([1, 16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C_rf") as [
- vi1_i2_fused_inner,
- b,
- i,
- j,
- ]:
- T.bind(vi1_i2_fused_inner, i1_i2_fused_inner)
- T.bind(b, i0)
- T.bind(i, T.floordiv(i1_i2_fused_outer, 256))
- T.bind(j, T.floormod(i1_i2_fused_outer, 256))
- with T.init():
- C_rf[vi1_i2_fused_inner, b] = 0.0
- C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j])
-
- for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1):
- with T.block([T.reduce_axis(0, 1), 16], "C") as [vi1_i2_fused_inner_1, b_1]:
- T.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1)
- T.bind(b_1, i0_1)
- with T.init():
- C[b_1] = 0.0
- C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1]
+def matmul_decompose2(a: T.handle, b: T.handle, c: T.handle) -> None:
+ C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
+ B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1)
+ A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
- for i0_2 in T.serial(0, 16):
- with T.block([16], "D") as [b_2]:
- T.bind(b_2, i0_2)
- D[b_2] = T.sqrt(C[b_2], dtype="float32")
+ for i0, i1 in T.grid(128, 128):
+ with T.block([128, 128], "update_init") as [vi_init, vj_init]:
+ C[vi_init, vj_init] = T.float32(0)
+ for i2 in T.serial(0, 128):
+ with T.block([128, 128, T.reduce_axis(0, 128)], "update_update") as [vi, vj, vk]:
+ C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
@T.prim_func
-def element_wise(a: T.handle, b: T.handle) -> None:
- A = T.match_buffer(a, (128, 128))
- B = T.match_buffer(b, (128, 128))
-
- with T.block([128, 128], "B") as [vi, vj]:
- B[vi, vj] = A[vi, vj] * 2.0
-
-
-@T.prim_func
-def rowsum(a: T.handle, b: T.handle) -> None:
- A = T.match_buffer(a, (128, 128))
- B = T.match_buffer(b, (128,))
-
- with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
- with T.init():
- B[vi] = 0.0
- B[vi] = B[vi] + A[vi, vk]
-
-
-@T.prim_func
-def rowsum_not_quasi_affine(a: T.handle, b: T.handle) -> None:
- A = T.match_buffer(a, (128, 128))
- B = T.match_buffer(b, (128,))
-
- for i, k in T.grid(128, 16):
- with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
- T.bind(vi, i)
- T.bind(vk, T.floordiv(k * k, 2))
- with T.init():
- B[vi] = 0.0
- B[vi] = B[vi] + A[vi, vk]
-
-
-@T.prim_func
-def rowsum_not_dominant(a: T.handle, b: T.handle) -> None:
- A = T.match_buffer(a, (128, 128))
- B = T.match_buffer(b, (128, 128))
-
- with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
- with T.init():
- B[vi, vk] = 0.0
- B[vi, vk] = B[vi, vk] + A[vi, vk]
-
-
-@T.prim_func
-def rowsum_not_serial(a: T.handle, b: T.handle) -> None:
- A = T.match_buffer(a, (128, 128))
- B = T.match_buffer(b, (128,))
-
- for i in T.serial(0, 128):
- for k in T.parallel(0, 128):
- with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
- T.bind(vi, i)
- T.bind(vk, k)
- with T.init():
- B[vi] = 0.0
- B[vi] = B[vi] + A[vi, vk]
-
-
-@T.prim_func
-def rowsum_wrong_reduce_pattern1(a: T.handle, b: T.handle) -> None:
- A = T.match_buffer(a, (128, 128))
- B = T.match_buffer(b, (128,))
-
- with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
- with T.init():
- B[vi] = 1.0
- B[vi] = B[vi] + A[vi, vk]
-
-
-@T.prim_func
-def rowsum_wrong_reduce_pattern2(a: T.handle, b: T.handle) -> None:
- A = T.match_buffer(a, (128, 128))
- B = T.match_buffer(b, (128,))
-
- with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
- with T.init():
- B[vi] = 0.0
- B[vi] = B[vi] - A[vi, vk]
-
-
-@T.prim_func
-def rowsum_transformed(a: T.handle, b: T.handle) -> None:
- A = T.match_buffer(a, (128, 128))
- B = T.match_buffer(b, (128,))
-
- for io, ii_ko_fused, ki in T.grid(32, 128, 4):
- with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
- T.bind(vi, io * 4 + T.floordiv(ii_ko_fused, 32))
- T.bind(vk, T.floormod(ii_ko_fused, 32) * 4 + ki)
- with T.init():
- B[vi] = 0.0
- B[vi] = B[vi] + A[vi, vk]
-
-
-@T.prim_func
-def rowsum_zero_dim(a: T.handle, b: T.handle) -> None:
- A = T.match_buffer(a, [128])
- B = T.match_buffer(b, [])
-
- with T.block([T.reduce_axis(0, 128)], "B") as [k]:
- with T.init():
- B[()] = 0.0
- B[()] = B[()] + A[k]
-
-
-@T.prim_func
-def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None:
- A = T.match_buffer(a, [128])
- B = T.match_buffer(b, [])
- B_rf = T.alloc_buffer([128])
-
- with T.block([128], "B_rf") as [vi0]:
- with T.init():
- B_rf[vi0] = 0.0
- B_rf[vi0] = B_rf[vi0] + A[vi0]
-
- with T.block([T.reduce_axis(0, 128)], "B") as [vi0_1]:
- with T.init():
- B[()] = 0.0
- B[()] = B[()] + B_rf[vi0_1]
-
-
-@T.prim_func
-def multiple_reduction_blocks(a: T.handle, f: T.handle) -> None:
- A = T.match_buffer(a, (16, 16, 16))
- C = T.alloc_buffer((16, 16))
- D = T.alloc_buffer((16, 16))
- E = T.alloc_buffer((16, 16))
- F = T.match_buffer(f, (16, 16))
-
- for i in T.serial(0, 16):
- for j1 in T.serial(0, 16):
- for k1o, k1i in T.grid(4, 4):
- with T.block([16, 16, T.reduce_axis(0, 16)], "C") as [ci, cj, ck]:
- T.bind(ci, i)
- T.bind(cj, j1)
- T.bind(ck, k1o * 4 + k1i)
- with T.init():
- C[ci, cj] = 0.0
- C[ci, cj] = C[ci, cj] + A[ci, cj, ck]
- for k2o, k2i in T.grid(4, 4):
- with T.block([16, 16, T.reduce_axis(0, 16)], "D") as [di, dj, dk]:
- T.bind(di, i)
- T.bind(dj, j1)
- T.bind(dk, k2o * 4 + k2i)
- with T.init():
- D[di, dj] = 0.0
- D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj]
- for j2 in T.serial(0, 16):
- for k3o, k3i in T.grid(4, 4):
- with T.block([16, 16, T.reduce_axis(0, 16)], "E") as [ei, ej, ek]:
- T.bind(ei, i)
- T.bind(ej, j2)
- T.bind(ek, k3o * 4 + k3i)
- with T.init():
- E[ei, ej] = 0.0
- E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej]
- for k4o, k4i in T.grid(4, 4):
- with T.block([16, 16, T.reduce_axis(0, 16)], "F") as [fi, fj, fk]:
- T.bind(fi, i)
- T.bind(fj, j2)
- T.bind(fk, k4o * 4 + k4i)
- with T.init():
- F[fi, fj] = 0.0
- F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj]
-
+def matmul_decompose_fail3(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128])
+ B = T.match_buffer(b, [128, 128])
+ C = T.match_buffer(c, [128, 128])
-@T.prim_func
-def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None:
- A = T.match_buffer(a, [16, 16, 16])
- C = T.alloc_buffer([16, 16])
- D = T.alloc_buffer([16, 16])
- E = T.alloc_buffer([16, 16])
- F = T.match_buffer(f, [16, 16])
- C_rf = T.alloc_buffer([16, 16, 4])
-
- for i, j1, k1o, k1i in T.grid(16, 16, 4, 4):
- with T.block([4, 16, 16, T.reduce_axis(0, 4)], "C_rf") as [vk1o, ci, cj, vk1i]:
- T.bind(vk1o, k1o)
- T.bind(ci, i)
- T.bind(cj, j1)
- T.bind(vk1i, k1i)
+ for i, k, j in T.grid(128, 128, 128):
+ with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
with T.init():
- C_rf[ci, cj, vk1o] = 0.0
- C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, ((vk1o * 4) + vk1i)]
- for i_1 in T.serial(0, 16):
- for j1_1 in T.serial(0, 16):
- for k1o_1 in T.serial(0, 4):
- with T.block([T.reduce_axis(0, 4), 16, 16], "C") as [vk1o_1, ci_1, cj_1]:
- T.bind(vk1o_1, k1o_1)
- T.bind(ci_1, i_1)
- T.bind(cj_1, j1_1)
- with T.init():
- C[ci_1, cj_1] = 0.0
- C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1]
- for k2o, k2i in T.grid(4, 4):
- with T.block([16, 16, T.reduce_axis(0, 16)], "D") as [di, dj, dk]:
- T.bind(di, i_1)
- T.bind(dj, j1_1)
- T.bind(dk, (k2o * 4) + k2i)
- with T.init():
- D[di, dj] = 0.0
- D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj]
- for j2 in T.serial(0, 16):
- for k3o, k3i in T.grid(4, 4):
- with T.block([16, 16, T.reduce_axis(0, 16)], "E") as [ei, ej, ek]:
- T.bind(ei, i_1)
- T.bind(ej, j2)
- T.bind(ek, (k3o * 4) + k3i)
- with T.init():
- E[ei, ej] = 0.0
- E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej]
- for k4o, k4i in T.grid(4, 4):
- with T.block([16, 16, T.reduce_axis(0, 16)], "F") as [fi, fj, fk]:
- T.bind(fi, i_1)
- T.bind(fj, j2)
- T.bind(fk, (k4o * 4) + k4i)
- with T.init():
- F[fi, fj] = 0.0
- F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj]
+ C[vi, vj] = 0.0
+ T.bind(vi, i)
+ T.bind(vj, j)
+ T.bind(vk, k)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+@T.prim_func
+def matmul_decompose4(a: T.handle, b: T.handle, c: T.handle) -> None:
+ C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
+ B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1)
+ A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
+ # body
+ with T.block([], "root"):
+ T.reads([])
+ T.writes([])
+ for i0_0 in T.serial(0, 16):
+ for i0_1_init, i1_init in T.grid(8, 128):
+ with T.block([128, 128], "update_init") as [vi_init, vj_init]:
+ T.bind(vi_init, ((i0_0 * 8) + i0_1_init))
+ T.bind(vj_init, i1_init)
+ C[vi_init, vj_init] = T.float32(0)
+ for i0_1, i1, i2_0, i2_1 in T.grid(8, 128, 19, 7):
+ with T.block([128, 128, T.reduce_axis(0, 128)], "update_update") as [
+ vi,
+ vj,
+ vk,
+ ]:
+ T.where((((i2_0 * 7) + i2_1) < 128))
+ T.bind(vi, ((i0_0 * 8) + i0_1))
+ T.bind(vj, i1)
+ T.bind(vk, ((i2_0 * 7) + i2_1))
+ C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
-def test_reduction_rfactor_matmul():
- s = tir.Schedule(transformed_matmul, debug_mask="all")
- update = s.get_block("update")
- _, _, _, _, kii = s.get_loops(update)
- rf_block = s.rfactor(kii, 0)
- tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor)
- assert s.get(rf_block).same_as(s.get(s.get_block("update_rf")))
- assert s.get(update).same_as(s.get(s.get_block("update")))
- verify_trace_roundtrip(s, mod=transformed_matmul)
-
-
-def test_reduction_rfactor_square_sum():
- s = tir.Schedule(square_sum, debug_mask="all")
- C = s.get_block("C")
- _, _, j = s.get_loops(C)
- rf_block = s.rfactor(j, 1)
- tvm.ir.assert_structural_equal(s.mod["main"], square_sum_rfactor)
- assert s.get(rf_block).same_as(s.get(s.get_block("C_rf")))
- assert s.get(C).same_as(s.get(s.get_block("C")))
- verify_trace_roundtrip(s, mod=square_sum)
-
-
-def test_reduction_rfactor_square_sum_square_root():
- s = tir.Schedule(transformed_square_sum_square_root, debug_mask="all")
- C = s.get_block("C")
- _, _, f_i = s.get_loops(C)
- rf_block = s.rfactor(f_i, 0)
- tvm.ir.assert_structural_equal(s.mod["main"], square_sum_square_root_rfactor)
- assert s.get(rf_block).same_as(s.get(s.get_block("C_rf")))
- assert s.get(C).same_as(s.get(s.get_block("C")))
- verify_trace_roundtrip(s, mod=transformed_square_sum_square_root)
-
-
-def test_reduction_rfactor_loop_multiple_children():
- s = tir.Schedule(matmul_loop_multiple_children, debug_mask="all")
- k, _, _ = s.get_loops(s.get_block("C"))
- with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(k, 0)
-
-
-def test_reduction_rfactor_not_stage_pipeline():
- s = tir.Schedule(matmul_not_stage_pipeline, debug_mask="all")
- _, _, k = s.get_loops(s.get_block("C"))
- with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(k, 0)
-
-
-def test_reduction_rfactor_not_reduction_block1():
- s = tir.Schedule(element_wise, debug_mask="all")
- i, _ = s.get_loops(s.get_block("B"))
- with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(i, 0)
-
-
-def test_reduction_rfactor_not_reduction_block2():
- s = tir.Schedule(rowsum_not_quasi_affine, debug_mask="all")
- _, k = s.get_loops(s.get_block("B"))
- with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(k, 0)
-
-
-def test_reduction_rfactor_not_reduction_block3():
- s = tir.Schedule(rowsum_not_dominant, debug_mask="all")
- _, k = s.get_loops(s.get_block("B"))
- with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(k, 0)
-
-
-def test_reduction_rfactor_not_serial_loop():
- s = tir.Schedule(rowsum_not_serial, debug_mask="all")
- _, k = s.get_loops(s.get_block("B"))
- with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(k, 0)
-
-
-def test_reduction_rfactor_not_same_buffer_access():
- s = tir.Schedule(matmul_not_same_buffer_access, debug_mask="all")
- _, _, k = s.get_loops(s.get_block("C"))
- with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(k, 0)
-
-
-def test_reduction_rfactor_factor_axis_range_fail():
- s = tir.Schedule(transformed_matmul, debug_mask="all")
- _, _, _, _, kii = s.get_loops(s.get_block("update"))
- with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(kii, 3)
- with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(kii, -4)
-
-
-def test_reduction_rfactor_factor_axis_range():
- s = tir.Schedule(transformed_matmul, debug_mask="all")
- update = s.get_block("update")
- _, _, _, _, kii = s.get_loops(update)
- rf_block = s.rfactor(kii, -3)
- tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor)
- assert s.get(rf_block).same_as(s.get(s.get_block("update_rf")))
- assert s.get(update).same_as(s.get(s.get_block("update")))
- verify_trace_roundtrip(s, mod=transformed_matmul)
-
-
-def test_reduction_rfactor_wrong_reduce_pattern1():
- s = tir.Schedule(rowsum_wrong_reduce_pattern1, debug_mask="all")
- _, k = s.get_loops(s.get_block("B"))
- with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(k, 0)
+def test_reduction_decompose0():
+ s = tir.Schedule(matmul, debug_mask="all")
+ C = s.get_block("update")
+ i, j, k = s.get_loops(C)
+ s.decompose_reduction(C, i)
+ tvm.ir.assert_structural_equal(matmul_decompose0, s.mod["main"])
+ verify_trace_roundtrip(s, mod=matmul)
-def test_reduction_rfactor_wrong_reduce_pattern2():
- s = tir.Schedule(rowsum_wrong_reduce_pattern2, debug_mask="all")
- _, k = s.get_loops(s.get_block("B"))
- with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(k, 0)
+def test_reduction_decompose1():
+ s = tir.Schedule(rowsum_blockized, debug_mask="all")
+ blockized_B = s.get_block("blockized_B")
+ io, ko = s.get_loops(blockized_B)
+ s.decompose_reduction(blockized_B, io)
+ tvm.ir.assert_structural_equal(matmul_decompose1, s.mod["main"])
+ verify_trace_roundtrip(s, mod=rowsum_blockized)
-def test_reduction_rfactor_wrong_loops1():
- s = tir.Schedule(rowsum, debug_mask="all")
- i, _ = s.get_loops(s.get_block("B"))
- with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(i, 0)
+def test_reduction_decompose2():
+ s = tir.Schedule(matmul, debug_mask="all")
+ C = s.get_block("update")
+ i, j, k = s.get_loops(C)
+ s.decompose_reduction(C, k)
+ tvm.ir.assert_structural_equal(matmul_decompose2, s.mod["main"])
+ verify_trace_roundtrip(s, mod=matmul)
-def test_reduction_rfactor_wrong_loops2():
- s = tir.Schedule(rowsum_transformed, debug_mask="all")
- _, _, k_i = s.get_loops(s.get_block("B"))
+def test_reduction_decompose3():
+ s = tir.Schedule(matmul_decompose_fail3, debug_mask="all")
+ C = s.get_block("update")
+ i, j, k = s.get_loops(C)
with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(k_i, 0)
+ s.decompose_reduction(C, k)
-def test_reduction_rfactor_zero_dim():
- s = tir.Schedule(rowsum_zero_dim, debug_mask="all")
- B = s.get_block("B")
- (k,) = s.get_loops(B)
- rf_block = s.rfactor(k, 0)
- tvm.ir.assert_structural_equal(s.mod["main"], rowsum_zero_dim_rfactor)
- assert s.get(rf_block).same_as(s.get(s.get_block("B_rf")))
- assert s.get(B).same_as(s.get(s.get_block("B")))
- verify_trace_roundtrip(s, mod=rowsum_zero_dim)
-
-
-def test_reduction_rfactor_outermost_loop_multiple_children_fail(): # pylint: disable=invalid-name
- s = tir.Schedule(multiple_reduction_blocks, debug_mask="all")
- _, _, k2o, k2i = s.get_loops(s.get_block("D"))
- _, _, k3o, k3i = s.get_loops(s.get_block("E"))
- _, _, k4o, k4i = s.get_loops(s.get_block("F"))
- with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(k2o, 0)
- with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(k2i, 0)
- with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(k3o, 0)
- with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(k3i, 0)
- with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(k4o, 0)
- with pytest.raises(tvm.tir.ScheduleError):
- s.rfactor(k4i, 0)
-
-
-def test_reduction_rfactor_outermost_loop_multiple_children(): # pylint: disable=invalid-name
- s = tir.Schedule(multiple_reduction_blocks, debug_mask="all")
- C = s.get_block("C")
- _, _, k1o, _ = s.get_loops(C)
- rf_block = s.rfactor(k1o, 2)
- tvm.ir.assert_structural_equal(s.mod["main"], multiple_reduction_blocks_rfactor)
- assert s.get(rf_block).same_as(s.get(s.get_block("C_rf")))
- assert s.get(C).same_as(s.get(s.get_block("C")))
- verify_trace_roundtrip(s, mod=multiple_reduction_blocks)
+def test_reduction_decompose4():
+ s = tir.Schedule(matmul, debug_mask="all")
+ C = s.get_block("update")
+ i, j, k = s.get_loops(C)
+ io, ii = s.split(i, factors=[16, 8])
+ ko, ki = s.split(k, factors=[19, 7])
+ s.decompose_reduction(C, ii)
+ tvm.ir.assert_structural_equal(matmul_decompose4, s.mod["main"])
+ verify_trace_roundtrip(s, mod=matmul)
if __name__ == "__main__":
diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_rfactor.py
similarity index 100%
copy from tests/python/unittest/test_tir_schedule_reduction.py
copy to tests/python/unittest/test_tir_schedule_rfactor.py