You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/09/06 17:31:46 UTC
[tvm] branch main updated: [TIR][StorageRewrite] Allow in-place buffer reuse of non-flat memory (#12655)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 744649e53b [TIR][StorageRewrite] Allow in-place buffer reuse of non-flat memory (#12655)
744649e53b is described below
commit 744649e53bd32b53eb53020a111479facff3b88a
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Tue Sep 6 10:31:39 2022 -0700
[TIR][StorageRewrite] Allow in-place buffer reuse of non-flat memory (#12655)
* [TIR][StorageRewrite] Allow in-place buffer reuse of non-flat memory
Previously, shared buffer use was entirely disabled for non-flat
memory, since the existing checks for shared memory assume flat 1-d
spaces. This was enforced in `FindAlloc` and validated in
`PrepareNewAlloc`. The validation in `PrepareNewAlloc` could trigger,
if the buffer sharing was due to an in-place operation, and not
through the `FindAlloc` function.
In-place operations do not require N-d packing, nor do they introduce
ambiguity in how different code generators may interpret non-flat
physical indices. Therefore, this commit relaxes the validation in
`PrepareNewAlloc`, allowing buffer reuse of non-flat buffers for
in-place operations.
* Update new StorageRewrite with correct allocate/buffer_decl usage
---
src/tir/transforms/storage_rewrite.cc | 20 +++-
.../unittest/test_tir_transform_storage_rewrite.py | 116 ++++++++++++++++++++-
2 files changed, 132 insertions(+), 4 deletions(-)
diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc
index 177017f9a2..67972ce672 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -655,7 +655,25 @@ class StoragePlanRewriter : public StmtExprMutator {
}
}
- if (e->allocs.size() == 1) {
+ bool all_allocs_identical = std::all_of(
+ e->allocs.begin() + 1, e->allocs.end(), [&](const AllocateNode* op) -> bool {
+ const AllocateNode* first = *e->allocs.begin();
+ if (op->dtype != first->dtype) {
+ return false;
+ }
+ if (op->extents.size() != first->extents.size()) {
+ return false;
+ }
+ ExprDeepEqual expr_equal;
+ for (size_t i = 0; i < op->extents.size(); i++) {
+ if (!expr_equal(op->extents[i], first->extents[i])) {
+ return false;
+ }
+ }
+ return true;
+ });
+
+ if (all_allocs_identical) {
// simply use the original allocation.
e->new_alloc = Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents,
e->allocs[0]->condition, Evaluate(0));
diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py
index 581afef889..533a835e0f 100644
--- a/tests/python/unittest/test_tir_transform_storage_rewrite.py
+++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py
@@ -673,7 +673,11 @@ def test_access_in_let_value():
tvm.ir.assert_structural_equal(mod["main"], func_rewritten)
-class TestLetBufferRewrite(tvm.testing.CompareBeforeAfter):
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+ transform = tvm.tir.transform.StorageRewrite()
+
+
+class TestLetBufferRewrite(BaseCompare):
"""StorageRewrite replaces the bound var of backing allocations
If StorageRewrite replaces the backing variable of an array, such
@@ -684,8 +688,6 @@ class TestLetBufferRewrite(tvm.testing.CompareBeforeAfter):
handled.
"""
- transform = tvm.tir.transform.StorageRewrite()
-
def before() -> None:
A_data: T.Ptr[T.int32] = T.call_extern("dummy_func", dtype="handle")
A = T.buffer_decl([8], "int32", data=A_data)
@@ -697,5 +699,113 @@ class TestLetBufferRewrite(tvm.testing.CompareBeforeAfter):
A[0] = T.broadcast(42, 8)
+class TestRewriteInPlaceUseOfNonFlatBuffer(BaseCompare):
+ """A non-flat buffer may be re-used for in-place operations"""
+
+ def before(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]):
+ B_data = T.allocate(
+ [16, 16],
+ dtype="float32",
+ scope="global",
+ )
+ B = T.buffer_decl(
+ [16, 16],
+ dtype="float32",
+ axis_separators=[1],
+ data=B_data,
+ )
+ C_data = T.allocate(
+ [16, 16],
+ dtype="float32",
+ scope="global",
+ )
+ C = T.buffer_decl(
+ [16, 16],
+ dtype="float32",
+ axis_separators=[1],
+ data=C_data,
+ )
+
+ for i, j in T.grid(16, 16):
+ B[i, j] = A[i, j]
+
+ for i, j in T.grid(16, 16):
+ C[i, j] = 2.0 * B[i, j]
+
+ for i, j in T.grid(16, 16):
+ D[i, j] = C[i, j]
+
+ def expected(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]):
+ B_data = T.allocate(
+ [16, 16],
+ dtype="float32",
+ scope="global",
+ )
+ B = T.buffer_decl([16, 16], dtype="float32", axis_separators=[1], data=B_data)
+ C = T.buffer_decl(
+ [16, 16],
+ dtype="float32",
+ axis_separators=[1],
+ data=B.data,
+ )
+
+ for i, j in T.grid(16, 16):
+ B[i, j] = A[i, j]
+
+ for i, j in T.grid(16, 16):
+ C[i, j] = 2.0 * B[i, j]
+
+ for i, j in T.grid(16, 16):
+ D[i, j] = C[i, j]
+
+
+class TestNoRewriteOfSharedNonFlatBuffer(BaseCompare):
+ """In general, sharing of non-flat buffer isn't supported
+
+ The current packing algorithms in StorageRewrite assume a flat
+ memory space, and do not support packing of N-d buffers. For
+ buffers with axis separators, normal buffer sharing should be
+ disabled.
+
+ Like TestRewriteInPlaceUseOfNonFlatBuffer, except that B and C do
+ not have matching shapes.
+ """
+
+ def before(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]):
+ B_data = T.allocate(
+ [16, 16],
+ dtype="float32",
+ scope="global",
+ )
+ B = T.buffer_decl(
+ [16, 16],
+ dtype="float32",
+ axis_separators=[1],
+ data=B_data,
+ )
+ C_data = T.allocate(
+ [20, 20],
+ dtype="float32",
+ scope="global",
+ )
+ C = T.buffer_decl(
+ [20, 20],
+ dtype="float32",
+ axis_separators=[1],
+ data=C_data,
+ )
+
+ for i, j in T.grid(16, 16):
+ B[i, j] = A[i, j]
+
+ for i, j in T.grid(16, 16):
+ C[i, j] = 2.0 * B[i, j]
+
+ for i, j in T.grid(16, 16):
+ D[i, j] = C[i, j]
+
+ expected = before
+
+
if __name__ == "__main__":
tvm.testing.main()