You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wr...@apache.org on 2022/09/08 15:03:11 UTC
[tvm] branch main updated: [TIR] Handle axis_separators during FlattenBuffer (#12652)
This is an automated email from the ASF dual-hosted git repository.
wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b2bd434ef9 [TIR] Handle axis_separators during FlattenBuffer (#12652)
b2bd434ef9 is described below
commit b2bd434ef944315a6f241803ac03c59c9aaa9847
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Thu Sep 8 08:02:42 2022 -0700
[TIR] Handle axis_separators during FlattenBuffer (#12652)
* [TIR] Moved tir.FlattenBuffer to occur before tir.LowerOpaqueBlock
For buffers with more than one physical axis, the `axis_separators`
are required in order to know which groups of logical axes to fuse
into each physical axis. The implementation in `tir.FlattenBuffer`
assumed that all buffers were being flattened to a single physical
axis. Because `tir.LowerOpaqueBlock` replaces the
`BlockNode::alloc_buffers` with `Allocate` nodes, `tir.FlattenBuffer`
no longer has access to the axis separators and performs inconsistent
flattening for `Allocate` as opposed to `BufferLoad`/`BufferStore`.
This was introduced in https://github.com/apache/tvm/pull/12172, which
decoupled the lowering/flattening steps.
The commit reorders the `tir.FlattenBuffer` to occur before
`tir.LowerOpaqueBlock`, to make use of the axis separators. Any
`Allocate` nodes that exist at that point (e.g. from hand-written
schedules) are still flattened to 1-d physical buffers, but the
`BlockNode::alloc_buffers` are flattened according to the axis
separators.
* Add unit test to validate non-flat memory after tvm.lower
* Explicitly write T.reads for test on BufferRegion updates
* Update incorrect docstring for test
* Use DeclBuffer information in FlattenBuffer
The DeclBuffer node can be inserted during LowerOpaqueBlock, then
provide the missing Buffer information required to flatten the
allocation.
* Use T.allocate in unit tests
With the insertion of `DeclBuffer` nodes, `LowerOpaqueBlock` no longer
needs to be before `FlattenBuffer`, and has been moved back to its
original position. Revering the tests to use `T.allocate` instead of
`T.alloc_buffer` more closely represents the functions as they are
being lowered.
* Fix usage of T.decl_buffer in updated tests
* Update LowerOpaqueBuffer to expect the DeclBuffer nodes
* Strip DeclBuffer annotation in FlattenBuffer
The DeclBuffer annotations aren't yet supported in all passes. This
restricts them to being introduced in LowerOpaqueBuffer, then
immediately removed in FlattenBuffer.
* Strip out all DeclBuffer nodes in FlattenBuffer
* Update unit tests to remove expectation of DeclBuffer nodes
---
src/tir/transforms/flatten_buffer.cc | 123 ++++-
src/tir/transforms/lower_opaque_block.cc | 1 +
.../unittest/test_tir_transform_flatten_buffer.py | 502 ++++++++++++---------
.../test_tir_transform_lower_opaque_block.py | 22 +-
4 files changed, 417 insertions(+), 231 deletions(-)
diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc
index 22aef136bc..5441120491 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -21,6 +21,7 @@
* \file flatten_buffer.cc
*/
+#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
@@ -53,6 +54,34 @@ class BufferFlattener : public StmtExprMutator {
}
}
+ Stmt VisitStmt_(const BlockNode* op) final {
+ ICHECK_EQ(op->match_buffers.size(), 0)
+ << "Unexpected MatchBufferRegion found during tir.transform.FlattenBuffer. "
+ << "All MatchBufferRegion should be removed in tir.transform.LowerMatchBuffer.";
+
+ Block block = GetRef<Block>(op);
+
+ Array<Buffer> alloc_buffers = op->alloc_buffers;
+ alloc_buffers.MutateByApply([this](Buffer buf) { return GetFlattenedBuffer(buf); });
+ if (!alloc_buffers.same_as(op->alloc_buffers)) {
+ block.CopyOnWrite()->alloc_buffers = alloc_buffers;
+ }
+
+ Array<BufferRegion> reads = op->reads;
+ reads.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); });
+ if (!reads.same_as(op->reads)) {
+ block.CopyOnWrite()->reads = reads;
+ }
+
+ Array<BufferRegion> writes = op->writes;
+ writes.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); });
+ if (!writes.same_as(op->writes)) {
+ block.CopyOnWrite()->writes = writes;
+ }
+
+ return StmtExprMutator::VisitStmt_(block.get());
+ }
+
Stmt VisitStmt_(const AllocateNode* op) final {
Allocate alloc = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
// TODO(Lunderberg): Move the handling of boolean into a
@@ -61,18 +90,70 @@ class BufferFlattener : public StmtExprMutator {
auto writer = alloc.CopyOnWrite();
writer->dtype = DataType::Int(8);
}
- // Handle multi-dimension allocations
+
if (alloc->extents.size() == 1) {
- return std::move(alloc);
- } else {
- Array<PrimExpr> flat_extent(static_cast<size_t>(1), 1);
- for (size_t i = 0; i < alloc->extents.size(); i++) {
- flat_extent.Set(0, flat_extent[0] * alloc->extents[i]);
+ // No flattening required for buffers that are already flat
+
+ // TODO(rfc-70): Keep the DeclBuffer node as-is. Stripping it
+ // out in the current implementation as not all lowering passes
+ // support DeclBuffer.
+ if (auto* decl_buffer = alloc->body.as<DeclBufferNode>()) {
+ alloc.CopyOnWrite()->body = std::move(decl_buffer->body);
}
- auto n = alloc.CopyOnWrite();
- n->extents = flat_extent;
+
return std::move(alloc);
}
+
+ if (auto* decl_buffer = alloc->body.as<DeclBufferNode>();
+ decl_buffer && decl_buffer->buffer->data.same_as(alloc->buffer_var)) {
+ // N-d buffer, use the DeclBuffer inside to determine how it
+ // should be flattened.
+ auto& buffer = decl_buffer->buffer;
+ bool matching_buffer = [&]() {
+ if (alloc->dtype != buffer->dtype) {
+ return false;
+ }
+ if (alloc->extents.size() != buffer->shape.size()) {
+ return false;
+ }
+ ExprDeepEqual expr_equal;
+ for (size_t i = 0; i < alloc->extents.size(); i++) {
+ if (!expr_equal(alloc->extents[i], buffer->shape[i])) {
+ return false;
+ }
+ }
+ return true;
+ }();
+
+ if (matching_buffer) {
+ Buffer flattened = GetFlattenedBuffer(buffer);
+
+ auto n = alloc.CopyOnWrite();
+ // TODO(rfc-70): Update the DeclBuffer node instead of
+ // stripping it out. Stripping it out in the current
+ // implementation as not all lowering passes support
+ // DeclBuffer.
+ //
+ // n->body = DeclBuffer(flattened, std::move(decl_buffer->body));
+ n->body = std::move(decl_buffer->body);
+ n->extents = flattened->shape;
+ return std::move(alloc);
+ } else {
+ ICHECK(decl_buffer->buffer->axis_separators.empty())
+ << "DeclBuffer node doesn't match Allocate extents, but also shouldn't be "
+ "flattened to 1-d physical memory";
+ }
+ }
+
+ // Fallback, this is an allocation without a matching DeclBuffer
+ PrimExpr flat_extent = 1;
+ for (const auto& dim : alloc->extents) {
+ flat_extent *= dim;
+ }
+
+ auto n = alloc.CopyOnWrite();
+ n->extents = {flat_extent};
+ return std::move(alloc);
}
Buffer GetFlattenedBuffer(Buffer buf) {
@@ -141,6 +222,32 @@ class BufferFlattener : public StmtExprMutator {
return node;
}
+ BufferRegion MutateBufferRegion(BufferRegion region) {
+ Buffer orig_buf = region->buffer;
+ Buffer flattened_buf = GetFlattenedBuffer(orig_buf);
+ if (flattened_buf.same_as(orig_buf)) {
+ return region;
+ }
+
+ Array<PrimExpr> min_values;
+ Array<PrimExpr> max_values;
+ for (const auto& range : region->region) {
+ min_values.push_back(range->min);
+ max_values.push_back(range->min + range->extent - 1);
+ }
+
+ Array<PrimExpr> flattened_min = orig_buf->ElemOffset(min_values);
+ Array<PrimExpr> flattened_max = orig_buf->ElemOffset(max_values);
+
+ Array<Range> flattened_ranges;
+ ICHECK_EQ(flattened_min.size(), flattened_max.size());
+ for (size_t i = 0; i < flattened_min.size(); i++) {
+ flattened_ranges.push_back(Range(flattened_min[i], flattened_max[i] + 1));
+ }
+
+ return BufferRegion(flattened_buf, flattened_ranges);
+ }
+
/*! \brief Map of buffers being remapped. */
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;
diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/tir/transforms/lower_opaque_block.cc
index a4655ebbae..ce74fdc4c1 100644
--- a/src/tir/transforms/lower_opaque_block.cc
+++ b/src/tir/transforms/lower_opaque_block.cc
@@ -57,6 +57,7 @@ class OpaqueBlockLower : public StmtExprMutator {
new_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]);
}
}
+ body = DeclBuffer(buffer, std::move(body));
body = Allocate(buffer->data, buffer->dtype, new_shape, const_true(), std::move(body));
}
// Step 4. Handle annotations, block annotations are not preserved by default.
diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py
index 4cdf71889e..870208499e 100644
--- a/tests/python/unittest/test_tir_transform_flatten_buffer.py
+++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py
@@ -20,223 +20,307 @@ from tvm import te
from tvm.script import tir as T
-def _check(original, transformed):
- func = original
- mod = tvm.IRModule.from_expr(func)
- mod = tvm.tir.transform.FlattenBuffer()(mod)
- mod = tvm.tir.transform.Simplify()(mod)
- tvm.ir.assert_structural_equal(mod["main"], transformed, True)
-
-
-@T.prim_func
-def elementwise_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (16, 16), "float32")
- C = T.match_buffer(c, (16, 16), "float32")
- for i in T.serial(0, 16):
- B_new_data = T.allocate([1, 16], "float32", "global")
- B_new = T.buffer_decl(shape=[1, 16], dtype="float32", data=B_new_data)
- for j in T.serial(0, 16):
- B_new[0, j] = A[i, j] + 1.0
- for j in T.serial(0, 16):
- C[i, j] = B_new[0, j] * 2.0
-
-
-@T.prim_func
-def flattened_elementwise_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, 256, "float32")
- C = T.match_buffer(c, 256, "float32")
- T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
- T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
- for i in T.serial(0, 16):
- B_new_data = T.allocate([16], "float32", "global")
- B_new = T.buffer_decl(shape=[16], dtype="float32", data=B_new_data)
- for j in T.serial(0, 16):
- B_new[j] = A[((i * 16) + j)] + 1.0
- for j in T.serial(0, 16):
- C[((i * 16) + j)] = B_new[j] * 2.0
-
-
-@T.prim_func
-def gpu_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (16, 16), "float32")
- C = T.match_buffer(c, (16, 16), "float32")
-
- i0 = T.env_thread("blockIdx.x")
- i1 = T.env_thread("threadIdx.x")
- i2 = T.env_thread("vthread")
-
- T.launch_thread(i0, 4)
- T.launch_thread(i1, 2)
- T.launch_thread(i2, 2)
- B_data = T.allocate([1, 16], "float32", "local")
- B = T.buffer_decl(shape=[1, 16], dtype="float32", data=B_data, scope="local")
- for j in range(0, 16):
- B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
- for j in range(0, 16):
- C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0
-
-
-@T.prim_func
-def flattened_gpu_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, 256, "float32")
- C = T.match_buffer(c, 256, "float32")
- T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
- T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
-
- i0 = T.env_thread("blockIdx.x")
- i1 = T.env_thread("threadIdx.x")
- i2 = T.env_thread("vthread")
-
- T.launch_thread(i0, 4)
- T.launch_thread(i1, 2)
- T.launch_thread(i2, 2)
- B_data = T.allocate([16], "float32", "local")
- B = T.buffer_decl(shape=[16], dtype="float32", data=B_data, scope="local")
- for j in range(0, 16):
- B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0
- for j in range(0, 16):
- C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0
-
-
-@T.prim_func
-def symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None:
- A = T.match_buffer(a, (n, m), "float32")
- C = T.match_buffer(c, (n, m), "float32")
-
- for i in range(0, n):
- B_data = T.allocate([m], "float32", "global")
- B = T.buffer_decl(shape=[m], dtype="float32", data=B_data)
- for j in range(0, m):
- B[j] = A[i, j] + 1.0
- for j in range(0, m):
- C[i, j] = B[j] * 2.0
-
-
-@T.prim_func
-def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None:
- A = T.match_buffer(a, n * m, "float32")
- C = T.match_buffer(c, n * m, "float32")
- T.preflattened_buffer(A, (n, m), "float32", data=A.data)
- T.preflattened_buffer(C, (n, m), "float32", data=C.data)
-
- for i in range(0, n):
- B_data = T.allocate([m], "float32", "global")
- B = T.buffer_decl(shape=[m], dtype="float32", data=B_data)
- for j in range(0, m):
- B[j] = A[i * m + j] + 1.0
- for j in range(0, m):
- C[i * m + j] = B[j] * 2.0
-
-
-@T.prim_func
-def multi_alloc_func(a: T.handle, d: T.handle) -> None:
- A = T.match_buffer(a, (4, 32), "float32")
- D = T.match_buffer(d, (4, 32), "float32")
-
- for i, j in T.grid(4, 32):
- B_data = T.allocate((4, 32), "float32", scope="global")
- B = T.buffer_decl(shape=(4, 32), dtype="float32", data=B_data)
- C_data = T.allocate((4, 32), "float32", scope="global")
- C = T.buffer_decl(shape=(4, 32), dtype="float32", data=C_data)
- B[i, j] = A[i, j] + 1.0
- C[i, j] = A[i, j] + B[i, j]
- D[i, j] = C[i, j] * 2.0
-
-
-@T.prim_func
-def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None:
- A = T.match_buffer(a, 128, "float32")
- D = T.match_buffer(d, 128, "float32")
- T.preflattened_buffer(A, (4, 32), "float32", data=A.data)
- T.preflattened_buffer(D, (4, 32), "float32", data=D.data)
-
- for i, j in T.grid(4, 32):
- B_data = T.allocate([128], "float32", "global")
- B = T.buffer_decl(shape=[128], dtype="float32", data=B_data)
- C_data = T.allocate([128], "float32", "global")
- C = T.buffer_decl(shape=[128], dtype="float32", data=C_data)
- B[i * 32 + j] = A[i * 32 + j] + 1.0
- C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j]
- D[i * 32 + j] = C[i * 32 + j] * 2.0
-
-
-@T.prim_func
-def strided_buffer_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (16, 16), "float32")
- C = T.match_buffer(c, (16, 16), "float32")
- for i0 in T.serial(4):
- B_data = T.allocate([4, 17], "float32", "global")
- B = T.buffer_decl(shape=[4, 17], dtype="float32", data=B_data)
- B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1])
- for i1, j in T.grid(4, 16):
- B_1[i1, j] = A[i0 * 4 + i1, j] + 1.0
- for i1, j in T.grid(4, 16):
- C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0
-
-
-@T.prim_func
-def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (256,), "float32")
- C = T.match_buffer(c, (256,), "float32")
- T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data)
- T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data)
- for i0 in T.serial(0, 4):
- B_new_data = T.allocate([68], "float32", "global")
- B_new = T.buffer_decl(shape=[68], dtype="float32", data=B_new_data)
- for i1 in T.serial(0, 4):
- for j in T.serial(0, 16):
- B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0
- for i1 in T.serial(0, 4):
- for j in T.serial(0, 16):
- C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0
-
-
-@T.prim_func
-def boolean_handling_before(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) -> None:
- for i0 in T.serial(10):
- b[i0] = a[i0]
-
-
-@T.prim_func
-def boolean_handling_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) -> None:
- T.preflattened_buffer(a, [10], dtype="bool", data=a.data)
- T.preflattened_buffer(b, [10], dtype="bool", data=b.data)
- # body
- for i0 in T.serial(10):
- b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")
-
-
-def test_elementwise():
- _check(elementwise_func, flattened_elementwise_func)
-
-
-def test_gpu_workload():
- _check(gpu_func, flattened_gpu_func)
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+ transform = tvm.transform.Sequential(
+ [
+ tvm.tir.transform.FlattenBuffer(),
+ tvm.tir.transform.Simplify(),
+ ]
+ )
-def test_symbolic_shape():
- _check(symbolic_func, flattened_symbolic_func)
-
-
-def test_multi_alloc():
- _check(multi_alloc_func, flattened_multi_alloc_func)
+class TestElementwise(BaseCompare):
+ """2-d buffers are flattened to 1-d"""
+ def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]):
+ for i in T.serial(0, 16):
+ B_new = T.decl_buffer([1, 16], "float32")
+ for j in T.serial(0, 16):
+ B_new[0, j] = A[i, j] + 1.0
+ for j in T.serial(0, 16):
+ C[i, j] = B_new[0, j] * 2.0
+
+ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]):
+ T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
+ T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
+ for i in T.serial(0, 16):
+ B_new_data = T.allocate([16], "float32", scope="global")
+ B_new = T.buffer_decl([16], "float32", scope="global", data=B_new_data)
+ for j in T.serial(0, 16):
+ B_new[j] = A[((i * 16) + j)] + 1.0
+ for j in T.serial(0, 16):
+ C[((i * 16) + j)] = B_new[j] * 2.0
-def test_strided_buffer():
- _check(strided_buffer_func, flattened_strided_buffer_func)
+class TestElementwiseWithoutDeclBuffer(BaseCompare):
+ """2-d buffers are flattened to 1-d
-def test_lower_te():
- x = te.placeholder((1,))
- y = te.compute((1,), lambda i: x[i] + 2)
- s = te.create_schedule(y.op)
- orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
- mod = tvm.tir.transform.FlattenBuffer()(orig_mod)
- tvm.ir.assert_structural_equal(mod, orig_mod) # FlattenBuffer should do nothing on TE
+ Like TestElementwise, but the TIR doesn't have the DeclBuffer
+ node. The T.buffer_decl declaration applies only during the
+ parsing the TVMScript, and doesn't occur in the TIR itself. In
+ this case, the allocation should be assumed to be targeting flat
+ memory, and should be flattened to a 1-d allocation.
+ """
+ def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]):
+ for i in T.serial(0, 16):
+ B_new_data = T.allocate([1, 16], "float32", "global")
+ B_new = T.buffer_decl([1, 16], "float32", data=B_new_data)
+ for j in T.serial(0, 16):
+ B_new[0, j] = A[i, j] + 1.0
+ for j in T.serial(0, 16):
+ C[i, j] = B_new[0, j] * 2.0
+
+ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]):
+ T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
+ T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
+ for i in T.serial(0, 16):
+ B_new_data = T.allocate([16], "float32", "global")
+ B_new = T.buffer_decl(16, "float32", data=B_new_data)
+ for j in T.serial(0, 16):
+ B_new[j] = A[((i * 16) + j)] + 1.0
+ for j in T.serial(0, 16):
+ C[((i * 16) + j)] = B_new[j] * 2.0
+
+
+class TestGPU(BaseCompare):
+ """Buffer flattening may have indices based on GPU thread vars"""
+
+ def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]):
+ i0 = T.env_thread("blockIdx.x")
+ i1 = T.env_thread("threadIdx.x")
+ i2 = T.env_thread("vthread")
+
+ T.launch_thread(i0, 4)
+ T.launch_thread(i1, 2)
+ T.launch_thread(i2, 2)
+ B = T.decl_buffer([1, 16], "float32", scope="local")
+ for j in range(0, 16):
+ B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
+ for j in range(0, 16):
+ C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0
+
+ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]):
+ T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
+ T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
+
+ i0 = T.env_thread("blockIdx.x")
+ i1 = T.env_thread("threadIdx.x")
+ i2 = T.env_thread("vthread")
+
+ T.launch_thread(i0, 4)
+ T.launch_thread(i1, 2)
+ T.launch_thread(i2, 2)
+ B_data = T.allocate([16], "float32", scope="local")
+ B = T.buffer_decl([16], "float32", scope="local", data=B_data)
+ for j in range(0, 16):
+ B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0
+ for j in range(0, 16):
+ C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0
+
+
+class TestSymbolic(BaseCompare):
+ """Dynamically-sized arrrays are flattened"""
+
+ def before(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None:
+ A = T.match_buffer(a, (n, m), "float32")
+ C = T.match_buffer(c, (n, m), "float32")
+
+ for i in range(0, n):
+ B = T.decl_buffer([m], "float32")
+ for j in range(0, m):
+ B[j] = A[i, j] + 1.0
+ for j in range(0, m):
+ C[i, j] = B[j] * 2.0
+
+ def expected(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None:
+ A = T.match_buffer(a, n * m, "float32")
+ C = T.match_buffer(c, n * m, "float32")
+ T.preflattened_buffer(A, (n, m), "float32", data=A.data)
+ T.preflattened_buffer(C, (n, m), "float32", data=C.data)
+
+ for i in range(0, n):
+ B_data = T.allocate([m], "float32", scope="global")
+ B = T.buffer_decl([m], "float32", scope="global", data=B_data)
+ for j in range(0, m):
+ B[j] = A[i * m + j] + 1.0
+ for j in range(0, m):
+ C[i * m + j] = B[j] * 2.0
+
+
+class TestMultiAlloc(BaseCompare):
+ """If multiple allocations occur, all are flattened."""
+
+ def before(A: T.Buffer[(4, 32), "float32"], D: T.Buffer[(4, 32), "float32"]):
+ for i, j in T.grid(4, 32):
+ B = T.decl_buffer((4, 32), "float32", scope="global")
+ C = T.decl_buffer((4, 32), "float32", scope="global")
+ B[i, j] = A[i, j] + 1.0
+ C[i, j] = A[i, j] + B[i, j]
+ D[i, j] = C[i, j] * 2.0
+
+ def expected(A: T.Buffer[128, "float32"], D: T.Buffer[128, "float32"]):
+ T.preflattened_buffer(A, (4, 32), "float32", data=A.data)
+ T.preflattened_buffer(D, (4, 32), "float32", data=D.data)
+
+ for i, j in T.grid(4, 32):
+ B_data = T.allocate([128], "float32", scope="global")
+ B = T.buffer_decl([128], "float32", scope="global", data=B_data)
+ C_data = T.allocate([128], "float32", scope="global")
+ C = T.buffer_decl([128], "float32", scope="global", data=C_data)
+ B[i * 32 + j] = A[i * 32 + j] + 1.0
+ C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j]
+ D[i * 32 + j] = C[i * 32 + j] * 2.0
+
+
+class TestStrided(BaseCompare):
+ """Indices for flattened buffers use the specified striding."""
+
+ def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]):
+ for i0 in T.serial(4):
+ B = T.decl_buffer([4, 17], "float32")
+ B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1])
+ for i1, j in T.grid(4, 16):
+ B_1[i1, j] = A[i0 * 4 + i1, j] + 1.0
+ for i1, j in T.grid(4, 16):
+ C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0
+
+ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]):
+ T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data)
+ T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data)
+ for i0 in T.serial(0, 4):
+ B_new_data = T.allocate([68], "float32", scope="global")
+ B_new = T.buffer_decl([68], "float32", scope="global", data=B_new_data)
+ for i1 in T.serial(0, 4):
+ for j in T.serial(0, 16):
+ B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0
+ for i1 in T.serial(0, 4):
+ for j in T.serial(0, 16):
+ C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0
+
+
+class TestBoolean(BaseCompare):
+ """Boolean buffers should be replaced by a backing int8 array"""
+
+ def before(A: T.Buffer[10, "bool"], B: T.Buffer[10, "bool"]) -> None:
+ for i0 in T.serial(10):
+ B[i0] = A[i0]
+
+ def expected(A: T.Buffer[10, "int8"], B: T.Buffer[10, "int8"]) -> None:
+ T.preflattened_buffer(A, [10], dtype="bool", data=A.data)
+ T.preflattened_buffer(B, [10], dtype="bool", data=B.data)
+ # body
+ for i0 in T.serial(10):
+ B[i0] = T.cast(T.cast(A[i0], "bool"), "int8")
+
+
+class TestLowerTE(BaseCompare):
+ """FlattenBuffer should do nothing on TE-based functions"""
+
+ def before(self):
+ x = te.placeholder((1,))
+ y = te.compute((1,), lambda i: x[i] + 2)
+ s = te.create_schedule(y.op)
+ mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
+ return mod["main"]
+
+ expected = before
+
+
+class TestFlattenInsideBlock(BaseCompare):
+ """Flattening access inside a block flattens the accessed region."""
+
+ def before():
+ A = T.alloc_buffer([32, 32])
+ for i, j in T.grid(32, 32):
+ with T.block("block"):
+ T.reads(A[i, j])
+ T.evaluate(A[i, j])
+
+ def expected():
+ A = T.alloc_buffer([1024])
+ for i, j in T.grid(32, 32):
+ with T.block("block"):
+ T.reads(A[i * 32 + j])
+ T.evaluate(A[i * 32 + j])
+
+
+class TestNoChangeTo2DPhysicalBuffer(BaseCompare):
+ """Flattening preserves axis separators."""
+
+ def before():
+ A = T.alloc_buffer([32, 32], axis_separators=[1])
+ for i, j in T.grid(32, 32):
+ T.evaluate(A[i, j])
+
+ expected = before
+
+
+class TestFlattenAllocBufferWithAxisSeparators(BaseCompare):
+ """Flattening preserves axis separators"""
+
+ def before():
+ A = T.alloc_buffer([2, 3, 5, 7, 11, 13], axis_separators=[3])
+ for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13):
+ T.evaluate(A[i0, i1, i2, i3, i4, i5])
+
+ def expected():
+ A = T.alloc_buffer([30, 1001], axis_separators=[1])
+ for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13):
+ T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5])
+
+
+class TestFlattenDeclBufferWithAxisSeparators(BaseCompare):
+ """Flattening preserves axis separators
+
+ Like TestFlattenAllocBufferWithAxisSeparators, but the allocations
+ is done using Allocate/DeclBuffer, rather than through
+ BlockNode::alloc_buffers.
+ """
+
+ def before():
+ A = T.decl_buffer([2, 3, 5, 7, 11, 13], axis_separators=[3])
+ for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13):
+ T.evaluate(A[i0, i1, i2, i3, i4, i5])
+
+ def expected():
+ A_data = T.allocate([30, 1001], dtype="float32", scope="global")
+ A = T.buffer_decl(
+ [30, 1001], dtype="float32", scope="global", axis_separators=[1], data=A_data
+ )
+ for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13):
+ T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5])
+
+
+def test_lower_2d_physical_memory():
+ """Axis separators should preserve 2-d buffers through lowering.
-def test_boolean_handling():
- _check(boolean_handling_before, boolean_handling_after)
+ A catch-all test to ensure that defining axis_separators is
+ sufficient to maintain non-flat buffer descriptions through all
+ lowering steps.
+ """
+
+ # This test doesn't use CompareBeforeAfter, because the after step
+ # is not currently expressible in TVMScript. This test can be
+ # re-written after https://github.com/apache/tvm/pull/12412.
+
+ @T.prim_func
+ def func():
+ buf = T.alloc_buffer(
+ [1, 1],
+ dtype="int32",
+ scope="global",
+ axis_separators=[1],
+ )
+ buf[0, 0] = 0
+
+ lowered = tvm.lower(func)["main"]
+ assert isinstance(lowered.body, tvm.tir.Allocate)
+ assert list(lowered.body.extents) == [1, 1], (
+ "Non-flat buffer allocations, "
+ "marked by axis_separators, "
+ "flattened to flat memory allocation."
+ )
if __name__ == "__main__":
diff --git a/tests/python/unittest/test_tir_transform_lower_opaque_block.py b/tests/python/unittest/test_tir_transform_lower_opaque_block.py
index f8f3e3a5ac..824cef1740 100644
--- a/tests/python/unittest/test_tir_transform_lower_opaque_block.py
+++ b/tests/python/unittest/test_tir_transform_lower_opaque_block.py
@@ -54,8 +54,7 @@ def transformed_elementwise_func(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float32")
C = T.match_buffer(c, (16, 16), "float32")
for i in T.serial(0, 16):
- B_new_data = T.allocate([1, 16], "float32", "global")
- B_new = T.buffer_decl(shape=[1, 16], dtype="float32", data=B_new_data)
+ B_new = T.decl_buffer(shape=[1, 16], dtype="float32")
for j in T.serial(0, 16):
B_new[0, j] = A[i, j] + 1.0
for j in T.serial(0, 16):
@@ -97,8 +96,7 @@ def transformed_gpu_func(a: T.handle, c: T.handle) -> None:
T.launch_thread(i0, 4)
T.launch_thread(i1, 2)
T.launch_thread(i2, 2)
- B_data = T.allocate([1, 16], "float32", "local")
- B = T.buffer_decl(shape=[1, 16], dtype="float32", scope="local", data=B_data)
+ B = T.decl_buffer(shape=[1, 16], dtype="float32", scope="local")
for j in range(0, 16):
B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
for j in range(0, 16):
@@ -133,8 +131,7 @@ def transformed_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32)
C = T.match_buffer(c, (n, m), "float32")
for i in range(0, n):
- B_data = T.allocate([m], "float32", "global")
- B = T.buffer_decl(shape=[m], dtype="float32", data=B_data)
+ B = T.decl_buffer(shape=[m], dtype="float32")
for j in range(0, m):
B[j] = A[i, j] + 1.0
for j in range(0, m):
@@ -207,10 +204,8 @@ def transformed_multi_alloc_func(a: T.handle, d: T.handle) -> None:
D = T.match_buffer(d, (32), "float32")
for i in range(0, 32):
- B_data = T.allocate((32,), "float32", "global")
- B = T.buffer_decl(shape=(32,), dtype="float32", data=B_data)
- C_data = T.allocate((32,), "float32", "global")
- C = T.buffer_decl(shape=(32,), dtype="float32", data=C_data)
+ B = T.decl_buffer(shape=(32,), dtype="float32")
+ C = T.decl_buffer(shape=(32,), dtype="float32")
B[i] = A[i] + 1.0
C[i] = A[i] + B[i]
D[i] = C[i] * 2.0
@@ -246,12 +241,11 @@ def transformed_strided_buffer_func(
# body
for i0 in T.serial(4):
B_data = T.allocate([4, 17], "float32", "global")
- B = T.buffer_decl(shape=[4, 17], dtype="float32", data=B_data)
- B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1])
+ B = T.decl_buffer(shape=[4, 16], dtype="float32", strides=[17, 1], data=B_data)
for i1, j in T.grid(4, 16):
- B_1[i1, j] = A[i0 * 4 + i1, j] + T.float32(1)
+ B[i1, j] = A[i0 * 4 + i1, j] + T.float32(1)
for i1, j in T.grid(4, 16):
- C[i0 * 4 + i1, j] = B_1[i1, j] * T.float32(2)
+ C[i0 * 4 + i1, j] = B[i1, j] * T.float32(2)
@T.prim_func