You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2021/12/15 05:32:04 UTC
[tvm] branch main updated: [TIR][Schedule] Analysis functions to check if compute_inline and com… (#9743)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ecc2e56 [TIR][Schedule] Analysis functions to check if compute_inline and com… (#9743)
ecc2e56 is described below
commit ecc2e563df1a0b1d7e9d712bce90ee94948c3848
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Wed Dec 15 00:31:45 2021 -0500
[TIR][Schedule] Analysis functions to check if compute_inline and com… (#9743)
* [TIR][Schedule] Analysis functions to check if compute_inline and compute_inline is allowed
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
Co-authored-by: Xiyou Zhou <xi...@octoml.ai>
* Address comments
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Xiyou Zhou <xi...@octoml.ai>
---
src/tir/schedule/analysis.h | 41 ++++++++++++++
src/tir/schedule/primitive/compute_at.cc | 46 ++++++++++++---
src/tir/schedule/primitive/compute_inline.cc | 66 +++++++++++++++++++---
.../unittest/test_tir_schedule_compute_inline.py | 29 ++++++++++
4 files changed, 168 insertions(+), 14 deletions(-)
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 42e0e00..82f4afa 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -393,6 +393,47 @@ std::vector<runtime::TypedPackedFunc<CommReducer(DataType)>> GetReducerGetters()
bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner,
CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs);
+/******** Misc ********/
+
+/*!
+ * \brief Checks if a block could be successfully computed inline into its consumer
+ * \param self The schedule state
+ * \param block_sref The block to be checked
+ * \return A boolean indicating whether the block could be successfully computed inline
+ */
+bool CanComputeInline(const ScheduleState& self, const StmtSRef& block_sref);
+
+/*!
+ * \brief Checks if a block could be successfully computed inline into its producer
+ * \param self The schedule state
+ * \param block_sref The block to be checked
+ * \return A boolean indicating whether the block could be successfully computed inline
+ */
+bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref);
+
+/*!
+ * \brief Checks if a producer block could be successfully computed at the specific loop.
+ * \param self The schedule state
+ * \param block_sref The block to be moved
+ * \param loop_sref The loop where the block to be moved to
+ * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
+ * \return A boolean indicating whether the block could be successfully compute at the specific loop
+ */
+bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
+ bool preserve_unit_loops);
+
+/*!
+ * \brief Checks if a consumer block could be successfully computed at the specific loop.
+ * \param self The schedule state
+ * \param block_sref The block to be moved
+ * \param loop_sref The loop where the block to be moved to
+ * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
+ * \return A boolean indicating whether the block could be successfully reverse compute at the
+ * specific loop
+ */
+bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
+ const StmtSRef& loop_sref, bool preserve_unit_loops);
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc
index 0dae50a..00886e8 100644
--- a/src/tir/schedule/primitive/compute_at.cc
+++ b/src/tir/schedule/primitive/compute_at.cc
@@ -451,7 +451,8 @@ void CalculateProvidedRequiredRegions(
template <bool is_compute_at>
void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref,
- const StmtSRef& loop_sref, bool preserve_unit_loops) {
+ const StmtSRef& loop_sref, bool preserve_unit_loops,
+ arith::Analyzer* analyzer, bool check_only = false) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
// Step 1. Bunch of checks
@@ -463,11 +464,10 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
BlockScope scope = self->GetBlockScope(scope_root_sref);
Array<StmtSRef> producer_srefs = GetProducers(block_sref, scope);
Array<StmtSRef> consumer_srefs = GetConsumers(block_sref, scope);
- arith::Analyzer analyzer;
// Check condition 3): `block` and `loop` are under the same scope,
// and `loop` is not the ancestor of `block`
NotInSameScopeError::CheckAndBindLoopDomain(self, block_sref, loop_sref, scope_root_sref,
- &analyzer);
+ analyzer);
// Check condition 4): `block` is not an output block
if (is_compute_at) {
CheckNotOutputBlock(self, block_sref, scope_root_sref);
@@ -501,29 +501,61 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
CalculateBlockVarDomain(/*iter_vars=*/block->iter_vars,
/*provided_regions=*/std::move(provided_regions),
/*required_regions=*/std::move(required_regions),
- /*analyzer=*/&analyzer);
+ /*analyzer=*/analyzer);
// Step 6. Create the new scope according to the iteration domain
reconstructor.MakeNewLoop(/*insert_position=*/insert_position, /*iter_doms=*/std::move(iter_doms),
/*preserve_unit_loops=*/preserve_unit_loops);
Block new_scope_root = Downcast<Block>(reconstructor(scope_root));
+
// Step 7. Do the actual replacement
+ if (check_only) {
+ return;
+ }
self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}});
// Step 8. Update the cached flags
BlockInfo& block_info = self->block_info[block_sref];
block_info.affine_binding = IsAffineBinding(
/*realize=*/reconstructor.new_block_realize_,
/*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef<StmtSRef>(block_sref->parent)),
- /*analyzer=*/&analyzer);
+ /*analyzer=*/analyzer);
}
void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops) {
- ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops);
+ arith::Analyzer analyzer;
+ ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops,
+ &analyzer);
}
void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops) {
- ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops);
+ arith::Analyzer analyzer;
+ ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops,
+ &analyzer);
+}
+
+bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
+ bool preserve_unit_loops) {
+ arith::Analyzer analyzer;
+ try {
+ ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops,
+ &analyzer, true);
+ } catch (const tvm::runtime::Error& e) {
+ return false;
+ }
+ return true;
+}
+
+bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
+ const StmtSRef& loop_sref, bool preserve_unit_loops) {
+ arith::Analyzer analyzer;
+ try {
+ ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops,
+ &analyzer, true);
+ } catch (const tvm::runtime::Error& e) {
+ return false;
+ }
+ return true;
}
/******** InstructionKind Registration ********/
diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc
index 12ae021..fe2c679 100644
--- a/src/tir/schedule/primitive/compute_inline.cc
+++ b/src/tir/schedule/primitive/compute_inline.cc
@@ -60,11 +60,27 @@ class NotSingleReadWriteBuffer : public ScheduleError {
bool is_read_;
Block block_;
- static Buffer GetSingleRead(const ScheduleState& self, const Block& block) {
- if (block->reads.size() != 1) {
+ static Buffer GetSingleRead(const ScheduleState& self, const Block& block,
+ const StmtSRef& scope_root_sref) {
+ const std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual>&
+ buffer_writers = self->block_info.at(scope_root_sref).scope->buffer_writers;
+ const BufferNode* read_buffer = nullptr;
+ for (const BufferRegion& read_region : block->reads) {
+ const BufferNode* buffer = read_region->buffer.get();
+ if (buffer == read_buffer) {
+ continue;
+ }
+ if (buffer_writers.count(GetRef<Buffer>(buffer)) > 0) {
+ if (read_buffer != nullptr) {
+ throw NotSingleReadWriteBuffer(self->mod, true, block);
+ }
+ read_buffer = buffer;
+ }
+ }
+ if (read_buffer == nullptr) {
throw NotSingleReadWriteBuffer(self->mod, true, block);
}
- return block->reads[0]->buffer;
+ return GetRef<Buffer>(read_buffer);
}
static Buffer GetSingleWrite(const ScheduleState& self, const Block& block) {
@@ -167,7 +183,7 @@ class OpaqueAccessError : public ScheduleError {
* \brief The base class of the inliner, which handles:
* 1) Substitute a subtree with the specific block being inlined
* 2) Update the block signature to reflect the changes of read/write/allocated buffers
- * 3) Maintain a list of index variables and their substition of the buffer being inlined
+ * 3) Maintain a list of index variables and their substitution of the buffer being inlined
*/
class BaseInliner : public StmtExprMutator {
protected:
@@ -526,7 +542,8 @@ class ReverseComputeInliner : public BaseInliner {
PrimExpr producer_rhs_{nullptr};
};
-void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) {
+void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref,
+ bool check_only = false) {
const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(_producer_block, producer_block_sref);
Block producer_block = GetRef<Block>(_producer_block);
Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block);
@@ -535,6 +552,7 @@ void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) {
/*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/false);
// Step 2. Check completeness
+ CheckNotOutputBlock(self, producer_block_sref, scope_root_sref);
CheckCompleteBlock(self, producer_block_sref, scope_root_sref);
// Step 3. Analyze the block body
ComputeInliner inliner(inlined_buffer, producer_block, scope_root_sref);
@@ -550,17 +568,35 @@ void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) {
throw OpaqueAccessError(self->mod, scope_root_sref);
}
// Step 6. Do the real mutation on the AST and the sref tree in the schedule state
+ if (check_only) {
+ return;
+ }
self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse);
}
-void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sref) {
+void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) {
+ ComputeInlineImpl(self, producer_block_sref);
+}
+
+bool CanComputeInline(const ScheduleState& self, const StmtSRef& producer_block_sref) {
+ try {
+ ComputeInlineImpl(self, producer_block_sref, true);
+ } catch (const tvm::runtime::Error& e) {
+ return false;
+ }
+ return true;
+}
+
+void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block_sref,
+ bool check_only = false) {
const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(_consumer_block, consumer_block_sref);
Block consumer_block = GetRef<Block>(_consumer_block);
- Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block);
// Step 1. Get the scope block
StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, //
/*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/false);
+ Buffer inlined_buffer =
+ NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block, scope_root_sref);
// Step 2. Check completeness
CheckCompleteBlock(self, consumer_block_sref, scope_root_sref);
// Step 3. Check if the consumer has a single complete producer
@@ -579,9 +615,25 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre
throw OpaqueAccessError(self->mod, scope_root_sref);
}
// Step 7. Do the real mutation on the AST and the sref tree in the schedule state
+ if (check_only) {
+ return;
+ }
self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse);
}
+bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref) {
+ try {
+ ReverseComputeInlineImpl(self, block_sref, true);
+ } catch (const tvm::runtime::Error& e) {
+ return false;
+ }
+ return true;
+}
+
+void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sref) {
+ ReverseComputeInlineImpl(self, consumer_block_sref);
+}
+
/******** InstructionKind Registration ********/
struct ComputeInlineTraits : public UnpackedInstTraits<ComputeInlineTraits> {
diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py
index a078c0e..5cc36c0 100644
--- a/tests/python/unittest/test_tir_schedule_compute_inline.py
+++ b/tests/python/unittest/test_tir_schedule_compute_inline.py
@@ -329,6 +329,28 @@ def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None:
B[vi] = A_cache[vi] * 2.0 + 1.0
+@T.prim_func
+def matmul_relu(var_A: T.handle, var_B: T.handle, var_compute: T.handle) -> None:
+ A = T.match_buffer(var_A, [512, 512], dtype="float32")
+ B = T.match_buffer(var_B, [512, 512], dtype="float32")
+ compute = T.match_buffer(var_compute, [512, 512], dtype="float32")
+ C = T.alloc_buffer([512, 512], dtype="float32")
+ for i0, i1, i2 in T.grid(512, 512, 512):
+ with T.block("C"):
+ i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+ T.reads([C[i, j], A[i, k], B[k, j]])
+ T.writes([C[i, j]])
+ with T.init():
+ C[i, j] = T.float32(0)
+ C[i, j] = C[i, j] + A[i, k] * B[k, j]
+ for i0, i1 in T.grid(512, 512):
+ with T.block("compute"):
+ i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
+ T.reads([C[i0_1, i1_1]])
+ T.writes([compute[i0_1, i1_1]])
+ compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0))
+
+
# pylint: enable=no-member,invalid-name,unused-variable
@@ -458,6 +480,13 @@ def test_buffer_matched():
sch.compute_inline(block_b)
+def test_output_block():
+ sch = tir.Schedule(matmul_relu, debug_mask="all")
+ block = sch.get_block("compute")
+ with pytest.raises(tvm.tir.ScheduleError):
+ sch.compute_inline(block)
+
+
def test_compute_inline_predicate():
sch = tir.Schedule(elementwise_predicate, debug_mask="all")
block_b = sch.get_block("B")