You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wr...@apache.org on 2022/09/08 15:03:11 UTC

[tvm] branch main updated: [TIR] Handle axis_separators during FlattenBuffer (#12652)

This is an automated email from the ASF dual-hosted git repository.

wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b2bd434ef9 [TIR] Handle axis_separators during FlattenBuffer (#12652)
b2bd434ef9 is described below

commit b2bd434ef944315a6f241803ac03c59c9aaa9847
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Thu Sep 8 08:02:42 2022 -0700

    [TIR] Handle axis_separators during FlattenBuffer (#12652)
    
    * [TIR] Moved tir.FlattenBuffer to occur before tir.LowerOpaqueBlock
    
    For buffers with more than one physical axis, the `axis_separators`
    are required in order to know which groups of logical axes to fuse
    into each physical axis.  The implementation in `tir.FlattenBuffer`
    assumed that all buffers were being flattened to a single physical
    axis.  Because `tir.LowerOpaqueBlock` replaces the
    `BlockNode::alloc_buffers` with `Allocate` nodes, `tir.FlattenBuffer`
    no longer has access to the axis separators and performs inconsistent
    flattening for `Allocate` as opposed to `BufferLoad`/`BufferStore`.
    This was introduced in https://github.com/apache/tvm/pull/12172, which
    decoupled the lowering/flattening steps.
    
    The commit reorders the `tir.FlattenBuffer` to occur before
    `tir.LowerOpaqueBlock`, to make use of the axis separators.  Any
    `Allocate` nodes that exist at that point (e.g. from hand-written
    schedules) are still flattened to 1-d physical buffers, but the
    `BlockNode::alloc_buffers` are flattened according to the axis
    separators.
    
    * Add unit test to validate non-flat memory after tvm.lower
    
    * Explicitly write T.reads for test on BufferRegion updates
    
    * Update incorrect docstring for test
    
    * Use DeclBuffer information in FlattenBuffer
    
    The DeclBuffer node can be inserted during LowerOpaqueBlock, then
    provide the missing Buffer information required to flatten the
    allocation.
    
    * Use T.allocate in unit tests
    
    With the insertion of `DeclBuffer` nodes, `LowerOpaqueBlock` no longer
    needs to be before `FlattenBuffer`, and has been moved back to its
    original position.  Revering the tests to use `T.allocate` instead of
    `T.alloc_buffer` more closely represents the functions as they are
    being lowered.
    
    * Fix usage of T.decl_buffer in updated tests
    
    * Update LowerOpaqueBuffer to expect the DeclBuffer nodes
    
    * Strip DeclBuffer annotation in FlattenBuffer
    
    The DeclBuffer annotations aren't yet supported in all passes.  This
    restricts them to being introduced in LowerOpaqueBuffer, then
    immediately removed in FlattenBuffer.
    
    * Strip out all DeclBuffer nodes in FlattenBuffer
    
    * Update unit tests to remove expectation of DeclBuffer nodes
---
 src/tir/transforms/flatten_buffer.cc               | 123 ++++-
 src/tir/transforms/lower_opaque_block.cc           |   1 +
 .../unittest/test_tir_transform_flatten_buffer.py  | 502 ++++++++++++---------
 .../test_tir_transform_lower_opaque_block.py       |  22 +-
 4 files changed, 417 insertions(+), 231 deletions(-)

diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc
index 22aef136bc..5441120491 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -21,6 +21,7 @@
  * \file flatten_buffer.cc
  */
 
+#include <tvm/tir/analysis.h>
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
@@ -53,6 +54,34 @@ class BufferFlattener : public StmtExprMutator {
     }
   }
 
+  Stmt VisitStmt_(const BlockNode* op) final {
+    ICHECK_EQ(op->match_buffers.size(), 0)
+        << "Unexpected MatchBufferRegion found during tir.transform.FlattenBuffer.  "
+        << "All MatchBufferRegion should be removed in tir.transform.LowerMatchBuffer.";
+
+    Block block = GetRef<Block>(op);
+
+    Array<Buffer> alloc_buffers = op->alloc_buffers;
+    alloc_buffers.MutateByApply([this](Buffer buf) { return GetFlattenedBuffer(buf); });
+    if (!alloc_buffers.same_as(op->alloc_buffers)) {
+      block.CopyOnWrite()->alloc_buffers = alloc_buffers;
+    }
+
+    Array<BufferRegion> reads = op->reads;
+    reads.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); });
+    if (!reads.same_as(op->reads)) {
+      block.CopyOnWrite()->reads = reads;
+    }
+
+    Array<BufferRegion> writes = op->writes;
+    writes.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); });
+    if (!writes.same_as(op->writes)) {
+      block.CopyOnWrite()->writes = writes;
+    }
+
+    return StmtExprMutator::VisitStmt_(block.get());
+  }
+
   Stmt VisitStmt_(const AllocateNode* op) final {
     Allocate alloc = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
     // TODO(Lunderberg): Move the handling of boolean into a
@@ -61,18 +90,70 @@ class BufferFlattener : public StmtExprMutator {
       auto writer = alloc.CopyOnWrite();
       writer->dtype = DataType::Int(8);
     }
-    // Handle multi-dimension allocations
+
     if (alloc->extents.size() == 1) {
-      return std::move(alloc);
-    } else {
-      Array<PrimExpr> flat_extent(static_cast<size_t>(1), 1);
-      for (size_t i = 0; i < alloc->extents.size(); i++) {
-        flat_extent.Set(0, flat_extent[0] * alloc->extents[i]);
+      // No flattening required for buffers that are already flat
+
+      // TODO(rfc-70): Keep the DeclBuffer node as-is.  Stripping it
+      // out in the current implementation as not all lowering passes
+      // support DeclBuffer.
+      if (auto* decl_buffer = alloc->body.as<DeclBufferNode>()) {
+        alloc.CopyOnWrite()->body = std::move(decl_buffer->body);
       }
-      auto n = alloc.CopyOnWrite();
-      n->extents = flat_extent;
+
       return std::move(alloc);
     }
+
+    if (auto* decl_buffer = alloc->body.as<DeclBufferNode>();
+        decl_buffer && decl_buffer->buffer->data.same_as(alloc->buffer_var)) {
+      // N-d buffer, use the DeclBuffer inside to determine how it
+      // should be flattened.
+      auto& buffer = decl_buffer->buffer;
+      bool matching_buffer = [&]() {
+        if (alloc->dtype != buffer->dtype) {
+          return false;
+        }
+        if (alloc->extents.size() != buffer->shape.size()) {
+          return false;
+        }
+        ExprDeepEqual expr_equal;
+        for (size_t i = 0; i < alloc->extents.size(); i++) {
+          if (!expr_equal(alloc->extents[i], buffer->shape[i])) {
+            return false;
+          }
+        }
+        return true;
+      }();
+
+      if (matching_buffer) {
+        Buffer flattened = GetFlattenedBuffer(buffer);
+
+        auto n = alloc.CopyOnWrite();
+        // TODO(rfc-70): Update the DeclBuffer node instead of
+        // stripping it out.  Stripping it out in the current
+        // implementation as not all lowering passes support
+        // DeclBuffer.
+        //
+        // n->body = DeclBuffer(flattened, std::move(decl_buffer->body));
+        n->body = std::move(decl_buffer->body);
+        n->extents = flattened->shape;
+        return std::move(alloc);
+      } else {
+        ICHECK(decl_buffer->buffer->axis_separators.empty())
+            << "DeclBuffer node doesn't match Allocate extents, but also shouldn't be "
+               "flattened to 1-d physical memory";
+      }
+    }
+
+    // Fallback, this is an allocation without a matching DeclBuffer
+    PrimExpr flat_extent = 1;
+    for (const auto& dim : alloc->extents) {
+      flat_extent *= dim;
+    }
+
+    auto n = alloc.CopyOnWrite();
+    n->extents = {flat_extent};
+    return std::move(alloc);
   }
 
   Buffer GetFlattenedBuffer(Buffer buf) {
@@ -141,6 +222,32 @@ class BufferFlattener : public StmtExprMutator {
     return node;
   }
 
+  BufferRegion MutateBufferRegion(BufferRegion region) {
+    Buffer orig_buf = region->buffer;
+    Buffer flattened_buf = GetFlattenedBuffer(orig_buf);
+    if (flattened_buf.same_as(orig_buf)) {
+      return region;
+    }
+
+    Array<PrimExpr> min_values;
+    Array<PrimExpr> max_values;
+    for (const auto& range : region->region) {
+      min_values.push_back(range->min);
+      max_values.push_back(range->min + range->extent - 1);
+    }
+
+    Array<PrimExpr> flattened_min = orig_buf->ElemOffset(min_values);
+    Array<PrimExpr> flattened_max = orig_buf->ElemOffset(max_values);
+
+    Array<Range> flattened_ranges;
+    ICHECK_EQ(flattened_min.size(), flattened_max.size());
+    for (size_t i = 0; i < flattened_min.size(); i++) {
+      flattened_ranges.push_back(Range(flattened_min[i], flattened_max[i] + 1));
+    }
+
+    return BufferRegion(flattened_buf, flattened_ranges);
+  }
+
   /*! \brief Map of buffers being remapped. */
   std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;
 
diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/tir/transforms/lower_opaque_block.cc
index a4655ebbae..ce74fdc4c1 100644
--- a/src/tir/transforms/lower_opaque_block.cc
+++ b/src/tir/transforms/lower_opaque_block.cc
@@ -57,6 +57,7 @@ class OpaqueBlockLower : public StmtExprMutator {
           new_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]);
         }
       }
