You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/01/18 06:46:52 UTC
[tvm] 01/01: Revert "[TIR] Encode conditional accesses info into block read/write regions (#9880)"
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch revert-9880-encode_conditional_accesses_in_read_write_annotations
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 2e14b8565515b3940f1624de61fc23d3446c4cbe
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Mon Jan 17 22:46:07 2022 -0800
Revert "[TIR] Encode conditional accesses info into block read/write regions (#9880)"
This reverts commit 6f6fc68f5a028a92607b2907b9e4144543686639.
---
src/tir/analysis/block_access_region_detector.cc | 29 ++--------
src/tir/transforms/compact_buffer_region.cc | 10 ++--
src/tir/transforms/ir_utils.cc | 62 +++++---------------
src/tir/transforms/ir_utils.h | 18 +++---
.../test_tir_analysis_get_block_access_region.py | 66 ----------------------
.../test_tir_transform_compact_buffer_region.py | 1 -
6 files changed, 30 insertions(+), 156 deletions(-)
diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc
index 07dcace..776538a 100644
--- a/src/tir/analysis/block_access_region_detector.cc
+++ b/src/tir/analysis/block_access_region_detector.cc
@@ -56,8 +56,6 @@ class BlockReadWriteDetector : public StmtExprVisitor {
private:
/*! \brief Iteration range for loop_vars */
std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
- /*! \brief Extra iteration range hint for free vars */
- std::unordered_map<const VarNode*, arith::IntSet> hint_map_;
/*! \brief The buffers that the current block reads */
std::vector<Buffer> read_buffers_;
/*! \brief The buffers that the current block writes */
@@ -98,9 +96,6 @@ class BlockReadWriteDetector : public StmtExprVisitor {
/*! \brief Helper function to update a opaque access. */
void UpdateOpaque(const Var& buffer_var);
- /*! \brief Helper function to relax the buffer indices */
- arith::IntSet RelaxAccessIndex(const PrimExpr& index);
-
void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const BlockRealizeNode* op) override;
@@ -145,22 +140,10 @@ void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) {
ExprVisitor::VisitExpr_(op);
}
-arith::IntSet BlockReadWriteDetector::RelaxAccessIndex(const PrimExpr& index) {
- arith::IntSet relaxed = arith::EvalSet(index, dom_map_);
- if (!hint_map_.empty()) {
- // take non-relaxed var bound hints into considerations
- // eg, if i * 4 + j with i >= 10 and j in [0, 4), only j in domain scope
- // then the index region can be relaxed to [i*4, i*4+4) ^ [40, inf)
- arith::IntSet hint_bound = arith::EvalSet(relaxed, hint_map_);
- relaxed = arith::Intersect({relaxed, hint_bound});
- }
- return relaxed;
-}
-
void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
std::vector<arith::IntSet> relaxed_region;
for (const PrimExpr& index : op->indices) {
- relaxed_region.push_back(RelaxAccessIndex(index));
+ relaxed_region.push_back(arith::EvalSet(index, dom_map_));
}
Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
ExprVisitor::VisitExpr_(op);
@@ -177,12 +160,12 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) {
VisitExpr(op->condition);
{
// Visit then branch
- With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true);
+ With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, true);
StmtExprVisitor::VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
// Visit else branch
- With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false);
+ With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, false);
StmtExprVisitor::VisitStmt(op->else_case);
}
}
@@ -192,12 +175,12 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
VisitExpr(op->args[0]);
{
// Visit then branch
- With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, true);
+ With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, true);
StmtExprVisitor::VisitExpr(op->args[1]);
}
{
// Visit else branch
- With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, false);
+ With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, false);
StmtExprVisitor::VisitExpr(op->args[2]);
}
return;
@@ -213,7 +196,7 @@ void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) {
void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
std::vector<arith::IntSet> relaxed_region;
for (const PrimExpr& index : op->indices) {
- relaxed_region.push_back(RelaxAccessIndex(index));
+ relaxed_region.push_back(arith::EvalSet(index, dom_map_));
}
Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
StmtVisitor::VisitStmt_(op);
diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc
index 20ddd7f..07f9778 100644
--- a/src/tir/transforms/compact_buffer_region.cc
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -123,12 +123,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
StmtExprVisitor::VisitExpr(op->condition);
{
// Visit then branch
- With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true);
+ With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, true);
StmtExprVisitor::VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
// Visit else branch
- With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false);
+ With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, false);
StmtExprVisitor::VisitStmt(op->else_case);
}
}
@@ -139,12 +139,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
StmtExprVisitor::VisitExpr(op->args[0]);
{
// Visit then branch
- With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, true);
+ With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, true);
StmtExprVisitor::VisitExpr(op->args[1]);
}
{
// Visit else branch
- With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, false);
+ With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, false);
StmtExprVisitor::VisitExpr(op->args[2]);
}
return;
@@ -282,8 +282,6 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
/*! \brief The map from loop vars to their iter range. */
std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
- /*! \brief Extra map from free vars to their iter range hints. */
- std::unordered_map<const VarNode*, arith::IntSet> hint_map_;
/*! \brief The analyzer aware of loop domains. */
arith::Analyzer dom_analyzer_;
/*! \brief The map from Buffer to it's relaxed access set. */
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index bc2f7ad..2423b09 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -300,18 +300,11 @@ Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() {
Array<Var> vars = Array<Var>(var_set.begin(), var_set.end());
Map<Var, Range> ranges;
for (const Var& v : vars) {
- arith::IntSet dom;
- auto relax_it = relax_map_->find(v.get());
- if (relax_it != relax_map_->end()) {
- dom = relax_it->second;
- } else {
- auto hint_it = hint_map_->find(v.get());
- if (hint_it != hint_map_->end()) {
- dom = hint_it->second;
- }
- }
- if (dom.defined()) {
- ranges.Set(v, Range::FromMinExtent(dom.min(), analyzer.Simplify(dom.max() - dom.min() + 1)));
+ auto it = dom_map_->find(v.get());
+ if (it != dom_map_->end()) {
+ const auto& int_set = it->second;
+ ranges.Set(v, Range::FromMinExtent(int_set.min(),
+ analyzer.Simplify(int_set.max() - int_set.min() + 1)));
}
}
// solve constraints
@@ -321,53 +314,24 @@ Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() {
}
ConditionalBoundsContext::ConditionalBoundsContext(
- const PrimExpr& condition, std::unordered_map<const VarNode*, arith::IntSet>* relax_map,
- std::unordered_map<const VarNode*, arith::IntSet>* hint_map, bool is_true_branch)
- : condition_(condition),
- relax_map_(relax_map),
- hint_map_(hint_map),
- is_true_branch_(is_true_branch) {}
+ const PrimExpr& condition, std::unordered_map<const VarNode*, arith::IntSet>* dom_map,
+ bool is_true_branch)
+ : condition_(condition), dom_map_(dom_map), is_true_branch_(is_true_branch) {}
void ConditionalBoundsContext::EnterWithScope() {
for (const auto& p : GetVarBoundsFromCondition()) {
const auto* var = p.first.get();
- arith::IntSet new_dom = arith::IntSet::FromRange(p.second);
- auto relax_it = relax_map_->find(var);
- if (relax_it != relax_map_->end()) {
- // this is a bound for relaxed var
- origin_map_.emplace(var, relax_it->second);
- relax_it->second = arith::Intersect({relax_it->second, new_dom});
- } else {
- // this is a bound for free var
- auto hint_it = hint_map_->find(var);
- if (hint_it != hint_map_->end()) {
- origin_map_.emplace(var, hint_it->second);
- hint_it->second = arith::Intersect({hint_it->second, new_dom});
- } else {
- origin_map_.emplace(var, arith::IntSet::Nothing());
- hint_map_->insert(hint_it, {var, new_dom});
- }
+ auto it = dom_map_->find(var);
+ if (it != dom_map_->end()) {
+ origin_map_.emplace(var, it->second);
+ it->second = arith::Intersect({it->second, arith::IntSet::FromRange(p.second)});
}
}
}
void ConditionalBoundsContext::ExitWithScope() {
for (const auto& p : origin_map_) {
- const auto* var = p.first;
- auto relax_it = relax_map_->find(var);
- if (relax_it != relax_map_->end()) {
- // recover bound for relaxed var
- relax_it->second = p.second;
- } else {
- // recover bound for free var
- auto hint_it = hint_map_->find(var);
- ICHECK(hint_it != hint_map_->end());
- if (p.second.IsNothing()) {
- hint_map_->erase(hint_it);
- } else {
- hint_it->second = p.second;
- }
- }
+ (*dom_map_)[p.first] = p.second;
}
}
diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h
index da52a82..7b1d34c 100644
--- a/src/tir/transforms/ir_utils.h
+++ b/src/tir/transforms/ir_utils.h
@@ -231,9 +231,9 @@ Bool IsFromLegacyTESchedule(PrimFunc f);
*\brief Context helper to update domain map within conditional scope.
*
* Assume the condition is `0 <= i && i < 9` and global domain of i is [0, 20], thus `bounds[i]` is
- * [0, 8]. Then `With<ConditionalBoundsContext> ctx(condition, &relax_map, &hint_map, true)` step
- *into scope where dom_map[i] is [0, 8] and `With<ConditionalBoundsContext> ctx(condition,
- *&relax_map, &hint_map, false)` step into scope where dom_map[i] is [9, 20]
+ *[0, 8]. Then `With<ConditionalBoundsContext> ctx(&dom_map, bounds, true)` step into scope where
+ *dom_map[i] is [0, 8] and `With<ConditionalBoundsContext> ctx(&dom_map, bounds, false)` step into
+ *scope where dom_map[i] is [9, 20]
*/
class ConditionalBoundsContext {
private:
@@ -241,13 +241,11 @@ class ConditionalBoundsContext {
/*!
* \brief Construct a condition bounds context.
* \param condition The condition holds on true branch.
- * \param relax_map The domain map for relaxed vars to update.
- * \param hint_map The domain map for free vars to update.
+ * \param dom_map The global domain map to be updated.
* \param is_true_branch Whether step into the branch where condition bounds holds.
*/
ConditionalBoundsContext(const PrimExpr& condition,
- std::unordered_map<const VarNode*, arith::IntSet>* relax_map,
- std::unordered_map<const VarNode*, arith::IntSet>* hint_map,
+ std::unordered_map<const VarNode*, arith::IntSet>* dom_map,
bool is_true_branch);
void EnterWithScope();
void ExitWithScope();
@@ -257,10 +255,8 @@ class ConditionalBoundsContext {
/*! \brief the condition holds on true branch. */
const PrimExpr& condition_;
- /*! \brief domain map for relaxed vars to update */
- std::unordered_map<const VarNode*, arith::IntSet>* relax_map_;
- /*! \brief domain map for free vars to update */
- std::unordered_map<const VarNode*, arith::IntSet>* hint_map_;
+ /*! \brief global domain map to updated */
+ std::unordered_map<const VarNode*, arith::IntSet>* dom_map_;
/*! \brief whether is on true branch */
bool is_true_branch_;
/*! \brief used to record and restore original var bounds */
diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
index 5403754..e508fbb 100644
--- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py
+++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
@@ -130,41 +130,6 @@ def access_in_branch_func() -> None:
B[i] = A[i - 1]
-@T.prim_func
-def access_of_padding_pattern() -> None:
- X = T.alloc_buffer([28, 28])
- X_pad = T.alloc_buffer([32, 32])
- Y = T.alloc_buffer([28, 28])
- for i, j in T.grid(32, 32):
- with T.block("padding"):
- vi, vj = T.axis.remap("SS", [i, j])
- T.reads(
- [
- X[
- T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1,
- T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1,
- ]
- ]
- )
- T.writes([X_pad[vi, vj]])
- X_pad[vi, vj] = T.if_then_else(
- 2 <= vi and vi < 30 and 2 <= vj and vj < 30, X[vi - 2, vj - 2], 0.0, dtype="float32"
- )
- with T.block("padding_reverse"):
- vi, vj = T.axis.remap("SS", [i, j])
- T.reads([X_pad[T.max(vi, 2) : T.min(vi, 29) + 1, T.max(vj, 2) : T.min(vj, 29) + 1]])
- T.writes(
- [
- Y[
- T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1,
- T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1,
- ]
- ]
- )
- if 2 <= vi and vi < 30 and 2 <= vj and vj < 30:
- Y[vi - 2, vj - 2] = X_pad[vi, vj]
-
-
def test_block_access_region_detector():
block = func.body.block.body.block
alloc_buffers = func.body.block.alloc_buffers
@@ -255,36 +220,6 @@ def test_access_in_branch_func():
tvm.ir.assert_structural_equal(ret0[1], ret1[1])
-def test_access_of_padding_pattern():
- s = tvm.tir.schedule.Schedule(access_of_padding_pattern)
- alloc_buffers = s.get_sref(s.get_block("root")).stmt.alloc_buffers
- buffer_var_map = {buf.data: buf for buf in alloc_buffers}
-
- def do_compare_buffer_region(region, expect):
- assert region.buffer == expect.buffer
- analyzer = tvm.arith.Analyzer()
- for k, rng in enumerate(region.region):
- tvm.ir.assert_structural_equal(
- analyzer.simplify(rng.min), analyzer.simplify(expect.region[k].min)
- )
- tvm.ir.assert_structural_equal(
- analyzer.simplify(rng.extent), analyzer.simplify(expect.region[k].extent)
- )
-
- def do_check_block(block_name):
- block = s.get_sref(s.get_block(block_name)).stmt
- expect_reads = block.reads
- expect_writes = block.writes
- ret = tir.analysis.get_block_access_region(block, buffer_var_map)
- for i, read in enumerate(ret[0]):
- do_compare_buffer_region(read, expect_reads[i])
- for i, write in enumerate(ret[1]):
- do_compare_buffer_region(write, expect_writes[i])
-
- do_check_block("padding")
- do_check_block("padding_reverse")
-
-
if __name__ == "__main__":
test_block_access_region_detector()
test_opaque_block()
@@ -292,4 +227,3 @@ if __name__ == "__main__":
test_match_buffer()
test_access_in_if_then_else_func()
test_access_in_branch_func()
- test_access_of_padding_pattern()
diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
index 9b84485..57c87e5 100644
--- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
+++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
@@ -24,7 +24,6 @@ def _check(original, transformed):
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.CompactBufferAllocation()(mod)
mod = tvm.tir.transform.Simplify()(mod)
- transformed = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(transformed))["main"]
tvm.ir.assert_structural_equal(mod["main"], transformed)