You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2022/09/08 16:36:06 UTC
[tvm] branch main updated: [TIR] Update region min/extent in ReplaceBufferMutator (#12725)
This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 299ca267e7 [TIR] Update region min/extent in ReplaceBufferMutator (#12725)
299ca267e7 is described below
commit 299ca267e7641b5fa6e78dd131d0574e310f9a13
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Thu Sep 8 09:35:58 2022 -0700
[TIR] Update region min/extent in ReplaceBufferMutator (#12725)
Prior to this commit, `ReplaceBufferMutator` only checks
`BufferRegionNode::buffer` to determine if a `BufferRegion` needs to
be replaced, and doesn't check the `BufferRegionNode::region`. As a
result, updating `T.reads(A[B[i]])` would fail to replace `B`.
This commit checks `BufferRegionNode::region` for buffer usage to
resolve this issue.
---
src/tir/schedule/transform.cc | 27 +++++++++++++++++++---
.../test_tir_schedule_set_axis_separator.py | 24 +++++++++++++++++++
2 files changed, 48 insertions(+), 3 deletions(-)
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index 1ebaf202d4..c11fa656d6 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -138,9 +138,30 @@ Stmt ReplaceBufferMutator::VisitStmt_(const BlockNode* block) {
return this->VisitMatchBufferRegion(match_buffer);
};
auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) {
- auto it = buffer_var_map_.find(buffer_region->buffer->data.get());
- return it == buffer_var_map_.end() ? buffer_region
- : BufferRegion(it->second, buffer_region->region);
+ auto region = MutateArray(buffer_region->region, [this](const Range& range) {
+ PrimExpr min = VisitExpr(range->min);
+ PrimExpr extent = VisitExpr(range->extent);
+ if (min.same_as(range->min) && extent.same_as(range->extent)) {
+ return range;
+ } else {
+ return Range::FromMinExtent(min, extent);
+ }
+ });
+
+ Buffer buf = [&]() {
+ auto it = buffer_var_map_.find(buffer_region->buffer->data.get());
+ if (it == buffer_var_map_.end()) {
+ return buffer_region->buffer;
+ } else {
+ return it->second;
+ }
+ }();
+
+ if (buf.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) {
+ return buffer_region;
+ } else {
+ return BufferRegion(buf, region);
+ }
};
auto f_mutate_alloc_buffers = [this](const Buffer& buffer) {
auto it = buffer_var_map_.find(buffer->data.get());
diff --git a/tests/python/unittest/test_tir_schedule_set_axis_separator.py b/tests/python/unittest/test_tir_schedule_set_axis_separator.py
index 9502da1829..b432fbb610 100644
--- a/tests/python/unittest/test_tir_schedule_set_axis_separator.py
+++ b/tests/python/unittest/test_tir_schedule_set_axis_separator.py
@@ -154,6 +154,30 @@ def test_set_axis_separator_subregion(use_sugared_transform):
tvm.ir.assert_structural_equal(element_wise_subregion_match_set_axis_separator, s.mod["main"])
verify_trace_roundtrip(sch=s, mod=func)
+class TestIndexedLookup(tvm.testing.CompareBeforeAfter):
+ def transform(self):
+ def func(mod):
+ sch = tir.Schedule(mod)
+ sch.set_axis_separator('block', 'B', [1])
+ return sch.mod
+ return func
+
+ @T.prim_func
+ def before():
+ A = T.alloc_buffer([4,4], dtype="int32")
+ B = T.alloc_buffer([1,1], dtype="int32")
+ for j in T.serial(4):
+ with T.block('block'):
+ A[B[0,0],j] = 0
+
+ @T.prim_func
+ def expected():
+ A = T.alloc_buffer([4,4], dtype="int32")
+ B = T.alloc_buffer([1,1], dtype="int32", axis_separators=[1])
+ for j in T.serial(4):
+ with T.block('block'):
+ A[B[0,0],j] = 0
+
if __name__ == "__main__":
tvm.testing.main()