+      body = DeclBuffer(buffer, std::move(body));
       body = Allocate(buffer->data, buffer->dtype, new_shape, const_true(), std::move(body));
     }
     // Step 4. Handle annotations, block annotations are not preserved by default.
diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py
index 4cdf71889e..870208499e 100644
--- a/tests/python/unittest/test_tir_transform_flatten_buffer.py
+++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py
@@ -20,223 +20,307 @@ from tvm import te
 from tvm.script import tir as T
 
 
-def _check(original, transformed):
-    func = original
-    mod = tvm.IRModule.from_expr(func)
-    mod = tvm.tir.transform.FlattenBuffer()(mod)
-    mod = tvm.tir.transform.Simplify()(mod)
-    tvm.ir.assert_structural_equal(mod["main"], transformed, True)
-
-
-@T.prim_func
-def elementwise_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    C = T.match_buffer(c, (16, 16), "float32")
-    for i in T.serial(0, 16):
-        B_new_data = T.allocate([1, 16], "float32", "global")
-        B_new = T.buffer_decl(shape=[1, 16], dtype="float32", data=B_new_data)
-        for j in T.serial(0, 16):
-            B_new[0, j] = A[i, j] + 1.0
-        for j in T.serial(0, 16):
-            C[i, j] = B_new[0, j] * 2.0
-
-
-@T.prim_func
-def flattened_elementwise_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, 256, "float32")
-    C = T.match_buffer(c, 256, "float32")
-    T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
-    T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
-    for i in T.serial(0, 16):
-        B_new_data = T.allocate([16], "float32", "global")
-        B_new = T.buffer_decl(shape=[16], dtype="float32", data=B_new_data)
-        for j in T.serial(0, 16):
-            B_new[j] = A[((i * 16) + j)] + 1.0
-        for j in T.serial(0, 16):
-            C[((i * 16) + j)] = B_new[j] * 2.0
-
-
-@T.prim_func
-def gpu_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    C = T.match_buffer(c, (16, 16), "float32")
-
-    i0 = T.env_thread("blockIdx.x")
-    i1 = T.env_thread("threadIdx.x")
-    i2 = T.env_thread("vthread")
-
-    T.launch_thread(i0, 4)
-    T.launch_thread(i1, 2)
-    T.launch_thread(i2, 2)
-    B_data = T.allocate([1, 16], "float32", "local")
-    B = T.buffer_decl(shape=[1, 16], dtype="float32", data=B_data, scope="local")
-    for j in range(0, 16):
-        B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
-    for j in range(0, 16):
-        C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0
-
-
-@T.prim_func
-def flattened_gpu_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, 256, "float32")
-    C = T.match_buffer(c, 256, "float32")
-    T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
-    T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
-
-    i0 = T.env_thread("blockIdx.x")
-    i1 = T.env_thread("threadIdx.x")
-    i2 = T.env_thread("vthread")
-
-    T.launch_thread(i0, 4)
-    T.launch_thread(i1, 2)
-    T.launch_thread(i2, 2)
-    B_data = T.allocate([16], "float32", "local")
-    B = T.buffer_decl(shape=[16], dtype="float32", data=B_data, scope="local")
-    for j in range(0, 16):
-        B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0
-    for j in range(0, 16):
-        C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0
-
-
-@T.prim_func
-def symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None:
-    A = T.match_buffer(a, (n, m), "float32")
-    C = T.match_buffer(c, (n, m), "float32")
-
-    for i in range(0, n):
-        B_data = T.allocate([m], "float32", "global")
-        B = T.buffer_decl(shape=[m], dtype="float32", data=B_data)
-        for j in range(0, m):
-            B[j] = A[i, j] + 1.0
-        for j in range(0, m):
-            C[i, j] = B[j] * 2.0
-
-
-@T.prim_func
-def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None:
-    A = T.match_buffer(a, n * m, "float32")
-    C = T.match_buffer(c, n * m, "float32")
-    T.preflattened_buffer(A, (n, m), "float32", data=A.data)
-    T.preflattened_buffer(C, (n, m), "float32", data=C.data)
-
-    for i in range(0, n):
-        B_data = T.allocate([m], "float32", "global")
-        B = T.buffer_decl(shape=[m], dtype="float32", data=B_data)
-        for j in range(0, m):
-            B[j] = A[i * m + j] + 1.0
-        for j in range(0, m):
-            C[i * m + j] = B[j] * 2.0
-
-
-@T.prim_func
-def multi_alloc_func(a: T.handle, d: T.handle) -> None:
-    A = T.match_buffer(a, (4, 32), "float32")
-    D = T.match_buffer(d, (4, 32), "float32")
-
-    for i, j in T.grid(4, 32):
-        B_data = T.allocate((4, 32), "float32", scope="global")
-        B = T.buffer_decl(shape=(4, 32), dtype="float32", data=B_data)
-        C_data = T.allocate((4, 32), "float32", scope="global")
-        C = T.buffer_decl(shape=(4, 32), dtype="float32", data=C_data)
-        B[i, j] = A[i, j] + 1.0
-        C[i, j] = A[i, j] + B[i, j]
-        D[i, j] = C[i, j] * 2.0
-
-
-@T.prim_func
-def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None:
-    A = T.match_buffer(a, 128, "float32")
-    D = T.match_buffer(d, 128, "float32")
-    T.preflattened_buffer(A, (4, 32), "float32", data=A.data)
-    T.preflattened_buffer(D, (4, 32), "float32", data=D.data)
-
-    for i, j in T.grid(4, 32):
-        B_data = T.allocate([128], "float32", "global")
-        B = T.buffer_decl(shape=[128], dtype="float32", data=B_data)
-        C_data = T.allocate([128], "float32", "global")
-        C = T.buffer_decl(shape=[128], dtype="float32", data=C_data)
-        B[i * 32 + j] = A[i * 32 + j] + 1.0
-        C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j]
-        D[i * 32 + j] = C[i * 32 + j] * 2.0
-
-
-@T.prim_func
-def strided_buffer_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    C = T.match_buffer(c, (16, 16), "float32")
-    for i0 in T.serial(4):
-        B_data = T.allocate([4, 17], "float32", "global")
-        B = T.buffer_decl(shape=[4, 17], dtype="float32", data=B_data)
-        B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1])
-        for i1, j in T.grid(4, 16):
-            B_1[i1, j] = A[i0 * 4 + i1, j] + 1.0
-        for i1, j in T.grid(4, 16):
-            C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0
-
-
-@T.prim_func
-def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (256,), "float32")
-    C = T.match_buffer(c, (256,), "float32")
-    T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data)
-    T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data)
-    for i0 in T.serial(0, 4):
-        B_new_data = T.allocate([68], "float32", "global")
-        B_new = T.buffer_decl(shape=[68], dtype="float32", data=B_new_data)
-        for i1 in T.serial(0, 4):
-            for j in T.serial(0, 16):
-                B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0
-        for i1 in T.serial(0, 4):
-            for j in T.serial(0, 16):
-                C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0
-
-
-@T.prim_func
-def boolean_handling_before(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) -> None:
-    for i0 in T.serial(10):
-        b[i0] = a[i0]
-
-
-@T.prim_func
-def boolean_handling_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) -> None:
-    T.preflattened_buffer(a, [10], dtype="bool", data=a.data)
-    T.preflattened_buffer(b, [10], dtype="bool", data=b.data)
-    # body
-    for i0 in T.serial(10):
-        b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")
-
-
-def test_elementwise():
-    _check(elementwise_func, flattened_elementwise_func)
-
-
-def test_gpu_workload():
-    _check(gpu_func, flattened_gpu_func)
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+    transform = tvm.transform.Sequential(
+        [
+            tvm.tir.transform.FlattenBuffer(),
+            tvm.tir.transform.Simplify(),
+        ]
+    )
 
 
-def test_symbolic_shape():
-    _check(symbolic_func, flattened_symbolic_func)
-
-
-def test_multi_alloc():
-    _check(multi_alloc_func, flattened_multi_alloc_func)
+class TestElementwise(BaseCompare):
+    """2-d buffers are flattened to 1-d"""
 
