You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2022/09/08 16:36:06 UTC

[tvm] branch main updated: [TIR] Update region min/extent in ReplaceBufferMutator (#12725)

This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 299ca267e7 [TIR] Update region min/extent in ReplaceBufferMutator (#12725)
299ca267e7 is described below

commit 299ca267e7641b5fa6e78dd131d0574e310f9a13
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Thu Sep 8 09:35:58 2022 -0700

    [TIR] Update region min/extent in ReplaceBufferMutator (#12725)
    
    Prior to this commit, `ReplaceBufferMutator` only checks
    `BufferRegionNode::buffer` to determine if a `BufferRegion` needs to
    be replaced, and doesn't check the `BufferRegionNode::region`.  As a
    result, updating `T.reads(A[B[i]])` would fail to replace `B`.
    
    This commit checks `BufferRegionNode::region` for buffer usage to
    resolve this issue.
---
 src/tir/schedule/transform.cc                      | 27 +++++++++++++++++++---
 .../test_tir_schedule_set_axis_separator.py        | 24 +++++++++++++++++++
 2 files changed, 48 insertions(+), 3 deletions(-)

diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index 1ebaf202d4..c11fa656d6 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -138,9 +138,30 @@ Stmt ReplaceBufferMutator::VisitStmt_(const BlockNode* block) {
     return this->VisitMatchBufferRegion(match_buffer);
   };
   auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) {
-    auto it = buffer_var_map_.find(buffer_region->buffer->data.get());
-    return it == buffer_var_map_.end() ? buffer_region
-                                       : BufferRegion(it->second, buffer_region->region);
+    auto region = MutateArray(buffer_region->region, [this](const Range& range) {
+      PrimExpr min = VisitExpr(range->min);
+      PrimExpr extent = VisitExpr(range->extent);
+      if (min.same_as(range->min) && extent.same_as(range->extent)) {
+        return range;
+      } else {
+        return Range::FromMinExtent(min, extent);
+      }
+    });
+
+    Buffer buf = [&]() {
+      auto it = buffer_var_map_.find(buffer_region->buffer->data.get());
+      if (it == buffer_var_map_.end()) {
+        return buffer_region->buffer;
+      } else {
+        return it->second;
+      }
+    }();
+
+    if (buf.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) {
+      return buffer_region;
+    } else {
+      return BufferRegion(buf, region);
+    }
   };
   auto f_mutate_alloc_buffers = [this](const Buffer& buffer) {
     auto it = buffer_var_map_.find(buffer->data.get());
diff --git a/tests/python/unittest/test_tir_schedule_set_axis_separator.py b/tests/python/unittest/test_tir_schedule_set_axis_separator.py
index 9502da1829..b432fbb610 100644
--- a/tests/python/unittest/test_tir_schedule_set_axis_separator.py
+++ b/tests/python/unittest/test_tir_schedule_set_axis_separator.py
@@ -154,6 +154,30 @@ def test_set_axis_separator_subregion(use_sugared_transform):
     tvm.ir.assert_structural_equal(element_wise_subregion_match_set_axis_separator, s.mod["main"])
     verify_trace_roundtrip(sch=s, mod=func)
 
+class TestIndexedLookup(tvm.testing.CompareBeforeAfter):
+    def transform(self):
+        def func(mod):
+            sch = tir.Schedule(mod)
+            sch.set_axis_separator('block', 'B', [1])
+            return sch.mod
+        return func
+
+    @T.prim_func
+    def before():
+        A = T.alloc_buffer([4,4], dtype="int32")
+        B = T.alloc_buffer([1,1], dtype="int32")
+        for j in T.serial(4):
+            with T.block('block'):
+                A[B[0,0],j] = 0
+
+    @T.prim_func
+    def expected():
+        A = T.alloc_buffer([4,4], dtype="int32")
+        B = T.alloc_buffer([1,1], dtype="int32", axis_separators=[1])
+        for j in T.serial(4):
+            with T.block('block'):
+                A[B[0,0],j] = 0
+
 
 if __name__ == "__main__":
     tvm.testing.main()