You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/09/06 17:31:46 UTC

[tvm] branch main updated: [TIR][StorageRewrite] Allow in-place buffer reuse of non-flat memory (#12655)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 744649e53b [TIR][StorageRewrite] Allow in-place buffer reuse of non-flat memory (#12655)
744649e53b is described below

commit 744649e53bd32b53eb53020a111479facff3b88a
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Tue Sep 6 10:31:39 2022 -0700

    [TIR][StorageRewrite] Allow in-place buffer reuse of non-flat memory (#12655)
    
    * [TIR][StorageRewrite] Allow in-place buffer reuse of non-flat memory
    
    Previously, shared buffer use was entirely disabled for non-flat
    memory, since the existing checks for shared memory assume flat 1-d
    spaces.  This was enforced in `FindAlloc` and validated in
    `PrepareNewAlloc`.  The validation in `PrepareNewAlloc` could trigger,
    if the buffer sharing was due to an in-place operation, and not
    through the `FindAlloc` function.
    
    In-place operations do not require N-d packing, nor do they introduce
    ambiguity in how different code generators may interpret non-flat
    physical indices.  Therefore, this commit relaxes the validation in
    `PrepareNewAlloc`, allowing buffer reuse of non-flat buffers for
    in-place operations.
    
    * Update new StorageRewrite with correct allocate/buffer_decl usage
---
 src/tir/transforms/storage_rewrite.cc              |  20 +++-
 .../unittest/test_tir_transform_storage_rewrite.py | 116 ++++++++++++++++++++-
 2 files changed, 132 insertions(+), 4 deletions(-)

diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc
index 177017f9a2..67972ce672 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -655,7 +655,25 @@ class StoragePlanRewriter : public StmtExprMutator {
           }
         }
 
-        if (e->allocs.size() == 1) {
+        bool all_allocs_identical = std::all_of(
+            e->allocs.begin() + 1, e->allocs.end(), [&](const AllocateNode* op) -> bool {
+              const AllocateNode* first = *e->allocs.begin();
+              if (op->dtype != first->dtype) {
+                return false;
+              }
+              if (op->extents.size() != first->extents.size()) {
+                return false;
+              }
+              ExprDeepEqual expr_equal;
+              for (size_t i = 0; i < op->extents.size(); i++) {
+                if (!expr_equal(op->extents[i], first->extents[i])) {
+                  return false;
+                }
+              }
+              return true;
+            });
+
+        if (all_allocs_identical) {
           // simply use the original allocation.
           e->new_alloc = Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents,
                                   e->allocs[0]->condition, Evaluate(0));
diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py
index 581afef889..533a835e0f 100644
--- a/tests/python/unittest/test_tir_transform_storage_rewrite.py
+++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py
@@ -673,7 +673,11 @@ def test_access_in_let_value():
     tvm.ir.assert_structural_equal(mod["main"], func_rewritten)
 
 
-class TestLetBufferRewrite(tvm.testing.CompareBeforeAfter):
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+    transform = tvm.tir.transform.StorageRewrite()
+
+
+class TestLetBufferRewrite(BaseCompare):
     """StorageRewrite replaces the bound var of backing allocations
 
     If StorageRewrite replaces the backing variable of an array, such
@@ -684,8 +688,6 @@ class TestLetBufferRewrite(tvm.testing.CompareBeforeAfter):
     handled.
     """
 
-    transform = tvm.tir.transform.StorageRewrite()
-
     def before() -> None:
         A_data: T.Ptr[T.int32] = T.call_extern("dummy_func", dtype="handle")
         A = T.buffer_decl([8], "int32", data=A_data)
@@ -697,5 +699,113 @@ class TestLetBufferRewrite(tvm.testing.CompareBeforeAfter):
         A[0] = T.broadcast(42, 8)
 
 
+class TestRewriteInPlaceUseOfNonFlatBuffer(BaseCompare):
+    """A non-flat buffer may be re-used for in-place operations"""
+
+    def before(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]):
+        B_data = T.allocate(
+            [16, 16],
+            dtype="float32",
+            scope="global",
+        )
+        B = T.buffer_decl(
+            [16, 16],
+            dtype="float32",
+            axis_separators=[1],
+            data=B_data,
+        )
+        C_data = T.allocate(
+            [16, 16],
+            dtype="float32",
+            scope="global",
+        )
+        C = T.buffer_decl(
+            [16, 16],
+            dtype="float32",
+            axis_separators=[1],
+            data=C_data,
+        )
+
+        for i, j in T.grid(16, 16):
+            B[i, j] = A[i, j]
+
+        for i, j in T.grid(16, 16):
+            C[i, j] = 2.0 * B[i, j]
+
+        for i, j in T.grid(16, 16):
+            D[i, j] = C[i, j]
+
+    def expected(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]):
+        B_data = T.allocate(
+            [16, 16],
+            dtype="float32",
+            scope="global",
+        )
+        B = T.buffer_decl([16, 16], dtype="float32", axis_separators=[1], data=B_data)
+        C = T.buffer_decl(
+            [16, 16],
+            dtype="float32",
+            axis_separators=[1],
+            data=B.data,
+        )
+
+        for i, j in T.grid(16, 16):
+            B[i, j] = A[i, j]
+
+        for i, j in T.grid(16, 16):
+            C[i, j] = 2.0 * B[i, j]
+
+        for i, j in T.grid(16, 16):
+            D[i, j] = C[i, j]
+
+
+class TestNoRewriteOfSharedNonFlatBuffer(BaseCompare):
+    """In general, sharing of non-flat buffer isn't supported
+
+    The current packing algorithms in StorageRewrite assume a flat
+    memory space, and do not support packing of N-d buffers.  For
+    buffers with axis separators, normal buffer sharing should be
+    disabled.
+
+    Like TestRewriteInPlaceUseOfNonFlatBuffer, except that B and C do
+    not have matching shapes.
+    """
+
+    def before(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]):
+        B_data = T.allocate(
+            [16, 16],
+            dtype="float32",
+            scope="global",
+        )
+        B = T.buffer_decl(
+            [16, 16],
+            dtype="float32",
+            axis_separators=[1],
+            data=B_data,
+        )
+        C_data = T.allocate(
+            [20, 20],
+            dtype="float32",
+            scope="global",
+        )
+        C = T.buffer_decl(
+            [20, 20],
+            dtype="float32",
+            axis_separators=[1],
+            data=C_data,
+        )
+
+        for i, j in T.grid(16, 16):
+            B[i, j] = A[i, j]
+
+        for i, j in T.grid(16, 16):
+            C[i, j] = 2.0 * B[i, j]
+
+        for i, j in T.grid(16, 16):
+            D[i, j] = C[i, j]
+
+    expected = before
+
+
 if __name__ == "__main__":
     tvm.testing.main()