+    def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]):
+        for i in T.serial(0, 16):
+            B_new = T.decl_buffer([1, 16], "float32")
+            for j in T.serial(0, 16):
+                B_new[0, j] = A[i, j] + 1.0
+            for j in T.serial(0, 16):
+                C[i, j] = B_new[0, j] * 2.0
+
+    def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]):
+        T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
+        T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
+        for i in T.serial(0, 16):
+            B_new_data = T.allocate([16], "float32", scope="global")
+            B_new = T.buffer_decl([16], "float32", scope="global", data=B_new_data)
+            for j in T.serial(0, 16):
+                B_new[j] = A[((i * 16) + j)] + 1.0
+            for j in T.serial(0, 16):
+                C[((i * 16) + j)] = B_new[j] * 2.0
 
-def test_strided_buffer():
-    _check(strided_buffer_func, flattened_strided_buffer_func)
 
+class TestElementwiseWithoutDeclBuffer(BaseCompare):
+    """2-d buffers are flattened to 1-d
 
-def test_lower_te():
-    x = te.placeholder((1,))
-    y = te.compute((1,), lambda i: x[i] + 2)
-    s = te.create_schedule(y.op)
-    orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
-    mod = tvm.tir.transform.FlattenBuffer()(orig_mod)
-    tvm.ir.assert_structural_equal(mod, orig_mod)  # FlattenBuffer should do nothing on TE
+    Like TestElementwise, but the TIR doesn't have the DeclBuffer
+    node.  The T.buffer_decl declaration applies only during the
+    parsing the TVMScript, and doesn't occur in the TIR itself.  In
+    this case, the allocation should be assumed to be targeting flat
+    memory, and should be flattened to a 1-d allocation.
+    """
 
+    def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]):
+        for i in T.serial(0, 16):
+            B_new_data = T.allocate([1, 16], "float32", "global")
+            B_new = T.buffer_decl([1, 16], "float32", data=B_new_data)
+            for j in T.serial(0, 16):
+                B_new[0, j] = A[i, j] + 1.0
+            for j in T.serial(0, 16):
+                C[i, j] = B_new[0, j] * 2.0
+
+    def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]):
+        T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
+        T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
+        for i in T.serial(0, 16):
+            B_new_data = T.allocate([16], "float32", "global")
+            B_new = T.buffer_decl(16, "float32", data=B_new_data)
+            for j in T.serial(0, 16):
+                B_new[j] = A[((i * 16) + j)] + 1.0
+            for j in T.serial(0, 16):
+                C[((i * 16) + j)] = B_new[j] * 2.0
+
+
+class TestGPU(BaseCompare):
+    """Buffer flattening may have indices based on GPU thread vars"""
+
+    def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]):
+        i0 = T.env_thread("blockIdx.x")
+        i1 = T.env_thread("threadIdx.x")
+        i2 = T.env_thread("vthread")
+
+        T.launch_thread(i0, 4)
+        T.launch_thread(i1, 2)
+        T.launch_thread(i2, 2)
+        B = T.decl_buffer([1, 16], "float32", scope="local")
+        for j in range(0, 16):
+            B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
+        for j in range(0, 16):
+            C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0
+
+    def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]):
+        T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
+        T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
+
+        i0 = T.env_thread("blockIdx.x")
+        i1 = T.env_thread("threadIdx.x")
+        i2 = T.env_thread("vthread")
+
+        T.launch_thread(i0, 4)
+        T.launch_thread(i1, 2)
+        T.launch_thread(i2, 2)
+        B_data = T.allocate([16], "float32", scope="local")
+        B = T.buffer_decl([16], "float32", scope="local", data=B_data)
+        for j in range(0, 16):
+            B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0
+        for j in range(0, 16):
+            C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0
+
+
+class TestSymbolic(BaseCompare):
+    """Dynamically-sized arrrays are flattened"""
+
+    def before(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None:
+        A = T.match_buffer(a, (n, m), "float32")
+        C = T.match_buffer(c, (n, m), "float32")
+
+        for i in range(0, n):
+            B = T.decl_buffer([m], "float32")
+            for j in range(0, m):
+                B[j] = A[i, j] + 1.0
+            for j in range(0, m):
+                C[i, j] = B[j] * 2.0
+
+    def expected(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None:
+        A = T.match_buffer(a, n * m, "float32")
+        C = T.match_buffer(c, n * m, "float32")
+        T.preflattened_buffer(A, (n, m), "float32", data=A.data)
+        T.preflattened_buffer(C, (n, m), "float32", data=C.data)
+
+        for i in range(0, n):
+            B_data = T.allocate([m], "float32", scope="global")
+            B = T.buffer_decl([m], "float32", scope="global", data=B_data)
+            for j in range(0, m):
+                B[j] = A[i * m + j] + 1.0
+            for j in range(0, m):
+                C[i * m + j] = B[j] * 2.0
+
+
+class TestMultiAlloc(BaseCompare):
+    """If multiple allocations occur, all are flattened."""
+
+    def before(A: T.Buffer[(4, 32), "float32"], D: T.Buffer[(4, 32), "float32"]):
+        for i, j in T.grid(4, 32):
+            B = T.decl_buffer((4, 32), "float32", scope="global")
+            C = T.decl_buffer((4, 32), "float32", scope="global")
+            B[i, j] = A[i, j] + 1.0
+            C[i, j] = A[i, j] + B[i, j]
+            D[i, j] = C[i, j] * 2.0
+
+    def expected(A: T.Buffer[128, "float32"], D: T.Buffer[128, "float32"]):
+        T.preflattened_buffer(A, (4, 32), "float32", data=A.data)
+        T.preflattened_buffer(D, (4, 32), "float32", data=D.data)
+
+        for i, j in T.grid(4, 32):
+            B_data = T.allocate([128], "float32", scope="global")
+            B = T.buffer_decl([128], "float32", scope="global", data=B_data)
+            C_data = T.allocate([128], "float32", scope="global")
+            C = T.buffer_decl([128], "float32", scope="global", data=C_data)
+            B[i * 32 + j] = A[i * 32 + j] + 1.0
+            C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j]
+            D[i * 32 + j] = C[i * 32 + j] * 2.0
+
+
+class TestStrided(BaseCompare):
+    """Indices for flattened buffers use the specified striding."""
+
+    def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]):
+        for i0 in T.serial(4):
+            B = T.decl_buffer([4, 17], "float32")
+            B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1])
+            for i1, j in T.grid(4, 16):
+                B_1[i1, j] = A[i0 * 4 + i1, j] + 1.0
+            for i1, j in T.grid(4, 16):
+                C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0
+
+    def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]):
+        T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data)
+        T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data)
+        for i0 in T.serial(0, 4):
+            B_new_data = T.allocate([68], "float32", scope="global")
+            B_new = T.buffer_decl([68], "float32", scope="global", data=B_new_data)
+            for i1 in T.serial(0, 4):
+                for j in T.serial(0, 16):
+                    B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0
+            for i1 in T.serial(0, 4):
+                for j in T.serial(0, 16):
+                    C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0
+
+
+class TestBoolean(BaseCompare):
+    """Boolean buffers should be replaced by a backing int8 array"""
+
+    def before(A: T.Buffer[10, "bool"], B: T.Buffer[10, "bool"]) -> None:
+        for i0 in T.serial(10):
+            B[i0] = A[i0]
+
+    def expected(A: T.Buffer[10, "int8"], B: T.Buffer[10, "int8"]) -> None:
+        T.preflattened_buffer(A, [10], dtype="bool", data=A.data)
+        T.preflattened_buffer(B, [10], dtype="bool", data=B.data)
+        # body
+        for i0 in T.serial(10):
+            B[i0] = T.cast(T.cast(A[i0], "bool"), "int8")
+
+
+class TestLowerTE(BaseCompare):
+    """FlattenBuffer should do nothing on TE-based functions"""
+
+    def before(self):
+        x = te.placeholder((1,))
+        y = te.compute((1,), lambda i: x[i] + 2)
+        s = te.create_schedule(y.op)
+        mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
+        return mod["main"]
+
+    expected = before
+
+
+class TestFlattenInsideBlock(BaseCompare):
+    """Flattening access inside a block flattens the accessed region."""
+
+    def before():
+        A = T.alloc_buffer([32, 32])
+        for i, j in T.grid(32, 32):
+            with T.block("block"):
+                T.reads(A[i, j])
+                T.evaluate(A[i, j])
+
+    def expected():
+        A = T.alloc_buffer([1024])
+        for i, j in T.grid(32, 32):
+            with T.block("block"):
+                T.reads(A[i * 32 + j])
+                T.evaluate(A[i * 32 + j])
+
+
+class TestNoChangeTo2DPhysicalBuffer(BaseCompare):
+    """Flattening preserves axis separators."""
+
+    def before():
+        A = T.alloc_buffer([32, 32], axis_separators=[1])
+        for i, j in T.grid(32, 32):
+            T.evaluate(A[i, j])
+
+    expected = before
+
+
+class TestFlattenAllocBufferWithAxisSeparators(BaseCompare):
+    """Flattening preserves axis separators"""
+
+    def before():
+        A = T.alloc_buffer([2, 3, 5, 7, 11, 13], axis_separators=[3])
+        for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13):
+            T.evaluate(A[i0, i1, i2, i3, i4, i5])
+
+    def expected():
+        A = T.alloc_buffer([30, 1001], axis_separators=[1])
+        for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13):
+            T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5])
+
+
+class TestFlattenDeclBufferWithAxisSeparators(BaseCompare):
+    """Flattening preserves axis separators
+
+    Like TestFlattenAllocBufferWithAxisSeparators, but the allocations
+    is done using Allocate/DeclBuffer, rather than through
+    BlockNode::alloc_buffers.
+    """
+
+    def before():
+        A = T.decl_buffer([2, 3, 5, 7, 11, 13], axis_separators=[3])
+        for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13):
+            T.evaluate(A[i0, i1, i2, i3, i4, i5])
+
+    def expected():
+        A_data = T.allocate([30, 1001], dtype="float32", scope="global")
+        A = T.buffer_decl(
+            [30, 1001], dtype="float32", scope="global", axis_separators=[1], data=A_data
+        )
+        for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13):
+            T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5])
+
+
+def test_lower_2d_physical_memory():
+    """Axis separators should preserve 2-d buffers through lowering.
 
-def test_boolean_handling():
-    _check(boolean_handling_before, boolean_handling_after)
+    A catch-all test to ensure that defining axis_separators is
+    sufficient to maintain non-flat buffer descriptions through all
+    lowering steps.
+    """
+
+    # This test doesn't use CompareBeforeAfter, because the after step
+    # is not currently expressible in TVMScript.  This test can be
+    # re-written after https://github.com/apache/tvm/pull/12412.
+
+    @T.prim_func
+    def func():
+        buf = T.alloc_buffer(
+            [1, 1],
+            dtype="int32",
+            scope="global",
+            axis_separators=[1],
+        )
+        buf[0, 0] = 0
+
+    lowered = tvm.lower(func)["main"]
+    assert isinstance(lowered.body, tvm.tir.Allocate)
+    assert list(lowered.body.extents) == [1, 1], (
+        "Non-flat buffer allocations, "
+        "marked by axis_separators, "
+        "flattened to flat memory allocation."
+    )
 
 
 if __name__ == "__main__":
diff --git a/tests/python/unittest/test_tir_transform_lower_opaque_block.py b/tests/python/unittest/test_tir_transform_lower_opaque_block.py
index f8f3e3a5ac..824cef1740 100644
--- a/tests/python/unittest/test_tir_transform_lower_opaque_block.py
+++ b/tests/python/unittest/test_tir_transform_lower_opaque_block.py
@@ -54,8 +54,7 @@ def transformed_elementwise_func(a: T.handle, c: T.handle) -> None:
     A = T.match_buffer(a, (16, 16), "float32")
     C = T.match_buffer(c, (16, 16), "float32")
     for i in T.serial(0, 16):
-        B_new_data = T.allocate([1, 16], "float32", "global")
-        B_new = T.buffer_decl(shape=[1, 16], dtype="float32", data=B_new_data)
+        B_new = T.decl_buffer(shape=[1, 16], dtype="float32")
         for j in T.serial(0, 16):
             B_new[0, j] = A[i, j] + 1.0
         for j in T.serial(0, 16):
@@ -97,8 +96,7 @@ def transformed_gpu_func(a: T.handle, c: T.handle) -> None:
     T.launch_thread(i0, 4)
     T.launch_thread(i1, 2)
     T.launch_thread(i2, 2)
-    B_data = T.allocate([1, 16], "float32", "local")
-    B = T.buffer_decl(shape=[1, 16], dtype="float32", scope="local", data=B_data)
+    B = T.decl_buffer(shape=[1, 16], dtype="float32", scope="local")
     for j in range(0, 16):
         B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
     for j in range(0, 16):
@@ -133,8 +131,7 @@ def transformed_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32)
     C = T.match_buffer(c, (n, m), "float32")
 
     for i in range(0, n):
-        B_data = T.allocate([m], "float32", "global")
-        B = T.buffer_decl(shape=[m], dtype="float32", data=B_data)
+        B = T.decl_buffer(shape=[m], dtype="float32")
         for j in range(0, m):
             B[j] = A[i, j] + 1.0
         for j in range(0, m):
@@ -207,10 +204,8 @@ def transformed_multi_alloc_func(a: T.handle, d: T.handle) -> None:
     D = T.match_buffer(d, (32), "float32")
 
     for i in range(0, 32):
-        B_data = T.allocate((32,), "float32", "global")
-        B = T.buffer_decl(shape=(32,), dtype="float32", data=B_data)
-        C_data = T.allocate((32,), "float32", "global")
-        C = T.buffer_decl(shape=(32,), dtype="float32", data=C_data)
+        B = T.decl_buffer(shape=(32,), dtype="float32")
+        C = T.decl_buffer(shape=(32,), dtype="float32")
         B[i] = A[i] + 1.0
         C[i] = A[i] + B[i]
         D[i] = C[i] * 2.0
@@ -246,12 +241,11 @@ def transformed_strided_buffer_func(
     # body
     for i0 in T.serial(4):
         B_data = T.allocate([4, 17], "float32", "global")
-        B = T.buffer_decl(shape=[4, 17], dtype="float32", data=B_data)
-        B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1])
+        B = T.decl_buffer(shape=[4, 16], dtype="float32", strides=[17, 1], data=B_data)
         for i1, j in T.grid(4, 16):
-            B_1[i1, j] = A[i0 * 4 + i1, j] + T.float32(1)
+            B[i1, j] = A[i0 * 4 + i1, j] + T.float32(1)
         for i1, j in T.grid(4, 16):
-            C[i0 * 4 + i1, j] = B_1[i1, j] * T.float32(2)
+            C[i0 * 4 + i1, j] = B[i1, j] * T.float32(2)
 
 
 @T.prim_func