You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2022/11/16 14:19:41 UTC

[tvm] branch main updated: [TIR] Remove PrimFuncNode::preflattened_buffer_map (#10940)

This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 78b53221f8 [TIR] Remove PrimFuncNode::preflattened_buffer_map (#10940)
78b53221f8 is described below

commit 78b53221f8dd8c1d2bbeff9d34803db33ca254dd
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Wed Nov 16 08:19:34 2022 -0600

    [TIR] Remove PrimFuncNode::preflattened_buffer_map (#10940)
    
    `PrimFuncNode::preflattened_buffer_map` was introduced in
    https://github.com/apache/tvm/pull/9727, in order to maintain a record
    of the pre-flattened buffer shape until it can be used in
    `MakePackedAPI`.  This commit instead maintains the pre-flattened
    shapes in `PrimFuncNode::buffer_map`, while the body of the function
    uses a flattened buffer alias, as described in
    [RFC#70](https://github.com/apache/tvm-rfcs/pull/70)
---
 include/tvm/script/ir_builder/tir/frame.h          |   3 -
 include/tvm/script/ir_builder/tir/ir.h             |  20 --
 include/tvm/tir/function.h                         |  43 ++---
 .../tvm/relay/backend/contrib/ethosu/tir/passes.py |  77 +++++---
 python/tvm/script/ir_builder/tir/ir.py             |  69 -------
 python/tvm/script/parser_v1/context_maintainer.py  |   3 -
 python/tvm/script/parser_v1/parser.py              |   1 -
 python/tvm/script/parser_v1/tir/__init__.pyi       |  12 --
 python/tvm/script/parser_v1/tir/special_stmt.py    |  73 --------
 python/tvm/tir/function.py                         |   7 -
 src/printer/tir_text_printer.cc                    |  10 -
 src/printer/tvmscript_printer.cc                   |  20 --
 src/relay/backend/aot/aot_lower_main.cc            |   2 +-
 src/relay/backend/aot_executor_codegen.cc          |   2 +-
 src/relay/backend/contrib/cmsisnn/relay_to_tir.cc  |   2 +-
 .../contrib/example_target_hooks/relay_to_tir.cc   |   2 +-
 src/script/ir_builder/tir/frame.cc                 |   1 -
 src/script/ir_builder/tir/ir.cc                    |  22 ---
 src/tir/analysis/device_constraint_utils.cc        |  22 +--
 src/tir/contrib/ethosu/passes.cc                   |   6 +-
 src/tir/ir/function.cc                             |  10 +-
 src/tir/transforms/bf16_legalize.cc                |  29 ---
 src/tir/transforms/flatten_buffer.cc               |  16 +-
 src/tir/transforms/legalize_packed_calls.cc        |   4 +-
 src/tir/transforms/make_packed_api.cc              |   4 +-
 .../plan_update_buffer_allocation_location.cc      |   6 +-
 src/tir/transforms/storage_flatten.cc              | 201 ++++++++++++---------
 src/tir/usmp/transform/assign_pool_info.cc         |   4 +-
 .../convert_pool_allocations_to_offsets.cc         |  10 +-
 src/tir/usmp/transform/create_io_allocates.cc      |   5 +-
 .../contrib/test_ethosu/test_encode_constants.py   |  33 +++-
 .../contrib/test_ethosu/test_hoist_allocates.py    |  31 ++--
 .../contrib/test_ethosu/test_merge_constants.py    |  44 +++--
 .../test_ethosu/test_remove_concatenates.py        |   7 +-
 .../contrib/test_ethosu/test_replace_conv2d.py     |  48 +++--
 .../contrib/test_ethosu/test_replace_copy.py       |   8 +-
 tests/python/contrib/test_ethosu/test_scheduler.py |   4 +-
 .../test_hexagon/test_2d_physical_buffers.py       |   2 +-
 .../unittest/test_aot_legalize_packed_call.py      |  26 +--
 tests/python/unittest/test_arith_domain_touched.py |  24 ++-
 .../python/unittest/test_auto_scheduler_feature.py |  16 +-
 tests/python/unittest/test_lower_build.py          |  36 ++--
 .../unittest/test_tir_transform_flatten_buffer.py  |  44 ++---
 .../unittest/test_tir_transform_loop_partition.py  |  73 ++++----
 ...test_tir_transform_renormalize_split_pattern.py |  42 ++---
 .../unittest/test_tir_transform_thread_sync.py     |   4 +-
 ...ransform_convert_pool_allocations_to_offsets.py |  72 --------
 .../python/unittest/test_tvmscript_error_report.py |  20 --
 .../unittest/test_tvmscript_ir_builder_tir.py      |   5 -
 .../python/unittest/test_tvmscript_syntax_sugar.py |  17 --
 50 files changed, 462 insertions(+), 780 deletions(-)

diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h
index b95d575360..ee80322362 100644
--- a/include/tvm/script/ir_builder/tir/frame.h
+++ b/include/tvm/script/ir_builder/tir/frame.h
@@ -75,8 +75,6 @@ class PrimFuncFrameNode : public TIRFrameNode {
   Optional<Type> ret_type;
   /*! \brief Maps some parameters to specific Buffer data structures. */
   Map<tvm::tir::Var, tvm::tir::Buffer> buffer_map;
-  /*! \brief The buffer map prior to flattening. */
-  Map<tvm::tir::Var, tvm::tir::Buffer> preflattened_buffer_map;
   /*! \brief Additional attributes storing the meta-data */
   Optional<Map<String, ObjectRef>> attrs;
   /*! \brief The variable map bound to thread env. */
@@ -90,7 +88,6 @@ class PrimFuncFrameNode : public TIRFrameNode {
     v->Visit("args", &args);
     v->Visit("ret_type", &ret_type);
     v->Visit("buffer_map", &buffer_map);
-    v->Visit("preflattened_buffer_map", &preflattened_buffer_map);
     v->Visit("attrs", &attrs);
     v->Visit("env_threads", &env_threads);
     v->Visit("root_alloc_buffers", &root_alloc_buffers);
diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h
index d9e1a1b490..5cba879205 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -114,26 +114,6 @@ Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype = Data
                    int align = -1, int offset_factor = 0, String buffer_type = "default",
                    Array<IntImm> axis_separators = {});
 
-/*!
- * \brief The pre-flattened buffer statement.
- * \param postflattened_buffer The original buffer to be flattened.
- * \param shape The type of the buffer prior to flattening.
- * \param dtype The data type in the content of the buffer.
- * \param data The pointer to the head of the data.
- * \param strides The strides of each dimension.
- * \param elem_offset The offset in terms of number of dtype elements (including lanes).
- * \param storage_scope The optional storage scope of buffer data pointer.
- * \param align The alignment requirement of data pointer in bytes.
- * \param offset_factor The factor of elem_offset field.
- * \param buffer_type The buffer type.
- * \param axis_separators The separators between input axes when generating flattened output axes.
- */
-void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape,
-                        DataType dtype = DataType::Float(32), Optional<Var> data = NullOpt,
-                        Array<PrimExpr> strides = {}, PrimExpr elem_offset = PrimExpr(),
-                        String storage_scope = "global", int align = -1, int offset_factor = 0,
-                        String buffer_type = "default", Array<IntImm> axis_separators = {});
-
 /*!
  * \brief The block declaration statement.
  * \param name The name of the block.
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
index d793d84fc6..cf92f97360 100644
--- a/include/tvm/tir/function.h
+++ b/include/tvm/tir/function.h
@@ -88,33 +88,22 @@ class PrimFuncNode : public BaseFuncNode {
    *  While we could have express parameter unpacking and constraint using
    *  normal statements, making buffer_map as first class citizen of PrimFunc
    *  will make program analysis much easier.
-   */
-  Map<tir::Var, Buffer> buffer_map;
-
-  /*! \brief The buffer map prior to flattening.
-   *
-   * This contains the buffers as they exists prior to flattening, and
-   * is used for validating an input tensor passed into the packed
-   * API.  Any buffer that is present in `buffer_map` but not present
-   * in `preflattened_buffer_map` is assumed to be the same before
-   * and after flattening (e.g. a 1-d tensor that is backed by 1-d
-   * flat memory).
    *
-   * TODO(Lunderberg): Remove preflattened_buffer_map, and instead
-   * declare each flattened buffer as aliasing the original tensor
-   * shape.  This should include improving the StmtExprMutator to
-   * provide easier interactions with Buffer objects, so that the
-   * bookkeeping of relationships between buffers doesn't need to be
-   * repeated across several transforms.
+   *  Prior to buffer flattening, which is performed either in
+   *  StorageFlatten for TE-based schedules or in FlattenBuffer for
+   *  TIR-based schedules, these buffer objects are used directly in
+   *  the body of the function.  After buffer flattening, these buffer
+   *  objects remain unflattened for use in argument validation, but
+   *  all usage in the body of the function is done through a
+   *  flattened alias of the buffer.
    */
-  Map<tir::Var, Buffer> preflattened_buffer_map;
+  Map<tir::Var, Buffer> buffer_map;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("params", &params);
     v->Visit("body", &body);
     v->Visit("ret_type", &ret_type);
     v->Visit("buffer_map", &buffer_map);
-    v->Visit("preflattened_buffer_map", &preflattened_buffer_map);
     v->Visit("attrs", &attrs);
     v->Visit("span", &span);
     v->Visit("_checked_type_", &checked_type_);
@@ -123,7 +112,6 @@ class PrimFuncNode : public BaseFuncNode {
   bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {
     // visit params and buffer_map first as they contains defs.
     return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) &&
-           equal(preflattened_buffer_map, other->preflattened_buffer_map) &&
            equal(ret_type, other->ret_type) && equal(body, other->body) &&
            equal(attrs, other->attrs);
   }
@@ -131,7 +119,6 @@ class PrimFuncNode : public BaseFuncNode {
   void SHashReduce(SHashReducer hash_reduce) const {
     hash_reduce.DefHash(params);
     hash_reduce(buffer_map);
-    hash_reduce(preflattened_buffer_map);
     hash_reduce(ret_type);
     hash_reduce(body);
     hash_reduce(attrs);
@@ -169,21 +156,13 @@ class PrimFunc : public BaseFunc {
    * PrimFunc.  (e.g. a buffer of shape ``[1024]`` originally
    * generated as a tensor of shape ``[32, 32]``)
    *
-   * \param preflattened_buffer_map The buffer map for
-   * parameter buffer unpacking.  This contains buffer
-   * objects as they are expected to be passed in by the
-   * callee.  (e.g. a buffer of shape ``[32, 32]`` originally
-   * generated as a tensor of shape ``[32, 32]``)
-   *
    * \param attrs Additional function attributes.
    *
    * \param span The location of this object in the source code.
    */
-  TVM_DLL PrimFunc(
-      Array<tir::Var> params, Stmt body, Type ret_type = VoidType(),
-      Map<tir::Var, Buffer> buffer_map = Map<tir::Var, Buffer>(),
-      Optional<Map<tir::Var, Buffer>> preflattened_buffer_map = Optional<Map<tir::Var, Buffer>>(),
-      DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
+  TVM_DLL PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type = VoidType(),
+                   Map<tir::Var, Buffer> buffer_map = Map<tir::Var, Buffer>(),
+                   DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode);
   TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
index cc94c6e816..e15d126dd9 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
@@ -299,7 +299,6 @@ def DivideConstants(const_dict):
             new_body,
             f.ret_type,
             new_buffer_map,
-            f.preflattened_buffer_map,
             f.attrs,
             f.span,
         )
@@ -327,7 +326,7 @@ def EncodeConstants(const_dict):
     """
     new_const_dict = {}
 
-    def collect_encoding_definitions(stmt, old_buffer_to_const):
+    def collect_encoding_definitions(stmt, old_buffer_var_to_const):
         # Map from copy destination to copy source.
         copy_map = {}
         # List of buffer copies that occurred
@@ -376,7 +375,7 @@ def EncodeConstants(const_dict):
         def _encode_weights_or_bias(buffer1, buffer2, stmt, encode_func):
             """Encode the weights or align the bias either for one or two cores,
             depending on the variant."""
-            constant = old_buffer_to_const[buffer1]
+            constant = old_buffer_var_to_const[buffer1.data]
 
             # If we have just one core, encode the whole constant
             if buffer2 is None:
@@ -471,7 +470,12 @@ def EncodeConstants(const_dict):
         }
 
     def transform_stmt(
-        stmt, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const, new_buffer_to_split_idx
+        stmt,
+        buf_remap,
+        var_remap,
+        pointer_to_buffer,
+        new_buffer_var_to_const,
+        new_buffer_to_split_idx,
     ):
         def _visit_rewrite(stmt):
             if isinstance(stmt, tvm.tir.Call):
@@ -485,7 +489,7 @@ def EncodeConstants(const_dict):
                     # encoded buffer, the current should be a length.
                     if (
                         isinstance(prev_arg, tvm.tir.BufferLoad)
-                        and prev_arg.buffer in new_buffer_to_const
+                        and prev_arg.buffer.data in new_buffer_var_to_const
                     ):
                         buffer_size = np.prod(list(prev_arg.buffer.shape))
                         arg = buffer_size
@@ -554,28 +558,56 @@ def EncodeConstants(const_dict):
             ["tir.Call", "tir.Allocate", "tir.BufferLoad", "tir.AttrStmt"],
         )
 
+    def _collect_parameter_buffer_aliases(prim_func):
+        buffer_vars = {}
+        for param in prim_func.params:
+            if param in prim_func.buffer_map:
+                buf = prim_func.buffer_map[param]
+                buffer_vars[buf.data] = {buf}
+
+        def visit(node):
+            if isinstance(node, (tvm.tir.BufferStore, tvm.tir.BufferLoad, tvm.tir.DeclBuffer)):
+                buf = node.buffer
+                if buf.data in buffer_vars:
+                    buffer_vars[buf.data].add(buf)
+
+        tvm.tir.stmt_functor.post_order_visit(prim_func.body, visit)
+        return buffer_vars
+
     def _ftransform(f, mod, ctx):
+        param_buffer_var_usage = _collect_parameter_buffer_aliases(f)
+
         # Step 0: Unpack the constant dictionary in terms of the
         # functions buffers.
-        old_buffer_to_const = {}
+        old_buffer_var_to_const = {}
         for i, param in enumerate(f.params):
             if i in const_dict:
-                old_buffer_to_const[f.buffer_map[param]] = const_dict[i]
+                old_buffer_var_to_const[f.buffer_map[param].data] = const_dict[i]
 
         # Step 1: Collect information on the buffers that will be
         # replaced by encodings.
-        buffer_information = collect_encoding_definitions(f.body, old_buffer_to_const)
+        buffer_information = collect_encoding_definitions(f.body, old_buffer_var_to_const)
 
         # Step 2: Generate variable/buffer remaps, based on the
         # collected information.
         buf_remap = {}
-        new_buffer_to_const = {}
+        new_buffer_var_to_const = {}
         new_buffer_to_split_idx = {}
 
+        def define_remap(old_buf, new_buf):
+            try:
+                old_buffers = param_buffer_var_usage[old_buf.data]
+            except KeyError:
+                old_buffers = [old_buf]
+
+            for old_buffer in old_buffers:
+                buf_remap[old_buffer] = new_buf
+
         # Any encoded buffers must be replaced
         for info in buffer_information["constant_buffer_replacements"]:
-            buf_remap[info["old_buffer"]] = info["new_buffer"]
-            new_buffer_to_const[info["new_buffer"]] = info["encoded_constants"]
+            define_remap(info["old_buffer"], info["new_buffer"])
+
+            new_buffer_var_to_const[info["new_buffer"].data] = info["encoded_constants"]
 
             if info["split_idx"]:
                 new_buffer_to_split_idx[info["new_buffer"]] = info["split_idx"]
@@ -596,9 +628,11 @@ def EncodeConstants(const_dict):
                     name=copy_dest.name,
                     scope=copy_dest.scope(),
                 )
-                buf_remap[copy_dest] = new_dest
-                if copy_source in new_buffer_to_const:
-                    new_buffer_to_const[new_dest] = new_buffer_to_const[copy_source]
+                define_remap(copy_dest, new_dest)
+                if copy_source.data in new_buffer_var_to_const:
+                    new_buffer_var_to_const[new_dest.data] = new_buffer_var_to_const[
+                        copy_source.data
+                    ]
 
                 if copy_source in new_buffer_to_split_idx:
                     new_buffer_to_split_idx[new_dest] = new_buffer_to_split_idx[copy_source]
@@ -615,7 +649,7 @@ def EncodeConstants(const_dict):
             buf_remap,
             var_remap,
             pointer_to_buffer,
-            new_buffer_to_const,
+            new_buffer_var_to_const,
             new_buffer_to_split_idx,
         )
 
@@ -626,10 +660,10 @@ def EncodeConstants(const_dict):
             if buffer in buf_remap:
                 buffer = buf_remap[buffer]
 
-            if buffer in new_buffer_to_const:
-                new_const_dict[i] = new_buffer_to_const[buffer].flatten()
-            elif buffer in old_buffer_to_const:
-                new_const_dict[i] = old_buffer_to_const[buffer].flatten()
+            if buffer.data in new_buffer_var_to_const:
+                new_const_dict[i] = new_buffer_var_to_const[buffer.data].flatten()
+            elif buffer.data in old_buffer_var_to_const:
+                new_const_dict[i] = old_buffer_var_to_const[buffer.data].flatten()
 
             new_buffer_map[param] = buffer
 
@@ -638,7 +672,6 @@ def EncodeConstants(const_dict):
             new_body,
             f.ret_type,
             new_buffer_map,
-            f.preflattened_buffer_map,
             f.attrs,
             f.span,
         )
@@ -873,7 +906,6 @@ def CreatePrimFuncWithoutConstants(const_dict):
     def _ftransform(f, mod, ctx):
         new_params = list()
         new_buffer_map = dict()
-        new_preflattened_buffer_map = dict()
         for param_idx in const_dict.keys():
             # We are using buffer_var to key the constants as
             # PrimFunc params of constants will be removed.
@@ -882,14 +914,11 @@ def CreatePrimFuncWithoutConstants(const_dict):
             if i not in const_dict.keys():
                 new_params.append(param)
                 new_buffer_map[param] = f.buffer_map[param]
-                if param in f.preflattened_buffer_map:
-                    new_preflattened_buffer_map[param] = f.preflattened_buffer_map[param]
         return tvm.tir.PrimFunc(
             new_params,
             f.body,
             f.ret_type,
             new_buffer_map,
-            new_preflattened_buffer_map,
             f.attrs,
             f.span,
         )
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
index 0678925e2f..842e21378f 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -314,74 +314,6 @@ def match_buffer(
     )
 
 
-def preflattened_buffer(
-    postflattened: Buffer,
-    shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
-    dtype: str = "float32",
-    data: Var = None,
-    strides: List[PrimExpr] = None,
-    elem_offset: PrimExpr = None,
-    scope: str = "global",
-    align: int = -1,
-    offset_factor: int = 0,
-    buffer_type: str = "default",
-    axis_separators: List[int] = None,
-) -> None:
-    """The pre-flattened buffer statement.
-
-    Parameters
-    ----------
-    postflattened : Buffer
-        The original buffer to be flattened.
-
-    shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral]
-        The type of the buffer prior to flattening.
-
-    dtype : str
-        The data type in the content of the buffer.
-
-    data : Var
-        The pointer to the head of the data.
-
-    strides : List[PrimExpr]
-        The strides of each dimension.
-
-    elem_offset : PrimExpr
-        The offset in terms of number of dtype elements (including lanes).
-
-    scope : str
-        The optional storage scope of buffer data pointer.
-
-    align : int
-        The alignment requirement of data pointer in bytes.
-
-    offset_factor : int
-        The factor of elem_offset field.
-
-    buffer_type : str
-        The buffer type.
-
-    axis_separators : List[int]
-        The separators between input axes when generating flattened output axes.
-    """
-    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
-    if strides is None:
-        strides = []
-    _ffi_api.PreflattenedBuffer(  # type: ignore[attr-defined] # pylint: disable=no-member
-        postflattened,
-        shape,
-        dtype,
-        data,
-        strides,
-        elem_offset,
-        scope,
-        align,
-        offset_factor,
-        buffer_type,
-        axis_separators,
-    )
-
-
 def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame:
     """The block declaration statement.
 
@@ -1697,7 +1629,6 @@ __all__ += [
     "func_attr",
     "func_ret",
     "match_buffer",
-    "preflattened_buffer",
     "block",
     "init",
     "where",
diff --git a/python/tvm/script/parser_v1/context_maintainer.py b/python/tvm/script/parser_v1/context_maintainer.py
index f7f16855c7..b84b7d3980 100644
--- a/python/tvm/script/parser_v1/context_maintainer.py
+++ b/python/tvm/script/parser_v1/context_maintainer.py
@@ -129,8 +129,6 @@ class ContextMaintainer:
     """List[Var]: The function parameters"""
     func_buffer_map: Mapping[Var, Buffer] = {}
     """Mapping[Var, Buffer]: The function buffer map"""
-    func_preflattened_buffer_map: Mapping[Var, Buffer] = {}
-    """Mapping[Var, Buffer]: The function buffer map, prior to any flattening."""
     func_dict_attr: Mapping[str, Object] = {}
     """Mapping[str, Object]: The function attrs"""
     func_var_env_dict: Mapping[Var, str] = {}
@@ -160,7 +158,6 @@ class ContextMaintainer:
         # function context
         self.func_params = []
         self.func_buffer_map = {}
-        self.func_preflattened_buffer_map = {}
         self.func_dict_attr = {}
         self.func_var_env_dict = {}
         # parser and analyzer
diff --git a/python/tvm/script/parser_v1/parser.py b/python/tvm/script/parser_v1/parser.py
index c34aae2345..ce8c1fe161 100644
--- a/python/tvm/script/parser_v1/parser.py
+++ b/python/tvm/script/parser_v1/parser.py
@@ -501,7 +501,6 @@ class TVMScriptParser(Transformer):
             body,
             ret_type,
             buffer_map=self.context.func_buffer_map,
-            preflattened_buffer_map=self.context.func_preflattened_buffer_map,
             attrs=tvm.ir.make_node("DictAttrs", **dict_attr) if dict_attr else None,
             span=tvm_span_from_synr(node.span),
         )
diff --git a/python/tvm/script/parser_v1/tir/__init__.pyi b/python/tvm/script/parser_v1/tir/__init__.pyi
index a64eed055a..beefaf4c75 100644
--- a/python/tvm/script/parser_v1/tir/__init__.pyi
+++ b/python/tvm/script/parser_v1/tir/__init__.pyi
@@ -117,18 +117,6 @@ def store(
 ) -> None: ...
 def comm_reducer(lambda_io: Callable[[Any, Any], Any], identities: List[PrimExpr]) -> PrimExpr: ...
 def llvm_lookup_intrinsic_id(name: str) -> PrimExpr: ...
-def preflattened_buffer(
-    buf: Buffer,
-    shape: Sequence[PrimExpr],
-    dtype: str = "float32",
-    data: Optional[Ptr] = None,
-    strides: Optional[Sequence[int]] = None,
-    elem_offset: Optional[int] = None,
-    scope: str = "global",
-    align: int = -1,
-    offset_factor: int = 0,
-    buffer_type: str = "default",
-) -> Buffer: ...
 
 """
 Intrinsics - tvm builtin
diff --git a/python/tvm/script/parser_v1/tir/special_stmt.py b/python/tvm/script/parser_v1/tir/special_stmt.py
index 7cbf474410..f558eb6b7f 100644
--- a/python/tvm/script/parser_v1/tir/special_stmt.py
+++ b/python/tvm/script/parser_v1/tir/special_stmt.py
@@ -904,79 +904,6 @@ class FuncAttr(SpecialStmt):
         super().__init__(func_attr, def_symbol=False)
 
 
-@register
-class PreflattenedBufferMap(SpecialStmt):
-    """Special Stmt for declaring the PrimFunc::preflattened_buffer_map
-
-    Example
-    -------
-    .. code-block:: python
-         A0 = T.match_buffer(A, (48,), dtype="float32")
-         T.preflattened_buffer_map(A, (1, 4, 4, 3), elem_offset=1, align=4, dtype="float32")
-    """
-
-    def __init__(self):
-        def preflattened_buffer(
-            postflattened,
-            shape,
-            dtype="float32",
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            span=None,
-        ):
-
-            param = None
-            for key, value in self.context.func_buffer_map.items():
-                if value.same_as(postflattened):
-                    param = key
-                    break
-
-            assert (
-                param is not None
-            ), f"Post-flatten buffer {postflattened.name} does not appear in the buffer map."
-
-            if data is None:
-                data = self.context.func_buffer_map[param].data
-
-            buffer_name: str = f"{postflattened.name}_preflatten"
-            if align != -1:
-                if isinstance(align, IntImm):
-                    align = align.value
-                else:
-                    assert isinstance(align, int), f"align: want int or IntImm, got {align!r}"
-
-            if offset_factor != 0:
-                if isinstance(offset_factor, IntImm):
-                    offset_factor = offset_factor.value
-                else:
-                    assert isinstance(
-                        offset_factor, int
-                    ), f"offset_factor: want int or IntImm, got {offset_factor!r}"
-
-            preflattened = tvm.tir.decl_buffer(
-                shape,
-                dtype,
-                buffer_name,
-                data,
-                strides,
-                elem_offset,
-                scope,
-                align,
-                offset_factor,
-                buffer_type,
-                span=span,
-            )
-
-            self.context.func_preflattened_buffer_map[param] = preflattened
-
-        super().__init__(preflattened_buffer, def_symbol=False)
-
-
 @register
 class TargetAttrValue(SpecialStmt):
     """Special Stmt for target attr value.
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index 4628ae3626..c5cc922a3e 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -49,9 +49,6 @@ class PrimFunc(BaseFunc):
     buffer_map : Map[tvm.tir.Var, tvm.tir.Buffer]
         The buffer binding map.
 
-    preflattened_buffer_map : Optional[Map[tvm.tir.Var, tvm.tir.Buffer]]
-        The buffer binding map, prior to any flattening.
-
     attrs: Optional[tvm.Attrs]
         Attributes of the function, can be None
 
@@ -65,14 +62,12 @@ class PrimFunc(BaseFunc):
         body,
         ret_type=None,
         buffer_map=None,
-        preflattened_buffer_map=None,
         attrs=None,
         span=None,
     ):
 
         param_list = []
         buffer_map = {} if buffer_map is None else buffer_map
-        preflattened_buffer_map = {} if preflattened_buffer_map is None else preflattened_buffer_map
         for x in params:
             x = tvm.runtime.convert(x) if not isinstance(x, Object) else x
             if isinstance(x, Buffer):
@@ -90,7 +85,6 @@ class PrimFunc(BaseFunc):
             body,
             ret_type,
             buffer_map,
-            preflattened_buffer_map,
             attrs,
             span,
         )  # type: ignore
@@ -116,7 +110,6 @@ class PrimFunc(BaseFunc):
             new_body,
             self.ret_type,
             self.buffer_map,
-            self.preflattened_buffer_map,
             self.attrs,
             span,
         )
diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc
index e50559ac10..fc3f49d76f 100644
--- a/src/printer/tir_text_printer.cc
+++ b/src/printer/tir_text_printer.cc
@@ -152,16 +152,6 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
         2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
   }
 
-  if (op->preflattened_buffer_map.size() != 0) {
-    // print preflattened_buffer_map
-    std::vector<Doc> preflattened_buffer_map_doc;
-    for (auto& v : op->preflattened_buffer_map) {
-      preflattened_buffer_map_doc.push_back(Print(v.first) << ": " << Print(v.second));
-    }
-    doc << Doc::Indent(2, Doc::NewLine()
-                              << "preflattened_buffer_map = {"
-                              << PrintSep(preflattened_buffer_map_doc, Doc::Text(", ")) << "}");
-  }
   doc << PrintBody(op->body);
   return doc;
 }
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index f1d68ee438..0dc6240bc6 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -1672,26 +1672,6 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
     body << Print((*it).first) << ", " << memo_buf_decl_[buf];
     body << ")" << Doc::NewLine();
   }
-  // print preflattened buffer map
-  for (const auto& param : op->params) {
-    auto pf_buf_it = op->preflattened_buffer_map.find(param);
-    if (pf_buf_it != op->preflattened_buffer_map.end()) {
-      const Buffer& preflattened = (*pf_buf_it).second;
-
-      auto buf_it = op->buffer_map.find(param);
-      ICHECK(buf_it != op->buffer_map.end()) << "Found pre-flattened buffer " << preflattened->name
-                                             << " with no corresponding post-flatten buffer.";
-      const Buffer& postflattened = (*buf_it).second;
-
-      // Call Print() without assigning in order to fill memo_buf_decl_.
-      Print(preflattened);
-      buf_not_in_headers_.insert(preflattened.get());
-      ICHECK(memo_buf_decl_.count(preflattened));
-
-      body << tir_prefix_ << ".preflattened_buffer(" << Print(postflattened) << ", "
-           << memo_buf_decl_.at(preflattened) << ")" << Doc::NewLine();
-    }
-  }
   // print body
   body << "# body" << Doc::NewLine();
 
diff --git a/src/relay/backend/aot/aot_lower_main.cc b/src/relay/backend/aot/aot_lower_main.cc
index 82393c535c..2a4dfb84dd 100644
--- a/src/relay/backend/aot/aot_lower_main.cc
+++ b/src/relay/backend/aot/aot_lower_main.cc
@@ -504,7 +504,7 @@ class AOTMainLowerer : public MixedModeVisitor {
     tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations});
 
     // Make the PrimFunc
-    return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, {},
+    return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_,
                          DictAttrs(dict_attrs));
   }
 
diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc
index 786b3f81a5..3c0ab7c16f 100644
--- a/src/relay/backend/aot_executor_codegen.cc
+++ b/src/relay/backend/aot_executor_codegen.cc
@@ -803,7 +803,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations});
 
     // Make the PrimFunc
-    return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, {},
+    return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_,
                          DictAttrs(dict_attrs));
   }
 
diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
index da51e6b762..1ea020e884 100644
--- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
+++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
@@ -108,7 +108,7 @@ class RelayToTIRVisitor : public MixedModeMutator {
     }
 
     tir::PrimFunc replacement_func(func_signature, body, VoidType(), buffer_map,
-                                   Map<tir::Var, tir::Buffer>(), DictAttrs(dict_attrs));
+                                   DictAttrs(dict_attrs));
     ir_module_->Add(global_var, replacement_func);
   }
 
diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
index eb6cf1cce4..ad2b06695c 100644
--- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
+++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
@@ -152,7 +152,7 @@ class ConvertAddToSubtract : public MixedModeMutator {
       };
 
       tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(),
-                                                     buffer_map, {}, DictAttrs(dict_attrs));
+                                                     buffer_map, DictAttrs(dict_attrs));
 
       // Switch to TIRToRuntime hook for testing
       Bool tir_to_runtime = func->GetAttr<Bool>("tir_to_runtime").value_or(Bool(false));
diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc
index f48ee52506..1e63201a40 100644
--- a/src/script/ir_builder/tir/frame.cc
+++ b/src/script/ir_builder/tir/frame.cc
@@ -34,7 +34,6 @@ void PrimFuncFrameNode::ExitWithScope() {
       /*body=*/AsStmt(stmts),
       /*ret_type=*/ret_type.value_or(TupleType::Empty()),
       /*buffer_map=*/buffer_map,
-      /*preflattened_buffer_map=*/preflattened_buffer_map,
       /*attrs=*/attrs.defined() ? DictAttrs(attrs.value()) : NullValue<DictAttrs>());
   func = tvm::tir::ScriptComplete(func, root_alloc_buffers);
   IRBuilder builder = IRBuilder::Current();
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 78107136d4..822e8e4683 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -58,7 +58,6 @@ PrimFuncFrame PrimFunc() {
   n->args.clear();
   n->ret_type = NullOpt;
   n->buffer_map.clear();
-  n->preflattened_buffer_map.clear();
   n->attrs = NullOpt;
   n->env_threads.clear();
   n->root_alloc_buffers.clear();
@@ -137,26 +136,6 @@ Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype, Optio
   return buffer;
 }
 
-void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape, DataType dtype,
-                        Optional<Var> data, Array<PrimExpr> strides, PrimExpr elem_offset,
-                        String storage_scope, int align, int offset_factor, String buffer_type_str,
-                        Array<IntImm> axis_separators) {
-  PrimFuncFrame frame = FindPrimFuncFrame("T.preflattened_buffer");
-  for (auto const& p : frame->buffer_map) {
-    if (p.second.same_as(postflattened_buffer)) {
-      String buffer_name(postflattened_buffer->name + "_preflatten");
-      Buffer buffer =
-          BufferDecl(shape, dtype, buffer_name, data.value_or(p.second->data), strides, elem_offset,
-                     storage_scope, align, offset_factor, buffer_type_str, axis_separators);
-      details::Namer::Name(buffer, buffer_name);
-      frame->preflattened_buffer_map.Set(p.first, buffer);
-      return;
-    }
-  }
-  LOG(FATAL) << "ValueError: postflattened buffer " << postflattened_buffer->name
-             << " does not exist.";
-}
-
 BlockFrame Block(String name, bool no_realize) {
   ObjectPtr<BlockFrameNode> n = make_object<BlockFrameNode>();
   n->name = name;
@@ -595,7 +574,6 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncName").set_body_typed(FuncName);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncAttrs").set_body_typed(FuncAttrs);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncRet").set_body_typed(FuncRet);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.MatchBuffer").set_body_typed(MatchBuffer);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.PreflattenedBuffer").set_body_typed(PreflattenedBuffer);
 
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Init").set_body_typed(Init);
diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc
index 32b59ce54b..d0933e0691 100644
--- a/src/tir/analysis/device_constraint_utils.cc
+++ b/src/tir/analysis/device_constraint_utils.cc
@@ -210,8 +210,6 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator {
 
     // Start with a copy of the current prim_func buffer map.
     Map<Var, Buffer> new_buffer_map(prim_func->buffer_map.begin(), prim_func->buffer_map.end());
-    Map<Var, Buffer> new_preflattened_buffer_map(prim_func->preflattened_buffer_map.begin(),
-                                                 prim_func->preflattened_buffer_map.end());
     bool any_change = false;
 
     // For each constrained parameter...
@@ -225,23 +223,6 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator {
         any_change = true;
       }
       new_buffer_map.Set(param, new_buffer);
-
-      // Rewrite the pre-flattened buffers to account for constraint.
-      // This only has an impact if the IRModule being analyzed has
-      // already been run through the StorageFlatten or FlattenBuffer
-      // passes.
-      if (auto opt = prim_func->preflattened_buffer_map.Get(param)) {
-        Buffer pf_buffer = opt.value();
-        if (pf_buffer.same_as(buffer)) {
-          new_preflattened_buffer_map.Set(param, new_buffer);
-        } else {
-          const Buffer new_buffer = RewriteBuffer(pf_buffer, virtual_device);
-          if (!new_buffer.same_as(pf_buffer)) {
-            any_change = true;
-          }
-          new_preflattened_buffer_map.Set(param, new_buffer);
-        }
-      }
     }
     // Make sure we have accounted for all prim_func parameters.
     CheckNoRemainingPointerParams(prim_func, &current_primfunc_param_index);
@@ -259,8 +240,7 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator {
 
     if (any_change) {
       return PrimFunc(prim_func->params, std::move(new_body), prim_func->ret_type,
-                      std::move(new_buffer_map), std::move(new_preflattened_buffer_map),
-                      prim_func->attrs, prim_func->span);
+                      std::move(new_buffer_map), prim_func->attrs, prim_func->span);
     } else {
       return prim_func;
     }
diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc
index d51ffbf833..369c4adc85 100644
--- a/src/tir/contrib/ethosu/passes.cc
+++ b/src/tir/contrib/ethosu/passes.cc
@@ -152,9 +152,8 @@ class HoistAllocatesMutator : public StmtExprMutator {
                    current_alloc->span);
     }
 
-    PrimFunc new_main_func =
-        PrimFunc(main_func->params, new_main_func_body, main_func->ret_type, main_func->buffer_map,
-                 main_func->preflattened_buffer_map, main_func->attrs);
+    PrimFunc new_main_func = PrimFunc(main_func->params, new_main_func_body, main_func->ret_type,
+                                      main_func->buffer_map, main_func->attrs);
     return new_main_func;
   }
 
@@ -523,7 +522,6 @@ class MergeConstantsMutator : public StmtExprMutator {
     prim_func_node->body = std::move(new_body);
     prim_func_node->buffer_map = std::move(new_buffer_map);
     prim_func_node->params = std::move(new_params);
-    prim_func_node->preflattened_buffer_map = {};
     PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
 
     // Add the new const dict as an attribute
diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc
index c609ad158e..d4802e2876 100644
--- a/src/tir/ir/function.cc
+++ b/src/tir/ir/function.cc
@@ -29,9 +29,7 @@ namespace tvm {
 namespace tir {
 // Get the function type of a PrimFunc
 PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type,
-                   Map<tir::Var, Buffer> buffer_map,
-                   Optional<Map<tir::Var, Buffer>> preflattened_buffer_map, DictAttrs attrs,
-                   Span span) {
+                   Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span span) {
   // Assume void-return type for now
   // TODO(tvm-team) consider type deduction from body.
   if (!ret_type.defined()) {
@@ -42,7 +40,6 @@ PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type,
   n->body = std::move(body);
   n->ret_type = std::move(ret_type);
   n->buffer_map = std::move(buffer_map);
-  n->preflattened_buffer_map = preflattened_buffer_map.value_or(Map<tir::Var, Buffer>());
   n->attrs = std::move(attrs);
   n->checked_type_ = n->func_type_annotation();
   n->span = std::move(span);
@@ -129,9 +126,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 
 TVM_REGISTER_GLOBAL("tir.PrimFunc")
     .set_body_typed([](Array<tir::Var> params, Stmt body, Type ret_type,
-                       Map<tir::Var, Buffer> buffer_map,
-                       Map<tir::Var, Buffer> preflattened_buffer_map, DictAttrs attrs, Span span) {
-      return PrimFunc(params, body, ret_type, buffer_map, preflattened_buffer_map, attrs, span);
+                       Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span span) {
+      return PrimFunc(params, body, ret_type, buffer_map, attrs, span);
     });
 
 TVM_REGISTER_GLOBAL("tir.TensorIntrin")
diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc
index 5dc08f31c2..040c48c796 100644
--- a/src/tir/transforms/bf16_legalize.cc
+++ b/src/tir/transforms/bf16_legalize.cc
@@ -308,37 +308,8 @@ class BF16LowerRewriter : public StmtExprMutator {
       }
     }
 
-    // Most passes do not change the preflattened buffer map, nor
-    // should they change it.  This is an exception, because the Var
-    // associated with the `BufferNode::data` in
-    // `PrimFunc::buffer_map` may be replaced, and the corresponding
-    // Var in the `PrimFunc::preflattened_buffer_map` must also be
-    // replaced.
-    Map<Var, Buffer> new_preflattened_buffer_map;
-    for (auto& itr : op->preflattened_buffer_map) {
-      auto param_var = itr.first;
-      auto oldbuf = itr.second;
-      if (oldbuf->dtype.is_bfloat16()) {
-        auto it = new_buffer_map.find(param_var);
-        ICHECK(it != new_buffer_map.end())
-            << "PrimFunc parameter " << param_var->name_hint
-            << " is associated with the pre-flattened buffer " << oldbuf->name
-            << ", but isn't associated with any post-flatten buffer.";
-        const Buffer& flatbuf = (*it).second;
-        DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes());
-        auto newbuf = Buffer(flatbuf->data, dtype, oldbuf->shape, oldbuf->strides,
-                             oldbuf->elem_offset, oldbuf->name, oldbuf->data_alignment,
-                             oldbuf->offset_factor, oldbuf->buffer_type);
-        buffer_remap_[oldbuf] = newbuf;
-        new_preflattened_buffer_map.Set(param_var, newbuf);
-      } else {
-        new_preflattened_buffer_map.Set(param_var, oldbuf);
-      }
-    }
-
     if (buffer_remap_.size() != 0) {
       op->buffer_map = new_buffer_map;
-      op->preflattened_buffer_map = new_preflattened_buffer_map;
     }
   }
 
diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc
index 5441120491..d51a44887f 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -37,22 +37,18 @@ namespace tir {
 class BufferFlattener : public StmtExprMutator {
  public:
   static PrimFunc Flatten(PrimFunc func) {
-    Map<Var, Buffer> preflattened_buffer_map =
-        Merge(func->buffer_map, func->preflattened_buffer_map);
-    auto pass = BufferFlattener(func->buffer_map);
+    auto pass = BufferFlattener();
     auto writer = func.CopyOnWrite();
     writer->body = pass.VisitStmt(func->body);
-    writer->preflattened_buffer_map = preflattened_buffer_map;
-    writer->buffer_map = pass.updated_extern_buffer_map_;
+    // The buffers in func->buffer_map are deliberately left
+    // unflattened, as they are used for validation of user-provided
+    // arguments.  The flattened buffers used in the updated
+    // function body alias the argument buffers.
     return func;
   }
 
  private:
-  explicit BufferFlattener(const Map<Var, Buffer>& extern_buffer_map) {
-    for (const auto& kv : extern_buffer_map) {
-      updated_extern_buffer_map_.Set(kv.first, GetFlattenedBuffer(kv.second));
-    }
-  }
+  BufferFlattener() {}
 
   Stmt VisitStmt_(const BlockNode* op) final {
     ICHECK_EQ(op->match_buffers.size(), 0)
diff --git a/src/tir/transforms/legalize_packed_calls.cc b/src/tir/transforms/legalize_packed_calls.cc
index 344e6c7ae3..fed76876f6 100644
--- a/src/tir/transforms/legalize_packed_calls.cc
+++ b/src/tir/transforms/legalize_packed_calls.cc
@@ -74,9 +74,9 @@ class PackedCallLegalizer : public StmtExprMutator {
             tvm::runtime::Map<tvm::tir::Var, tvm::tir::Buffer>::iterator param_buf_it;
             if (prim_func != nullptr) {
               auto param_var = prim_func->params[i - 1];
-              param_buf_it = prim_func->preflattened_buffer_map.find(param_var);
+              param_buf_it = prim_func->buffer_map.find(param_var);
             }
-            if (prim_func != nullptr && param_buf_it != prim_func->preflattened_buffer_map.end()) {
+            if (prim_func != nullptr && param_buf_it != prim_func->buffer_map.end()) {
               Buffer param = (*param_buf_it).second;
               PrimExpr shape = tvm::tir::Call(
                   DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), param->shape);
diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc
index 5b9bac03ab..c1611a23a0 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -209,9 +209,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func) {
       continue;
     }
 
-    if (func_ptr->preflattened_buffer_map.count(param)) {
-      buffer_def.emplace_back(v_arg, func_ptr->preflattened_buffer_map[param]);
-    } else if (func_ptr->buffer_map.count(param)) {
+    if (func_ptr->buffer_map.count(param)) {
       buffer_def.emplace_back(v_arg, func_ptr->buffer_map[param]);
     } else {
       var_def.emplace_back(v_arg, param);
diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc
index db59824bf1..90150ebd3c 100644
--- a/src/tir/transforms/plan_update_buffer_allocation_location.cc
+++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc
@@ -52,21 +52,21 @@ class BufferAllocationLocator : public StmtExprMutator {
  public:
   explicit BufferAllocationLocator(const PrimFunc& func) {
     Map<Buffer, Optional<Stmt>> buffer_lca = DetectBufferAccessLCA(func);
+    std::unordered_set<const VarNode*> arg_buffer_vars;
     CollectUnmanagedAllocations collector;
     collector(func->body);
     unmanaged_allocations_ = collector.unmanaged_allocations;
 
-    std::unordered_set<const BufferNode*> arg_buffers;
     for (const auto& kv : func->buffer_map) {
       const Buffer& buffer = kv.second;
-      arg_buffers.emplace(buffer.get());
+      arg_buffer_vars.emplace(buffer->data.get());
       buffer_data_to_buffer_.Set(buffer->data, buffer);
     }
     // create buffers to be allocated at each stmts
     for (const auto& kv : buffer_lca) {
       const Buffer& buffer = kv.first;
       const StmtNode* stmt = kv.second.get();
-      if (arg_buffers.count(buffer.get())) {
+      if (arg_buffer_vars.count(buffer->data.get())) {
         continue;
       }
       if (!unmanaged_allocations_.count(buffer->data.get())) {
diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc
index ab1b062ad6..eb0409e555 100644
--- a/src/tir/transforms/storage_flatten.cc
+++ b/src/tir/transforms/storage_flatten.cc
@@ -402,6 +402,7 @@ class BufferStrideLegalize : public StmtExprMutator {
 
       auto fptr = func.CopyOnWrite();
       fptr->body = pass(std::move(fptr->body));
+      fptr->buffer_map = pass.UpdatedExternBufferMap();
       if (auto map = func->attrs.GetAttr<Map<Buffer, Array<IndexMap>>>("layout_transform_map")) {
         func = WithAttr(std::move(func), "layout_transform_map", pass.UpdateIndexMap(map.value()));
       }
@@ -420,7 +421,6 @@ class BufferStrideLegalize : public StmtExprMutator {
         BufferEntry entry;
         entry.remap_to = with_strides;
         entry.in_scope = true;
-        entry.is_external = true;
         buf_map_[buf] = entry;
       }
       updated_extern_buffer_map_.Set(kv.first, with_strides);
@@ -443,51 +443,54 @@ class BufferStrideLegalize : public StmtExprMutator {
   Map<Var, Buffer> UpdatedExternBufferMap() const { return updated_extern_buffer_map_; }
 
   Buffer WithStrides(Buffer buf) {
-    auto it = buf_map_.find(buf);
+    auto cache_key = buf;
+
+    auto it = buf_map_.find(cache_key);
     if (it != buf_map_.end()) {
       const BufferEntry& entry = it->second;
       ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer";
       return entry.remap_to;
     }
 
+    Array<PrimExpr> shape = buf->shape;
+
     if (buf->strides.size()) {
       ICHECK_EQ(buf->strides.size(), buf->shape.size())
           << "Buffer " << buf << " has inconsistent strides/shape.";
-      return buf;
-    }
-
-    // Keeping this to have matched behavior to previous version.
-    // There are many parts of the codebase that assume that a strided
-    // array cannot be compact.  For example, ArgBinder::BindBuffer
-    // and tir.Specialize.
-    if (dim_align_.count(buf) == 0) {
-      return buf;
-    }
-
-    // Can't define the strides for a buffer without a known shape.
-    Array<PrimExpr> shape = buf->shape;
-    if (shape.size() == 0) {
-      return buf;
-    }
-
-    std::vector<PrimExpr> rstrides;
-    const std::vector<DimAlignInfo>& avec = dim_align_[buf];
-    int first_dim = 0;
-    PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
-    for (size_t i = shape.size(); i != 0; --i) {
-      size_t dim = i - 1;
-      if (dim < avec.size() && avec[dim].align_factor != 0) {
-        PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
-        PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
-        stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
-        stride = bound_analyzer_->Simplify(stride);
+    } else if (dim_align_.count(buf) == 0) {
+      // Keeping this to have matched behavior to previous version.
+      // There are many parts of the codebase that assume that a
+      // strided array cannot be compact.  For example,
+      // ArgBinder::BindBuffer and tir.Specialize.  To avoid breaking
+      // these, do not define the strides unless required for a
+      // non-compact array.
+    } else if (shape.size() == 0) {
+      // Can't define the strides for a buffer without a known shape.
+    } else {
+      // With everything checked, can now define the updated strides
+      std::vector<PrimExpr> rstrides;
+      const std::vector<DimAlignInfo>& avec = dim_align_[buf];
+      int first_dim = 0;
+      PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
+      for (size_t i = shape.size(); i != 0; --i) {
+        size_t dim = i - 1;
+        if (dim < avec.size() && avec[dim].align_factor != 0) {
+          PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
+          PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
+          stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
+          stride = bound_analyzer_->Simplify(stride);
+        }
+        rstrides.push_back(stride);
+        stride = stride * shape[dim];
       }
-      rstrides.push_back(stride);
-      stride = stride * shape[dim];
+
+      buf.CopyOnWrite()->strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
     }
 
-    auto ptr = buf.CopyOnWrite();
-    ptr->strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
+    BufferEntry entry;
+    entry.remap_to = buf;
+    entry.in_scope = true;
+    buf_map_[cache_key] = entry;
 
     return buf;
   }
@@ -513,16 +516,10 @@ class BufferStrideLegalize : public StmtExprMutator {
       Buffer target_with_strides = WithStrides(Downcast<Buffer>(arr[1]));
       Buffer source_with_strides = WithStrides(source);
 
-      {
-        BufferEntry entry;
-        entry.remap_to = source_with_strides;
-        entry.in_scope = true;
-        entry.is_external = false;
-        buf_map_[source] = entry;
-      }
-
       Stmt body = this->VisitStmt(op->body);
 
+      buf_map_[source].in_scope = false;
+
       return AttrStmt(Array<ObjectRef>{source_with_strides, target_with_strides}, op->attr_key,
                       op->value, body, op->span);
     } else {
@@ -560,13 +557,6 @@ class BufferStrideLegalize : public StmtExprMutator {
   Stmt VisitStmt_(const BufferRealizeNode* op) final {
     Buffer key = op->buffer;
     Buffer with_strides = WithStrides(op->buffer);
-    {
-      BufferEntry entry;
-      entry.remap_to = with_strides;
-      entry.in_scope = true;
-      entry.is_external = false;
-      buf_map_[key] = entry;
-    }
 
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
 
@@ -589,22 +579,14 @@ class BufferStrideLegalize : public StmtExprMutator {
 
   template <typename Node>
   Node VisitBufferAccess(Node node) {
-    auto alloc_key = node->buffer->data.get();
-    if (!buf_map_.count(node->buffer) && buffer_var_defines_.count(alloc_key)) {
-      BufferEntry entry;
-      entry.remap_to = WithStrides(node->buffer);
-      entry.in_scope = true;
-      entry.is_external = false;
-      buf_map_[node->buffer] = entry;
-    }
-
     auto it = buf_map_.find(node->buffer);
-    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << node->buffer;
-    const BufferEntry& e = it->second;
-    ICHECK(e.in_scope) << "Cannot access a buffer " << node->buffer->name << ", out of scope";
+    ICHECK(it == buf_map_.end() || it->second.in_scope)
+        << "Cannot access a buffer " << node->buffer->name << ", out of scope";
 
-    auto writer = node.CopyOnWrite();
-    writer->buffer = e.remap_to;
+    auto with_strides = WithStrides(node->buffer);
+    if (!with_strides.same_as(node->buffer)) {
+      node.CopyOnWrite()->buffer = with_strides;
+    }
 
     return node;
   }
@@ -623,7 +605,6 @@ class BufferStrideLegalize : public StmtExprMutator {
   struct BufferEntry {
     Buffer remap_to;
     bool in_scope;
-    bool is_external;
   };
 
   std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_;
@@ -846,6 +827,7 @@ class BufferBindUnwrapper : public StmtExprMutator {
       BufferEntry e;
       e.buffer = kv.second;
       e.external = true;
+      var_to_buffer_[kv.second->data.get()] = kv.second;
       buf_map_[kv.second.get()] = std::move(e);
     }
   }
@@ -1001,6 +983,7 @@ class BufferBindUnwrapper : public StmtExprMutator {
       BufferEntry e;
       e.bounds = op->bounds;
       e.buffer = op->buffer;
+      var_to_buffer_[op->buffer->data.get()] = op->buffer;
       buf_map_[key] = std::move(e);
     }
 
@@ -1089,6 +1072,7 @@ class BufferBindUnwrapper : public StmtExprMutator {
       source_info.buffer = source;
       source_info.remap = std::make_unique<RemapInfo>(remap);
 
+      var_to_buffer_[source->data.get()] = source;
       buf_map_[source.get()] = std::move(source_info);
     }
 
@@ -1160,18 +1144,70 @@ class BufferBindUnwrapper : public StmtExprMutator {
   };
 
   const BufferEntry& GetBufferEntry(Buffer buffer) {
-    auto alloc_key = buffer->data.get();
-    if (!buf_map_.count(buffer.get()) && buffer_var_defines_.count(alloc_key)) {
+    if (buf_map_.count(buffer.get())) {
+      const BufferEntry& e = buf_map_[buffer.get()];
+      ICHECK(e.in_scope) << "Cannot access a buffer " << buffer->name << ", out of scope";
+      return e;
+    } else if (buffer_var_defines_.count(buffer->data.get())) {
+      // The buffer var was defined, but the buffer hasn't been seen
+      // before.
       BufferEntry entry;
       entry.buffer = buffer;
+      var_to_buffer_[buffer->data.get()] = buffer;
       buf_map_[buffer.get()] = std::move(entry);
-    }
+      return buf_map_[buffer.get()];
+    } else if (var_remap_.count(buffer->data.get())) {
+      // The buffer var is an alias of a bound buffer.  Only
+      // supported if the bound buffer has no offsets.  In this
+      // case, we just need to make a new aliasing buffer that
+      // shares the remapped data variable.
+      Var old_var = buffer->data;
+      Var new_var = Downcast<Var>(var_remap_[old_var.get()]);
 
-    auto it = buf_map_.find(buffer.get());
-    ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer;
-    const BufferEntry& e = it->second;
-    ICHECK(e.in_scope) << "Cannot access a buffer " << buffer->name << ", out of scope";
-    return it->second;
+      {
+        ICHECK(var_to_buffer_.count(old_var.get()))
+            << "Cannot find remap information for aliased buffer var " << old_var->name_hint
+            << ", required to verify this alias is legal.";
+        const Buffer& aliased_buffer = var_to_buffer_[old_var.get()];
+        const BufferEntry& entry = buf_map_[aliased_buffer.get()];
+        if (entry.remap) {
+          for (const auto& begin : entry.remap->begins) {
+            ICHECK(is_zero(begin)) << "Aliasing of buffer with offset is not supported";
+          }
+        }
+      }
+
+      {
+        Buffer new_buf = buffer;
+        new_buf.CopyOnWrite()->data = new_var;
+
+        RemapInfo remap_info;
+        remap_info.target = new_buf;
+        remap_info.begins = Array<PrimExpr>(buffer->shape.size(), 0);
+        remap_info.extents = buffer->shape;
+
+        BufferEntry entry;
+        entry.buffer = buffer;
+        entry.remap = std::make_unique<RemapInfo>(remap_info);
+        entry.in_scope = true;
+        var_to_buffer_[buffer->data.get()] = buffer;
+        buf_map_[buffer.get()] = std::move(entry);
+      }
+      return buf_map_[buffer.get()];
+    } else if (var_to_buffer_.count(buffer->data.get())) {
+      // This buffer is an alias of a known buffer, with no remaps.  A
+      // buffer entry should be generated and returned.
+      BufferEntry entry;
+      entry.buffer = buffer;
+      entry.in_scope = true;
+      var_to_buffer_[buffer->data.get()] = buffer;
+      buf_map_[buffer.get()] = std::move(entry);
+
+      return buf_map_[buffer.get()];
+    } else {
+      LOG(FATAL) << "Can't work around the undefined buffer";
+      return *static_cast<BufferEntry*>(nullptr);
+    }
   }
 
   // The buffer assignment map
@@ -1181,6 +1217,9 @@ class BufferBindUnwrapper : public StmtExprMutator {
   std::unordered_set<const VarNode*> illegal_vars_;
   // Buffer map
   std::unordered_map<const BufferNode*, BufferEntry> buf_map_;
+  // Map from Var to the Buffer they occurred in.  In case of aliased
+  // buffers, contains the first buffer.
+  std::unordered_map<const VarNode*, Buffer> var_to_buffer_;
   // Set of vars that have occurred in an AllocateNode, but haven't
   // yet occurred in a BufferLoad/BufferStore.
   std::unordered_set<const VarNode*> buffer_var_defines_;
@@ -1311,13 +1350,12 @@ class StorageFlattener : public StmtExprMutator {
       auto pass = StorageFlattener(func->buffer_map, cache_line_size, create_bound_attributes,
                                    &bound_analyzer);
 
-      Map<Var, Buffer> preflattened_buffer_map =
-          Merge(func->buffer_map, func->preflattened_buffer_map);
-
       auto fptr = func.CopyOnWrite();
       fptr->body = pass(std::move(fptr->body));
-      fptr->preflattened_buffer_map = preflattened_buffer_map;
-      fptr->buffer_map = pass.UpdatedBufferMap();
+      // The buffers in func->buffer_map are deliberately left
+      // unflattened, as they are used for validation of user-provided
+      // arguments.  The flattened buffers used in the updated
+      // function body alias the argument buffers.
       return func;
     };
     return transform::CreatePrimFuncPass(pass_func, 0, "tir.StorageFlattener", {});
@@ -1345,15 +1383,12 @@ class StorageFlattener : public StmtExprMutator {
         }
       }
       e.external = true;
+      buffer_var_defines_.insert(kv.second->data.get());
       buf_map_[kv.second] = e;
-
-      updated_extern_buffer_map_.Set(kv.first, e.flattened_buffer);
     }
     cache_line_size_ = cache_line_size;
   }
 
-  Map<Var, Buffer> UpdatedBufferMap() { return updated_extern_buffer_map_; }
-
   Stmt VisitStmt_(const StoreNode* op) final {
     LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
     return Stmt();
@@ -1512,8 +1547,10 @@ class StorageFlattener : public StmtExprMutator {
         writer->dtype = DataType::Int(8);
       }
 
+      buffer_var_defines_.insert(op->buffer->data.get());
       buf_map_[key] = e;
       Stmt body = this->VisitStmt(op->body);
+      buffer_var_defines_.erase(op->buffer->data.get());
       buf_map_[key].in_scope = false;
 
       Stmt ret =
@@ -1777,8 +1814,6 @@ class StorageFlattener : public StmtExprMutator {
   std::unordered_map<const VarNode*, std::vector<Buffer>> buffer_var_map_;
   // Buffer map
   std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_;
-  // The extern buffer map, updated to include flattened buffers.
-  Map<Var, Buffer> updated_extern_buffer_map_;
   // Collects shapes.
   std::vector<std::pair<Var, Array<PrimExpr>>> shape_collector_;
   // bounds populator. We really need the analyzer from it.
diff --git a/src/tir/usmp/transform/assign_pool_info.cc b/src/tir/usmp/transform/assign_pool_info.cc
index 0671f1ea27..2bded7b487 100644
--- a/src/tir/usmp/transform/assign_pool_info.cc
+++ b/src/tir/usmp/transform/assign_pool_info.cc
@@ -166,8 +166,8 @@ IRModule PoolInfoAssigner::operator()() {
     if (kv.second->IsInstance<PrimFuncNode>()) {
       func_ = Downcast<PrimFunc>(kv.second);
       Stmt body = this->VisitStmt(func_->body);
-      PrimFunc new_prim_func = PrimFunc(func_->params, body, func_->ret_type, func_->buffer_map,
-                                        func_->preflattened_buffer_map, func_->attrs);
+      PrimFunc new_prim_func =
+          PrimFunc(func_->params, body, func_->ret_type, func_->buffer_map, func_->attrs);
       mod_->Update(gv, new_prim_func);
     }
   }
diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
index 56aba654b5..439e264338 100644
--- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
+++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
@@ -242,8 +242,8 @@ PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams(
     if (emit_tvmscript_printable_) {
       original_attrs = DictAttrs();
     }
-    PrimFunc ret = PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map,
-                            si.buffer_map, original_attrs);
+    PrimFunc ret =
+        PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, original_attrs);
     if (!emit_tvmscript_printable_) {
       ret = WithAttr(ret, tvm::attr::kPoolArgs, si.allocated_pool_params);
     }
@@ -449,12 +449,12 @@ IRModule PoolAllocationToOffsetConverter::operator()() {
   // We dont need attrs of PrimFunc that might include non printable attrs such as target
   // for unit tests where emit_tvmscript_printable_ is to be used.
   if (!emit_tvmscript_printable_) {
-    main_func = PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, {},
-                         main_func->attrs);
+    main_func =
+        PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, main_func->attrs);
     main_func = WithAttr(main_func, tvm::attr::kPoolArgs, si.allocated_pool_params);
   } else {
     main_func =
-        PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, {}, DictAttrs());
+        PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, DictAttrs());
   }
   module_->Update(gv, main_func);
   if (!emit_tvmscript_printable_) {
diff --git a/src/tir/usmp/transform/create_io_allocates.cc b/src/tir/usmp/transform/create_io_allocates.cc
index 59eee96163..cf75413177 100644
--- a/src/tir/usmp/transform/create_io_allocates.cc
+++ b/src/tir/usmp/transform/create_io_allocates.cc
@@ -195,9 +195,8 @@ IRModule IOAllocateCreator::operator()() {
     }
   }
   const GlobalVar& gv = mod_->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main);
-  mod_->Update(gv,
-               PrimFunc(new_main_params, main_body, main_func_->ret_type, main_func_->buffer_map,
-                        main_func_->preflattened_buffer_map, main_func_->attrs, main_func_->span));
+  mod_->Update(gv, PrimFunc(new_main_params, main_body, main_func_->ret_type,
+                            main_func_->buffer_map, main_func_->attrs, main_func_->span));
   return mod_;
 }
 
diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py
index c751d44b61..61128da71c 100644
--- a/tests/python/contrib/test_ethosu/test_encode_constants.py
+++ b/tests/python/contrib/test_ethosu/test_encode_constants.py
@@ -34,9 +34,11 @@ from .infra import make_ethosu_conv2d, make_ethosu_binary_elementwise
 @tvm.script.ir_module
 class WeightStreamOnlyU55:
     @T.prim_func
-    def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
+    def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+        placeholder = T.buffer_decl([8192], "int8", data=input_placeholder.data)
+        ethosu_write = T.buffer_decl([2048], "int8", data=input_ethosu_write.data)
         buffer1 = T.buffer_decl([160], "uint8")
         buffer3 = T.buffer_decl([144], "uint8")
         buffer5 = T.buffer_decl([144], "uint8")
@@ -62,10 +64,12 @@ class WeightStreamOnlyU55:
 @tvm.script.ir_module
 class WeightStreamOnlyU65:
     @T.prim_func
-    def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
+    def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         # buffer definition
+        placeholder = T.buffer_decl([8192], dtype="int8", data=input_placeholder.data)
+        ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data)
         buffer_encoded_1 = T.buffer_decl([192], dtype="uint8")
         buffer_encoded_2_1 = T.buffer_decl([192], dtype="uint8")
         buffer_encoded_4_1 = T.buffer_decl([208], dtype="uint8")
@@ -148,10 +152,12 @@ def test_weight_stream_only(accelerator, reference_mod, reference_const_sizes):
 @tvm.script.ir_module
 class RereadWeightsU55:
     @T.prim_func
-    def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
+    def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer1 = T.buffer_decl([384], "uint8")
+        placeholder = T.buffer_decl([8192], "int8", data=input_placeholder.data)
+        ethosu_write = T.buffer_decl([2048], "int8", data=input_ethosu_write.data)
         # body
         p1_data = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin":True})
         p1 = T.buffer_decl([384], "uint8", data=p1_data)
@@ -167,10 +173,12 @@ class RereadWeightsU55:
 @tvm.script.ir_module
 class RereadWeightsU65:
     @T.prim_func
-    def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
+    def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         # buffer definition
+        placeholder = T.buffer_decl([8192], dtype="int8", data=input_placeholder.data)
+        ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data)
         placeholder_encoded_1 = T.buffer_decl([464], "uint8")
         # body
         p1_data = T.allocate([464], "uint8", "global", annotations={"disable_lower_builtin":True})
@@ -246,13 +254,15 @@ def test_re_read_weights(accelerator, reference_mod, reference_const_sizes):
 @tvm.script.ir_module
 class DirectReadOnlyU55:
     @T.prim_func
-    def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
+    def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([592], "uint8")
         buffer_1 = T.buffer_decl([160], "uint8")
         buffer_2 = T.buffer_decl([160], "uint8")
         buffer_3 = T.buffer_decl([80], "uint8")
+        placeholder = T.buffer_decl([8192], "int8", data=input_placeholder.data)
+        ethosu_write = T.buffer_decl([2048], "int8", data=input_ethosu_write.data)
         # body
         ethosu_write_1_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True})
         ethosu_write_1 = T.buffer_decl([4096], "int8", data=ethosu_write_1_data)
@@ -264,7 +274,7 @@ class DirectReadOnlyU55:
 @tvm.script.ir_module
 class DirectReadOnlyU65:
     @T.prim_func
-    def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
+    def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         # buffer definition
@@ -272,6 +282,8 @@ class DirectReadOnlyU65:
         placeholder_encoded_1 = T.buffer_decl([160], dtype="uint8")
         placeholder_encoded_2 = T.buffer_decl([208], dtype="uint8")
         placeholder_encoded_3 = T.buffer_decl([96], dtype="uint8")
+        placeholder = T.buffer_decl([8192], dtype="int8", data=input_placeholder.data)
+        ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data)
         # body
         ethosu_write_2_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True})
         ethosu_write_2 = T.buffer_decl([4096], "int8", data=ethosu_write_2_data)
@@ -340,7 +352,7 @@ def test_direct_read_only(accelerator, reference_mod, reference_const_sizes):
 @tvm.script.ir_module
 class MixedReadU55:
     @T.prim_func
-    def main(ifm: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
+    def main(input_ifm: T.Buffer[(1,16,16,32), "int8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer1 = T.buffer_decl([112], "uint8")
@@ -349,6 +361,8 @@ class MixedReadU55:
         buffer7 = T.buffer_decl([112], "uint8")
         buffer9 = T.buffer_decl([592], "uint8")
         buffer10 = T.buffer_decl([160], "uint8")
+        ifm = T.buffer_decl([8192], "int8", data=input_ifm.data)
+        ethosu_write = T.buffer_decl([2048], "int8", data=input_ethosu_write.data)
         # body
         p1_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True})
         p1 = T.buffer_decl([112], "uint8", data=p1_data)
@@ -371,11 +385,12 @@ class MixedReadU55:
 @tvm.script.ir_module
 class MixedReadU65:
     @T.prim_func
-    def main(ifm: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
+    def main(input_ifm: T.Buffer[(1,16,16,32), "int8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
-
         # buffer definition
+        ifm = T.buffer_decl([8192], dtype="int8", data=input_ifm.data)
+        ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data)
         buffer1 = T.buffer_decl([128], dtype="uint8")
         buffer2 = T.buffer_decl([128], dtype="uint8")
         buffer3 = T.buffer_decl([128], dtype="uint8")
diff --git a/tests/python/contrib/test_ethosu/test_hoist_allocates.py b/tests/python/contrib/test_ethosu/test_hoist_allocates.py
index 6c6d51fa06..1508aa441c 100644
--- a/tests/python/contrib/test_ethosu/test_hoist_allocates.py
+++ b/tests/python/contrib/test_ethosu/test_hoist_allocates.py
@@ -106,15 +106,15 @@ def test_double_convolution():
     @tvm.script.ir_module
     class Module:
         @T.prim_func
-        def main(placeholder: T.Buffer[(3402,), "int8"], placeholder_encoded: T.Buffer[(128,), "uint8"], placeholder_encoded_1: T.Buffer[(32,), "uint8"], placeholder_encoded_2: T.Buffer[(128,), "uint8"], placeholder_encoded_3: T.Buffer[(32,), "uint8"], ethosu_write: T.Buffer[(3402,), "int8"]) -> None:
+        def main(input_placeholder: T.Buffer[(1, 27, 42, 3), "int8"], input_placeholder_encoded: T.Buffer[(3, 3, 2, 3), "uint8"], input_placeholder_encoded_1: T.Buffer[(3, 10), "uint8"], input_placeholder_encoded_2: T.Buffer[(3, 3, 2, 3), "uint8"], input_placeholder_encoded_3: T.Buffer[(3, 10), "uint8"], input_ethosu_write: T.Buffer[(1, 27, 42, 3), "int8"]) -> None:
             # function attr dict
             T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
-            T.preflattened_buffer(placeholder, [1, 27, 42, 3], dtype="int8", data=placeholder.data)
-            T.preflattened_buffer(placeholder_encoded, [3, 3, 2, 3], dtype="int8")
-            T.preflattened_buffer(placeholder_encoded_1, [3, 10], dtype="uint8")
-            T.preflattened_buffer(placeholder_encoded_2, [3, 3, 2, 3], dtype="int8")
-            T.preflattened_buffer(placeholder_encoded_3, [3, 10], dtype="uint8")
-            T.preflattened_buffer(ethosu_write, [1, 27, 42, 3], dtype="int8", data=ethosu_write.data)
+            placeholder = T.buffer_decl([3402], dtype="int8", data=input_placeholder.data)
+            placeholder_encoded = T.buffer_decl([128], dtype="int8", data=input_placeholder_encoded.data)
+            placeholder_encoded_1 = T.buffer_decl([32], dtype="uint8", data=input_placeholder_encoded_1.data)
+            placeholder_encoded_2 = T.buffer_decl([128], dtype="int8", data=input_placeholder_encoded_2.data)
+            placeholder_encoded_3 = T.buffer_decl([32], dtype="uint8", data=input_placeholder_encoded_3.data)
+            ethosu_write = T.buffer_decl([3402], dtype="int8", data=input_ethosu_write.data)
             # body
             placeholder_global_data = T.allocate([128], "uint8", "global")
             placeholder_global = T.buffer_decl([128], "uint8", data=placeholder_global_data)
@@ -150,11 +150,10 @@ def test_identities():
     @tvm.script.ir_module
     class Module:
         @T.prim_func
-        def main(placeholder: T.Buffer[(24,), "int8"], T_concat: T.Buffer[(24,), "int8"]) -> None:
+        def main(input_placeholder: T.Buffer[(1, 2, 3, 4), "int8"], T_concat: T.Buffer[(24,), "int8"]) -> None:
             # function attr dict
             T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
-            T.preflattened_buffer(placeholder, [1, 2, 3, 4], dtype="int8", data=placeholder.data)
-            T.preflattened_buffer(T_concat, [24], dtype="int8", data=T_concat.data)
+            placeholder = T.buffer_decl([24], dtype="int8", data=input_placeholder.data)
             # body
             ethosu_write_data = T.allocate([12], "int8", "global")
             ethosu_write = T.buffer_decl([12], "int8", data=ethosu_write_data)
@@ -188,11 +187,11 @@ def test_outer_seq_stmt():
     @tvm.script.ir_module
     class Module:
         @T.prim_func
-        def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"], buffer_encoded_1: T.Buffer[(32,), "uint8"], buffer_encoded_2: T.Buffer[(112,), "uint8"], buffer_encoded_3: T.Buffer[(32,), "uint8"], buffer_encoded_4: T.Buffer[(112,), "uint8"], buffer_encoded_5: T.Buffer[(32,), "uint8"], buffer_encoded_6: T.Buffer[(112,), "uint8"], buffer_encoded_7: T.Buffer[(32,), "uint8"]) -> None:
+        def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"], buffer_encoded_1: T.Buffer[(32,), "uint8"], buffer_encoded_2: T.Buffer[(112,), "uint8"], buffer_encoded_3: T.Buffer[(32,), "uint8"], buffer_encoded_4: T.Buffer[(112,), "uint8"], buffer_encoded_5: T.Buffer[(32,), "uint8"], buffer_encoded_6: T.Buffer[(112,), "uint8"], buffer_encoded_7: T.Buffer[(32,), "uint8"]) -> None:
             # function attr dict
             T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
-            T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
-            T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
+            placeholder = T.buffer_decl([8192], dtype="int8", data=input_placeholder.data)
+            ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data)
             # body
             with T.allocate([128], "uint8", "global") as placeholder_global_data:
                 placeholder_global = T.buffer_decl([128], "uint8", data=placeholder_global_data)
@@ -238,11 +237,11 @@ def test_allocate_without_seq_stmt():
     @tvm.script.ir_module
     class Module:
         @T.prim_func
-        def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"], buffer_encoded_1: T.Buffer[(32,), "uint8"], buffer_encoded_2: T.Buffer[(112,), "uint8"], buffer_encoded_3: T.Buffer[(32,), "uint8"], buffer_encoded_4: T.Buffer[(112,), "uint8"], buffer_encoded_5: T.Buffer[(32,), "uint8"], buffer_encoded_6: T.Buffer[(112,), "uint8"], buffer_encoded_7: T.Buffer[(32,), "uint8"]) -> None:
+        def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"], buffer_encoded_1: T.Buffer[(32,), "uint8"], buffer_encoded_2: T.Buffer[(112,), "uint8"], buffer_encoded_3: T.Buffer[(32,), "uint8"], buffer_encoded_4: T.Buffer[(112,), "uint8"], buffer_encoded_5: T.Buffer[(32,), "uint8"], buffer_encoded_6: T.Buffer[(112,), "uint8"], buffer_encoded_7: T.Buffer[(32,), "uint8"]) -> None:
             # function attr dict
             T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
-            T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
-            T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
+            placeholder = T.buffer_decl([8192], dtype="int8", data=input_placeholder.data)
+            ethosu_write = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write.data)
             # body
             placeholder_global_data = T.allocate([128], "uint8", "global")
             placeholder_global = T.buffer_decl([128], "uint8", data=placeholder_global_data)
diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py b/tests/python/contrib/test_ethosu/test_merge_constants.py
index a5adcfceac..ed1927b849 100644
--- a/tests/python/contrib/test_ethosu/test_merge_constants.py
+++ b/tests/python/contrib/test_ethosu/test_merge_constants.py
@@ -399,12 +399,12 @@ def test_read_from_the_same_buffer():
     @tvm.script.ir_module
     class InputModule:
         @T.prim_func
-        def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
+        def main(input_placeholder: T.Buffer[(1, 16, 16, 32), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], input_ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None:
             # function attr dict
             T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
             # buffer definition
-            T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
-            T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
+            placeholder = T.buffer_decl(8192, dtype="int8", data=input_placeholder.data)
+            ethosu_write = T.buffer_decl(2048, dtype="int8", data=input_ethosu_write.data)
             # body
             p1_data = T.allocate([368], "uint8", "global")
             p1 = T.buffer_decl([368], "uint8", data=p1_data)
@@ -419,9 +419,12 @@ def test_read_from_the_same_buffer():
     @tvm.script.ir_module
     class ReferenceModule:
         @T.prim_func
-        def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
+        def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(464,), "uint8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"]) -> None:
             # function attr dict
             T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            # buffer definition
+            placeholder = T.buffer_decl(8192, dtype="int8", data=input_placeholder.data)
+            ethosu_write = T.buffer_decl(2048, dtype="int8", data=input_ethosu_write.data)
             # body
             p1_data = T.allocate([464], "uint8", "global")
             p1 = T.buffer_decl([464], "uint8", data=p1_data)
@@ -446,12 +449,12 @@ def test_arbitrary_argument_order():
     @tvm.script.ir_module
     class InputModule:
         @T.prim_func
-        def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None:
+        def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None:
             # function attr dict
             T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
             # buffer definition
-            T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
-            T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
+            placeholder = T.buffer_decl(8192, dtype="int8", data=input_placeholder.data)
+            ethosu_write = T.buffer_decl(2048, dtype="int8", data=input_ethosu_write.data)
             # body
             p1_data = T.allocate([368], "uint8", "global")
             p1 = T.buffer_decl([368], "uint8", data=p1_data)
@@ -473,9 +476,12 @@ def test_arbitrary_argument_order():
     @tvm.script.ir_module
     class ReferenceModule:
         @T.prim_func
-        def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None:
+        def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(464,), "uint8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None:
             # function attr dict
             T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            # buffer definition
+            placeholder = T.buffer_decl(8192, dtype="int8", data=input_placeholder.data)
+            ethosu_write = T.buffer_decl(2048, dtype="int8", data=input_ethosu_write.data)
             # body
             p1_data = T.allocate([464], "uint8", "global")
             p1 = T.buffer_decl([464], "uint8", data=p1_data)
@@ -509,12 +515,12 @@ def test_arbitrary_argument_order_const_split():
     @tvm.script.ir_module
     class InputModule:
         @T.prim_func
-        def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(96,), "uint8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None:
+        def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(368,), "uint8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"], buffer2: T.Buffer[(96,), "uint8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None:
             # function attr dict
             T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
             # buffer definition
-            T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
-            T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
+            placeholder = T.buffer_decl(8192, dtype="int8", data=input_placeholder.data)
+            ethosu_write = T.buffer_decl(2048, dtype="int8", data=input_ethosu_write.data)
             # body
             p1_data = T.allocate([368], "uint8", "global")
             p1 = T.buffer_decl([368], "uint8", data=p1_data)
@@ -536,9 +542,12 @@ def test_arbitrary_argument_order_const_split():
     @tvm.script.ir_module
     class ReferenceModule:
         @T.prim_func
-        def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None:
+        def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(464,), "uint8"], input_ethosu_write: T.Buffer[(1,16,16,8), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None:
             # function attr dict
             T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            # buffer definition
+            placeholder = T.buffer_decl(8192, dtype="int8", data=input_placeholder.data)
+            ethosu_write = T.buffer_decl(2048, dtype="int8", data=input_ethosu_write.data)
             # body
             p1_data = T.allocate([464], "uint8", "global")
             p1 = T.buffer_decl([464], "uint8", data=p1_data)
@@ -572,12 +581,12 @@ def test_arbitrary_argument_order_const_split_mixed():
     @tvm.script.ir_module
     class InputModule:
         @T.prim_func
-        def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(368,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer3: T.Buffer[(96,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None:
+        def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(368,), "uint8"], input_ethosu_write: T.Buffer[(2,16,16,8), "int8"], buffer3: T.Buffer[(96,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None:
             # function attr dict
             T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
             # buffer definition
-            T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
-            T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
+            placeholder = T.buffer_decl(8192, dtype='int8', data=input_placeholder.data)
+            ethosu_write = T.buffer_decl(4096, dtype='int8', data=input_ethosu_write.data)
             # body
             p1_data = T.allocate([368], "uint8", "global")
             p1 = T.buffer_decl([368], "uint8", data=p1_data)
@@ -599,9 +608,12 @@ def test_arbitrary_argument_order_const_split_mixed():
     @tvm.script.ir_module
     class ReferenceModule:
         @T.prim_func
-        def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], buffer2: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"]) -> None:
+        def main(input_placeholder: T.Buffer[(1,16,16,32), "int8"], buffer1: T.Buffer[(464,), "uint8"], buffer2: T.Buffer[(464,), "uint8"], input_ethosu_write: T.Buffer[(2,16,16,8), "int8"]) -> None:
             # function attr dict
             T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            # buffer definition
+            placeholder = T.buffer_decl(8192, dtype='int8', data=input_placeholder.data)
+            ethosu_write = T.buffer_decl(4096, dtype='int8', data=input_ethosu_write.data)
             # body
             p1_data = T.allocate([464], "uint8", "global")
             p1 = T.buffer_decl([464], "uint8", data=p1_data)
diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py
index e6414c24d4..379a35b1b4 100644
--- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py
+++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py
@@ -30,9 +30,14 @@ from .infra import make_ethosu_conv2d
 @tvm.script.ir_module
 class ReferenceModule:
     @T.prim_func
-    def main(placeholder: T.Buffer[(1536,), "int8"], placeholder_1: T.Buffer[(1280,), "int8"], T_concat: T.Buffer[(4096,), "int8"]) -> None:
+    def main(input_placeholder: T.Buffer[(1,8,12,16), "int8"], input_placeholder_1: T.Buffer[(1,8,10,16), "int8"], input_T_concat: T.Buffer[(1,8,32,16), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+
+        placeholder = T.buffer_decl(1536, dtype="int8", data=input_placeholder.data)
+        placeholder_1 = T.buffer_decl(1280, dtype="int8", data=input_placeholder_1.data)
+        T_concat = T.buffer_decl(4096, dtype="int8", data=input_T_concat.data)
+
         buffer = T.buffer_decl([2992], "uint8")
         buffer_1 = T.buffer_decl([160], "uint8")
         buffer_2 = T.buffer_decl([2992], "uint8")
diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py
index ae46057369..46c6976567 100644
--- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py
+++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py
@@ -366,13 +366,15 @@ def test_conv2d_single(trial):
 @tvm.script.ir_module
 class Conv2dDoubleCascade1:
     @T.prim_func
-    def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512,), "int8"]) -> None:
+    def main(input_placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], input_ethosu_write_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([304], "uint8")
         buffer_1 = T.buffer_decl([80], "uint8")
         buffer_2 = T.buffer_decl([320], "uint8")
         buffer_3 = T.buffer_decl([160], "uint8")
+        placeholder_5 = T.buffer_decl([192], 'int8', data=input_placeholder_5.data)
+        ethosu_write_1 = T.buffer_decl([512], 'int8', data=input_ethosu_write_1.data)
         # body
         ethosu_write_2_data = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True})
         ethosu_write_2 = T.buffer_decl([1024], "int8", data=ethosu_write_2_data)
@@ -386,13 +388,15 @@ class Conv2dDoubleCascade1:
 @tvm.script.ir_module
 class Conv2dDoubleCascade2:
     @T.prim_func
-    def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512,), "int8"]) -> None:
+    def main(input_placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], input_ethosu_write_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([80], "uint8")
         buffer_1 = T.buffer_decl([320], "uint8")
         buffer_2 = T.buffer_decl([1312], "uint8")
         buffer_3 = T.buffer_decl([2608], "uint8")
+        placeholder_5 = T.buffer_decl([192], 'int8', data=input_placeholder_5.data)
+        ethosu_write_1 = T.buffer_decl([512], 'int8', data=input_ethosu_write_1.data)
         # body
         ethosu_write_2_data = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True})
         ethosu_write_2 = T.buffer_decl([1536], "int8", data=ethosu_write_2_data)
@@ -406,13 +410,16 @@ class Conv2dDoubleCascade2:
 @tvm.script.ir_module
 class Conv2dDoubleCascade3:
     @T.prim_func
-    def main(placeholder_5: T.Buffer[(768,), "int8"], ethosu_write_1: T.Buffer[(640,), "int8"]) -> None:
+    def main(input_placeholder_5: T.Buffer[(1, 16, 16, 3), "int8"], input_ethosu_write_1: T.Buffer[(1, 20, 4, 8), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([1744], "uint8")
         buffer_1 = T.buffer_decl([80], "uint8")
         buffer_2 = T.buffer_decl([320], "uint8")
         buffer_3 = T.buffer_decl([880], "uint8")
+        placeholder_5 = T.buffer_decl([768], 'int8', data=input_placeholder_5.data)
+        ethosu_write_1 = T.buffer_decl([640], 'int8', data=input_ethosu_write_1.data)
+
         # body
         ethosu_write_2_data = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True})
         ethosu_write_2 = T.buffer_decl([2560], "int8", data=ethosu_write_2_data)
@@ -428,13 +435,15 @@ class Conv2dDoubleCascade3:
 @tvm.script.ir_module
 class Conv2dDoubleCascade4:
     @T.prim_func
-    def main(placeholder_5: T.Buffer[(1024,), "int8"], ethosu_write_1: T.Buffer[(2048,), "int8"]) -> None:
+    def main(input_placeholder_5: T.Buffer[(1, 8, 1, 8, 16), "int8"], input_ethosu_write_1: T.Buffer[(1, 8, 2, 8, 16), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([1456], "uint8")
         buffer_1 = T.buffer_decl([352], "uint8")
         buffer_2 = T.buffer_decl([272], "uint8")
         buffer_3 = T.buffer_decl([11040], "uint8")
+        placeholder_5 = T.buffer_decl([1024], 'int8', data=input_placeholder_5.data)
+        ethosu_write_1 = T.buffer_decl([2048], 'int8', data=input_ethosu_write_1.data)
         # body
         ethosu_write_2_data = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True})
         ethosu_write_2 = T.buffer_decl((2304,), "int8", data=ethosu_write_2_data)
@@ -448,13 +457,15 @@ class Conv2dDoubleCascade4:
 @tvm.script.ir_module
 class Conv2dDoubleCascade5:
     @T.prim_func
-    def main(placeholder: T.Buffer[(192,), "int8"], ethosu_write: T.Buffer[(8192,), "int8"]) -> None:
+    def main(input_placeholder: T.Buffer[(1, 8, 8, 3), "int8"], input_ethosu_write: T.Buffer[(1, 32, 32, 8), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([160], "uint8")
         buffer_1 = T.buffer_decl([320], "uint8")
         buffer_2 = T.buffer_decl([304], "uint8")
         buffer_3 = T.buffer_decl([80], "uint8")
+        placeholder = T.buffer_decl([192], 'int8', data=input_placeholder.data)
+        ethosu_write = T.buffer_decl([8192], 'int8', data=input_ethosu_write.data)
         # body
         ethosu_write_1_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True})
         ethosu_write_1 = T.buffer_decl([4096], "int8", data=ethosu_write_1_data)
@@ -468,13 +479,15 @@ class Conv2dDoubleCascade5:
 @tvm.script.ir_module
 class Conv2dDoubleCascade6:
     @T.prim_func
-    def main(placeholder: T.Buffer[(1024,), "int8"], ethosu_write: T.Buffer[(32768,), "int8"]) -> None:
+    def main(input_placeholder: T.Buffer[(1, 8, 1, 8, 16), "int8"], input_ethosu_write: T.Buffer[(1, 32, 2, 32, 16), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([1456], "uint8")
         buffer_1 = T.buffer_decl([352], "uint8")
         buffer_2 = T.buffer_decl([11040], "uint8")
         buffer_3 = T.buffer_decl([272], "uint8")
+        placeholder = T.buffer_decl([1024], 'int8', data=input_placeholder.data)
+        ethosu_write = T.buffer_decl([32768], 'int8', data=input_ethosu_write.data)
         # body
         ethosu_write_1_data = T.allocate([12288], "int8", "global", annotations={"disable_lower_builtin":True})
         ethosu_write_1 = T.buffer_decl([12288], "int8", data=ethosu_write_1_data)
@@ -630,11 +643,13 @@ def test_conv2d_double_cascade(trial):
 @tvm.script.ir_module
 class Conv2dInlineCopy1:
     @T.prim_func
-    def main(placeholder_3: T.Buffer[(960,), "int8"], ethosu_write_1: T.Buffer[(1024,), "int8"]) -> None:
+    def main(input_placeholder_3: T.Buffer[(1, 10, 12, 8), "int8"], input_ethosu_write_1: T.Buffer[(1, 8, 8, 16), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([848], "uint8")
         buffer_1 = T.buffer_decl([160], "uint8")
+        placeholder_3 = T.buffer_decl([960], 'int8', data=input_placeholder_3.data)
+        ethosu_write_1 = T.buffer_decl([1024], 'int8', data=input_ethosu_write_1.data)
         # body
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, placeholder_3[120], 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 848, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
     __tvm_meta__ = None
@@ -643,11 +658,13 @@ class Conv2dInlineCopy1:
 @tvm.script.ir_module
 class Conv2dInlineCopy2:
     @T.prim_func
-    def main(placeholder_3: T.Buffer[(315,), "int8"], ethosu_write_1: T.Buffer[(240,), "int8"]) -> None:
+    def main(input_placeholder_3: T.Buffer[(1, 7, 9, 5), "int8"], input_ethosu_write_1: T.Buffer[(1, 3, 5, 16), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([160], "uint8")
         buffer_1 = T.buffer_decl([656], "uint8")
+        placeholder_3 = T.buffer_decl([315], 'int8', data=input_placeholder_3.data)
+        ethosu_write_1 = T.buffer_decl([240], 'int8', data=input_ethosu_write_1.data)
         # body
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, placeholder_3[146], 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 656, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
     __tvm_meta__ = None
@@ -685,11 +702,13 @@ def test_conv2d_inline_copy(trial):
 @tvm.script.ir_module
 class Conv2dInlineReshape1:
     @T.prim_func
-    def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None:
+    def main(input_placeholder_3: T.Buffer[(4, 6, 8, 1), "int8"], input_ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([160], "uint8")
         buffer_1 = T.buffer_decl([848], "uint8")
+        placeholder_3 = T.buffer_decl([192], 'int8', data=input_placeholder_3.data)
+        ethosu_write_1 = T.buffer_decl([768], 'int8', data=input_ethosu_write_1.data)
         # body
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
@@ -699,11 +718,13 @@ class Conv2dInlineReshape1:
 @tvm.script.ir_module
 class Conv2dInlineReshape2:
     @T.prim_func
-    def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None:
+    def main(input_placeholder_3: T.Buffer[(1, 24, 8), "int8"], input_ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([160], "uint8")
         buffer_1 = T.buffer_decl([848], "uint8")
+        placeholder_3 = T.buffer_decl([192], 'int8', data=input_placeholder_3.data)
+        ethosu_write_1 = T.buffer_decl([768], 'int8', data=input_ethosu_write_1.data)
         # body
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
@@ -713,11 +734,13 @@ class Conv2dInlineReshape2:
 @tvm.script.ir_module
 class Conv2dInlineReshape3:
     @T.prim_func
-    def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None:
+    def main(input_placeholder_3: T.Buffer[(192, 1), "int8"], input_ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([160], "uint8")
         buffer_1 = T.buffer_decl([848], "uint8")
+        placeholder_3 = T.buffer_decl([192], 'int8', data=input_placeholder_3.data)
+        ethosu_write_1 = T.buffer_decl([768], 'int8', data=input_ethosu_write_1.data)
         # body
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
@@ -727,11 +750,12 @@ class Conv2dInlineReshape3:
 @tvm.script.ir_module
 class Conv2dInlineReshape4:
     @T.prim_func
-    def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768,), "int8"]) -> None:
+    def main(placeholder_3: T.Buffer[(192,), "int8"], input_ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([160], "uint8")
         buffer_1 = T.buffer_decl([848], "uint8")
+        ethosu_write_1 = T.buffer_decl([768], 'int8', data=input_ethosu_write_1.data)
         # body
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py
index 8c7ff35272..7da3d7e5be 100644
--- a/tests/python/contrib/test_ethosu/test_replace_copy.py
+++ b/tests/python/contrib/test_ethosu/test_replace_copy.py
@@ -31,10 +31,12 @@ from .infra import make_ethosu_conv2d
 @tvm.script.ir_module
 class ReferenceModule:
     @T.prim_func
-    def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(2048,), "int8"]) -> None:
+    def main(input_placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write_1: T.Buffer[(1, 16, 16, 8), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer_1 = T.buffer_decl([384], "uint8")
+        placeholder_3 = T.buffer_decl([8192], dtype="int8", data=input_placeholder_3.data)
+        ethosu_write_1 = T.buffer_decl([2048], dtype="int8", data=input_ethosu_write_1.data)
         # body
         placeholder_global_data = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin": True})
         placeholder_global = T.buffer_decl([384], "uint8", data=placeholder_global_data)
@@ -73,11 +75,13 @@ def test_copy():
 @tvm.script.ir_module
 class WeightStream:
     @T.prim_func
-    def main(placeholder_5: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(4096,), "int8"]) -> None:
+    def main(input_placeholder_5: T.Buffer[(1, 16, 16, 32), "int8"], input_ethosu_write_1: T.Buffer[(1, 16, 16, 16), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([528], "uint8")
         buffer_2 = T.buffer_decl([336], "uint8")
+        placeholder_5 = T.buffer_decl([8192], dtype="int8", data=input_placeholder_5.data)
+        ethosu_write_1 = T.buffer_decl([4096], dtype="int8", data=input_ethosu_write_1.data)
         # body
         placeholder_d_global_data = T.allocate([528], "uint8", "global", annotations={"disable_lower_builtin": True})
         placeholder_d_global = T.buffer_decl([528], "uint8", data=placeholder_d_global_data)
diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py
index 254abab644..fd1e1afa60 100644
--- a/tests/python/contrib/test_ethosu/test_scheduler.py
+++ b/tests/python/contrib/test_ethosu/test_scheduler.py
@@ -180,8 +180,10 @@ def test_schedule_cache_reads():
 @tvm.script.ir_module
 class DiamondGraphTir:
     @T.prim_func
-    def main(placeholder: T.Buffer[(301056,), "int8"], ethosu_write: T.Buffer[(75264,), "int8"]) -> None:
+    def main(input_placeholder: T.Buffer[(1, 56, 56, 96), "int8"], input_ethosu_write: T.Buffer[(1, 56, 56, 24), "int8"]) -> None:
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+        placeholder = T.buffer_decl([301056], dtype='int8', data=input_placeholder.data)
+        ethosu_write = T.buffer_decl([75264], dtype='int8', data=input_ethosu_write.data)
         buffer1 = T.buffer_decl([2848], "uint8")
         buffer3 = T.buffer_decl([976], "uint8")
         p1_data = T.allocate([2848], "uint8", "global", annotations={"disable_lower_builtin":True})
diff --git a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py
index fb41e99a9b..4aa12aedf2 100755
--- a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py
+++ b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py
@@ -304,7 +304,7 @@ class TestElementWise:
     def test_param_shapes(self, ir_module, transformed_input_shape, transformed_output_shape):
         func = ir_module["main"]
         primfunc_input_shape, primfunc_output_shape = [
-            list(func.preflattened_buffer_map[param].shape) for param in func.params
+            list(func.buffer_map[param].shape) for param in func.params
         ]
         assert primfunc_input_shape == transformed_input_shape
         assert primfunc_output_shape == transformed_output_shape
diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py
index cd0114d464..106e0f52ad 100644
--- a/tests/python/unittest/test_aot_legalize_packed_call.py
+++ b/tests/python/unittest/test_aot_legalize_packed_call.py
@@ -26,15 +26,12 @@ from tvm.script import tir as T
 class Module:
     @T.prim_func
     def tvm_test_cpacked(
-        A: T.handle, B: T.handle, C: T.handle, device_context: T.handle
+        A: T.Buffer[(1,), "float32"],
+        B: T.Buffer[(1,), "float32"],
+        C: T.Buffer[(1,), "float32"],
+        device_context: T.Buffer[(1,), "float32"],
     ) -> T.handle:
-        A_0 = T.match_buffer(A, (1,), dtype="float32")
-        T.preflattened_buffer(A_0, (1,), dtype="float32")
-        B_0 = T.match_buffer(B, (1,), dtype="float32")
-        T.preflattened_buffer(B_0, (1,), dtype="float32")
-        C_0 = T.match_buffer(C, (1,), dtype="float32")
-        T.preflattened_buffer(C_0, (1,), dtype="float32")
-        T.evaluate(C)
+        T.evaluate(C.data)
 
     @T.prim_func
     def tir_packed_call() -> None:
@@ -59,15 +56,12 @@ class Module:
 class Expected:
     @T.prim_func
     def tvm_test_cpacked(
-        A: T.handle, B: T.handle, C: T.handle, device_context: T.handle
+        A: T.Buffer[(1,), "float32"],
+        B: T.Buffer[(1,), "float32"],
+        C: T.Buffer[(1,), "float32"],
+        device_context: T.Buffer[(1,), "float32"],
     ) -> T.handle:
-        A_0 = T.match_buffer(A, (1,), dtype="float32")
-        T.preflattened_buffer(A_0, (1,), dtype="float32")
-        B_0 = T.match_buffer(B, (1,), dtype="float32")
-        T.preflattened_buffer(B_0, (1,), dtype="float32")
-        C_0 = T.match_buffer(C, (1,), dtype="float32")
-        T.preflattened_buffer(C_0, (1,), dtype="float32")
-        T.evaluate(C)
+        T.evaluate(C.data)
 
     @T.prim_func
     def tir_packed_call() -> None:
diff --git a/tests/python/unittest/test_arith_domain_touched.py b/tests/python/unittest/test_arith_domain_touched.py
index 3641f06ab8..9f7eee0963 100644
--- a/tests/python/unittest/test_arith_domain_touched.py
+++ b/tests/python/unittest/test_arith_domain_touched.py
@@ -30,18 +30,6 @@ def scalar_func(a: T.handle, b: T.handle):
         A[i, j] = B[i - 1, j + 1] + A[i - 1, j - 1]
 
 
-@T.prim_func
-def vector_func(a: T.handle, b: T.handle):
-    n = T.var("int32")
-    m = 128
-    A = T.match_buffer(a, (n, m))
-    B = T.match_buffer(b, (n, m))
-
-    for i in T.serial(n):
-        for j in T.vectorized(m):
-            A[i, j] = A[i, j] + B[i, j]
-
-
 def test_domain_touched():
     func = scalar_func
     a, b = [func.buffer_map[var] for var in func.params]
@@ -81,7 +69,17 @@ def test_domain_touched():
 
 
 def test_domain_touched_vector():
-    func = tvm.lower(vector_func)["main"]
+    m = tvm.runtime.convert(128)
+
+    @T.prim_func
+    def func(a: T.handle, b: T.handle):
+        n = T.var("int32")
+        A = T.match_buffer(a, (n * m,))
+        B = T.match_buffer(b, (n * m,))
+
+        for i in T.serial(n):
+            A[i * m : (i + 1) * m : 1] = A[i * m : (i + 1) * m : 1] + B[i * m : (i + 1) * m : 1]
+
     a, b = [func.buffer_map[var] for var in func.params]
 
     assert tvm.arith._ffi_api.DomainTouched(func.body, a, True, False)[0].extent.value == 128
diff --git a/tests/python/unittest/test_auto_scheduler_feature.py b/tests/python/unittest/test_auto_scheduler_feature.py
index 4140e7732d..3f435366e1 100644
--- a/tests/python/unittest/test_auto_scheduler_feature.py
+++ b/tests/python/unittest/test_auto_scheduler_feature.py
@@ -203,20 +203,20 @@ def test_gpu_feature():
 
 @T.prim_func
 def tir_matmul(
-    A: T.Buffer[(16384,), "float32"],
-    B: T.Buffer[(16384,), "float32"],
-    C: T.Buffer[(16384,), "float32"],
+    A: T.Buffer[(256, 256), "float32"],
+    B: T.Buffer[(256, 256), "float32"],
+    C: T.Buffer[(256, 256), "float32"],
 ) -> None:
     # function attr dict
     T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
-    T.preflattened_buffer(A, [128, 128], dtype="float32", data=A.data)
-    T.preflattened_buffer(B, [128, 128], dtype="float32", data=B.data)
-    T.preflattened_buffer(C, [128, 128], dtype="float32", data=C.data)
+    A_flat = T.buffer_decl([16384], dtype="float32", data=A.data)
+    B_flat = T.buffer_decl([16384], dtype="float32", data=B.data)
+    C_flat = T.buffer_decl([16384], dtype="float32", data=C.data)
     # body
     for x, y in T.grid(128, 128):
-        C[x * 128 + y] = T.float32(0)
+        C_flat[x * 128 + y] = T.float32(0)
         for k in T.serial(128):
-            C[x * 128 + y] = C[x * 128 + y] + A[x * 128 + k] * B[y * 128 + k]
+            C_flat[x * 128 + y] = C_flat[x * 128 + y] + A_flat[x * 128 + k] * B_flat[y * 128 + k]
 
 
 def test_primfunc_without_lowering():
diff --git a/tests/python/unittest/test_lower_build.py b/tests/python/unittest/test_lower_build.py
index bd820b617c..665697b84b 100644
--- a/tests/python/unittest/test_lower_build.py
+++ b/tests/python/unittest/test_lower_build.py
@@ -54,40 +54,44 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
 class LoweredModule:
     @T.prim_func
     def main(
-        A: T.Buffer[(16384,), "float32"],
-        B: T.Buffer[(16384,), "float32"],
-        C: T.Buffer[(16384,), "float32"],
+        A: T.Buffer[(128, 128), "float32"],
+        B: T.Buffer[(128, 128), "float32"],
+        C: T.Buffer[(128, 128), "float32"],
     ) -> None:
         # function attr dict
         T.func_attr({"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True})
-        T.preflattened_buffer(A, [128, 128], data=A.data)
-        T.preflattened_buffer(B, [128, 128], data=B.data)
-        T.preflattened_buffer(C, [128, 128], data=C.data)
+        A_flat = T.buffer_decl([16384], data=A.data)
+        B_flat = T.buffer_decl([16384], data=B.data)
+        C_flat = T.buffer_decl([16384], data=C.data)
         # body
         for x, y in T.grid(128, 128):
-            C[x * 128 + y] = 0.0
+            C_flat[x * 128 + y] = 0.0
             for k in T.serial(0, 128):
-                C[x * 128 + y] = C[x * 128 + y] + A[x * 128 + k] * B[y * 128 + k]
+                C_flat[x * 128 + y] = (
+                    C_flat[x * 128 + y] + A_flat[x * 128 + k] * B_flat[y * 128 + k]
+                )
 
 
 @tvm.script.ir_module
 class LoweredTIRModule:
     @T.prim_func
     def main(
-        A: T.Buffer[(16384,), "float32"],
-        B: T.Buffer[(16384,), "float32"],
-        C: T.Buffer[(16384,), "float32"],
+        A: T.Buffer[(128, 128), "float32"],
+        B: T.Buffer[(128, 128), "float32"],
+        C: T.Buffer[(128, 128), "float32"],
     ) -> None:
         # function attr dict
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        T.preflattened_buffer(A, [128, 128], data=A.data)
-        T.preflattened_buffer(B, [128, 128], data=B.data)
-        T.preflattened_buffer(C, [128, 128], data=C.data)
+        A_flat = T.buffer_decl([16384], data=A.data)
+        B_flat = T.buffer_decl([16384], data=B.data)
+        C_flat = T.buffer_decl([16384], data=C.data)
         # body
         for x, y in T.grid(128, 128):
-            C[x * 128 + y] = 0.0
+            C_flat[x * 128 + y] = 0.0
             for k in T.serial(0, 128):
-                C[x * 128 + y] = C[x * 128 + y] + A[x * 128 + k] * B[y * 128 + k]
+                C_flat[x * 128 + y] = (
+                    C_flat[x * 128 + y] + A_flat[x * 128 + k] * B_flat[y * 128 + k]
+                )
 
 
 def test_lower_build_te_schedule():
diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py
index 870208499e..513e04dc20 100644
--- a/tests/python/unittest/test_tir_transform_flatten_buffer.py
+++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py
@@ -40,9 +40,9 @@ class TestElementwise(BaseCompare):
             for j in T.serial(0, 16):
                 C[i, j] = B_new[0, j] * 2.0
 
-    def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]):
-        T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
-        T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
+    def expected(input_A: T.Buffer[(16, 16), "float32"], input_C: T.Buffer[(16, 16), "float32"]):
+        A = T.buffer_decl(256, dtype="float32", data=input_A.data)
+        C = T.buffer_decl(256, dtype="float32", data=input_C.data)
         for i in T.serial(0, 16):
             B_new_data = T.allocate([16], "float32", scope="global")
             B_new = T.buffer_decl([16], "float32", scope="global", data=B_new_data)
@@ -71,9 +71,9 @@ class TestElementwiseWithoutDeclBuffer(BaseCompare):
             for j in T.serial(0, 16):
                 C[i, j] = B_new[0, j] * 2.0
 
-    def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]):
-        T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
-        T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
+    def expected(input_A: T.Buffer[(16, 16), "float32"], input_C: T.Buffer[(16, 16), "float32"]):
+        A = T.buffer_decl(256, dtype="float32", data=input_A.data)
+        C = T.buffer_decl(256, dtype="float32", data=input_C.data)
         for i in T.serial(0, 16):
             B_new_data = T.allocate([16], "float32", "global")
             B_new = T.buffer_decl(16, "float32", data=B_new_data)
@@ -100,9 +100,9 @@ class TestGPU(BaseCompare):
         for j in range(0, 16):
             C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0
 
-    def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]):
-        T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
-        T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
+    def expected(input_A: T.Buffer[(16, 16), "float32"], input_C: T.Buffer[(16, 16), "float32"]):
+        A = T.buffer_decl(256, dtype="float32", data=input_A.data)
+        C = T.buffer_decl(256, dtype="float32", data=input_C.data)
 
         i0 = T.env_thread("blockIdx.x")
         i1 = T.env_thread("threadIdx.x")
@@ -134,10 +134,10 @@ class TestSymbolic(BaseCompare):
                 C[i, j] = B[j] * 2.0
 
     def expected(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None:
-        A = T.match_buffer(a, n * m, "float32")
-        C = T.match_buffer(c, n * m, "float32")
-        T.preflattened_buffer(A, (n, m), "float32", data=A.data)
-        T.preflattened_buffer(C, (n, m), "float32", data=C.data)
+        input_A = T.match_buffer(a, (n, m), "float32")
+        input_C = T.match_buffer(c, (n, m), "float32")
+        A = T.buffer_decl(n * m, "float32", data=input_A.data)
+        C = T.buffer_decl(n * m, "float32", data=input_C.data)
 
         for i in range(0, n):
             B_data = T.allocate([m], "float32", scope="global")
@@ -159,9 +159,9 @@ class TestMultiAlloc(BaseCompare):
             C[i, j] = A[i, j] + B[i, j]
             D[i, j] = C[i, j] * 2.0
 
-    def expected(A: T.Buffer[128, "float32"], D: T.Buffer[128, "float32"]):
-        T.preflattened_buffer(A, (4, 32), "float32", data=A.data)
-        T.preflattened_buffer(D, (4, 32), "float32", data=D.data)
+    def expected(input_A: T.Buffer[(4, 32), "float32"], input_D: T.Buffer[(4, 32), "float32"]):
+        A = T.buffer_decl(128, "float32", data=input_A.data)
+        D = T.buffer_decl(128, "float32", data=input_D.data)
 
         for i, j in T.grid(4, 32):
             B_data = T.allocate([128], "float32", scope="global")
@@ -185,9 +185,9 @@ class TestStrided(BaseCompare):
             for i1, j in T.grid(4, 16):
                 C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0
 
-    def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]):
-        T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data)
-        T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data)
+    def expected(input_A: T.Buffer[(16, 16), "float32"], input_C: T.Buffer[(16, 16), "float32"]):
+        A = T.buffer_decl(256, dtype="float32", data=input_A.data)
+        C = T.buffer_decl(256, dtype="float32", data=input_C.data)
         for i0 in T.serial(0, 4):
             B_new_data = T.allocate([68], "float32", scope="global")
             B_new = T.buffer_decl([68], "float32", scope="global", data=B_new_data)
@@ -206,9 +206,9 @@ class TestBoolean(BaseCompare):
         for i0 in T.serial(10):
             B[i0] = A[i0]
 
-    def expected(A: T.Buffer[10, "int8"], B: T.Buffer[10, "int8"]) -> None:
-        T.preflattened_buffer(A, [10], dtype="bool", data=A.data)
-        T.preflattened_buffer(B, [10], dtype="bool", data=B.data)
+    def expected(input_A: T.Buffer[10, "bool"], input_B: T.Buffer[10, "bool"]) -> None:
+        A = T.buffer_decl(10, dtype="int8", data=input_A.data)
+        B = T.buffer_decl(10, dtype="int8", data=input_B.data)
         # body
         for i0 in T.serial(10):
             B[i0] = T.cast(T.cast(A[i0], "bool"), "int8")
diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py
index 5612815529..fe48aa7d8f 100644
--- a/tests/python/unittest/test_tir_transform_loop_partition.py
+++ b/tests/python/unittest/test_tir_transform_loop_partition.py
@@ -544,9 +544,6 @@ def partitioned_concat(
     A: T.Buffer[(16,), "float32"], B: T.Buffer[(16,), "float32"], C: T.Buffer[(32,), "float32"]
 ) -> None:
     T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
-    T.preflattened_buffer(A, [16], data=A.data)
-    T.preflattened_buffer(B, [16], data=B.data)
-    T.preflattened_buffer(C, [32], data=C.data)
     for i in T.serial(0, 16):
         C[i] = A[i]
     for i in T.serial(0, 16):
@@ -581,42 +578,46 @@ def partition_from_scheduled_tir(prim_func, pass_cfg):
 
 @T.prim_func
 def partitioned_concat_3(
-    placeholder: T.Buffer[(50176,), "int8"],
-    placeholder_1: T.Buffer[(25088,), "int8"],
-    placeholder_2: T.Buffer[(25088,), "int8"],
-    T_concat: T.Buffer[(100352,), "int8"],
+    placeholder: T.Buffer[(1, 64, 28, 28), "int8"],
+    placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"],
+    placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"],
+    T_concat: T.Buffer[(1, 128, 28, 28), "int8"],
 ) -> None:
-    T.preflattened_buffer(placeholder, [1, 64, 28, 28], "int8", data=placeholder.data)
-    T.preflattened_buffer(placeholder_1, [1, 32, 28, 28], "int8", data=placeholder_1.data)
-    T.preflattened_buffer(placeholder_2, [1, 32, 28, 28], "int8", data=placeholder_2.data)
-    T.preflattened_buffer(T_concat, [1, 128, 28, 28], "int8", data=T_concat.data)
+    placeholder_flat = T.buffer_decl([50176], "int8", data=placeholder.data)
+    placeholder_1_flat = T.buffer_decl([25088], "int8", data=placeholder_1.data)
+    placeholder_2_flat = T.buffer_decl([25088], "int8", data=placeholder_2.data)
+    T_concat_flat = T.buffer_decl([100352], "int8", data=T_concat.data)
     for i1, i2, i3 in T.grid(64, 28, 28):
-        T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3]
+        T_concat_flat[i1 * 784 + i2 * 28 + i3] = placeholder_flat[i1 * 784 + i2 * 28 + i3]
     for i1, i2, i3 in T.grid(32, 28, 28):
-        T_concat[i1 * 784 + i2 * 28 + i3 + 50176] = placeholder_1[i1 * 784 + i2 * 28 + i3]
+        T_concat_flat[i1 * 784 + i2 * 28 + i3 + 50176] = placeholder_1_flat[i1 * 784 + i2 * 28 + i3]
     for i1, i2, i3 in T.grid(32, 28, 28):
-        T_concat[i1 * 784 + i2 * 28 + i3 + 75264] = placeholder_2[i1 * 784 + i2 * 28 + i3]
+        T_concat_flat[i1 * 784 + i2 * 28 + i3 + 75264] = placeholder_2_flat[i1 * 784 + i2 * 28 + i3]
 
 
 @T.prim_func
 def concat_func_3(
-    placeholder: T.Buffer[(50176,), "int8"],
-    placeholder_1: T.Buffer[(25088,), "int8"],
-    placeholder_2: T.Buffer[(25088,), "int8"],
-    T_concat: T.Buffer[(100352,), "int8"],
+    placeholder: T.Buffer[(1, 64, 28, 28), "int8"],
+    placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"],
+    placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"],
+    T_concat: T.Buffer[(1, 128, 28, 28), "int8"],
 ) -> None:
-    T.preflattened_buffer(placeholder, (1, 64, 28, 28), "int8", data=placeholder.data)
-    T.preflattened_buffer(placeholder_1, (1, 32, 28, 28), "int8", data=placeholder_1.data)
-    T.preflattened_buffer(placeholder_2, (1, 32, 28, 28), "int8", data=placeholder_2.data)
-    T.preflattened_buffer(T_concat, (1, 128, 28, 28), "int8", data=T_concat.data)
+    placeholder_flat = T.buffer_decl([50176], "int8", data=placeholder.data)
+    placeholder_1_flat = T.buffer_decl([25088], "int8", data=placeholder_1.data)
+    placeholder_2_flat = T.buffer_decl([25088], "int8", data=placeholder_2.data)
+    T_concat_flat = T.buffer_decl([100352], "int8", data=T_concat.data)
     for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}):
         for i2, i3 in T.grid(28, 28):
             if 96 <= i1:
-                T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_2[i1 * 784 + i2 * 28 + i3 - 75264]
+                T_concat_flat[i1 * 784 + i2 * 28 + i3] = placeholder_2_flat[
+                    i1 * 784 + i2 * 28 + i3 - 75264
+                ]
             if 64 <= i1 and i1 < 96:
-                T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_1[i1 * 784 + i2 * 28 + i3 - 50176]
+                T_concat_flat[i1 * 784 + i2 * 28 + i3] = placeholder_1_flat[
+                    i1 * 784 + i2 * 28 + i3 - 50176
+                ]
             if i1 < 64:
-                T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3]
+                T_concat_flat[i1 * 784 + i2 * 28 + i3] = placeholder_flat[i1 * 784 + i2 * 28 + i3]
 
 
 def test_condition_mutually_exclusive():
@@ -628,9 +629,11 @@ def test_condition_mutually_exclusive():
 
 def test_loop_partition_unroll_hint():
     @T.prim_func
-    def main(A: T.Buffer[150528, "int8"], B: T.Buffer[25088, "int8"]) -> None:
-        T.preflattened_buffer(A, [1, 3, 224, 224], "int8", data=A.data)
-        T.preflattened_buffer(B, [1, 224, 7, 16], "int8", data=B.data)
+    def main(
+        A_arg: T.Buffer[(1, 3, 224, 224), "int8"], B_arg: T.Buffer[(1, 224, 7, 16), "int8"]
+    ) -> None:
+        A = T.buffer_decl(150528, "int8", data=A_arg.data)
+        B = T.buffer_decl(25088, "int8", data=B_arg.data)
         for ax0 in T.serial(
             112,
             annotations={"pragma_loop_partition_hint": True},
@@ -640,9 +643,11 @@ def test_loop_partition_unroll_hint():
                     B[ax1 * 112 + ax2 * 16 + ax3] = A[ax3 * 50176 + ax1 * 224 + ax0 * 2 + ax2 - 3]
 
     @T.prim_func
-    def partitioned_main(A: T.Buffer[150528, "int8"], B: T.Buffer[25088, "int8"]) -> None:
-        T.preflattened_buffer(A, [1, 3, 224, 224], dtype="int8", data=A.data)
-        T.preflattened_buffer(B, [1, 224, 7, 16], dtype="int8", data=B.data)
+    def partitioned_main(
+        A_arg: T.Buffer[(1, 3, 224, 224), "int8"], B_arg: T.Buffer[(1, 224, 7, 16), "int8"]
+    ) -> None:
+        A = T.buffer_decl(150528, dtype="int8", data=A_arg.data)
+        B = T.buffer_decl(25088, dtype="int8", data=B_arg.data)
         # body
         for ax1, ax2, ax3 in T.grid(224, 7, 16):
             if 3 <= ax2 and ax3 < 3:
@@ -688,8 +693,6 @@ def test_loop_partition_keep_loop_annotations():
 
     @T.prim_func
     def after(A: T.Buffer[160, "int32"], B: T.Buffer[160, "int32"]) -> None:
-        T.preflattened_buffer(A, [160], dtype="int32", data=A.data)
-        T.preflattened_buffer(B, [160], dtype="int32", data=B.data)
         for i in T.serial(10, annotations={"key": "value"}):
             B[i] = A[i] + 1
         for i in T.serial(140, annotations={"key": "value"}):
@@ -737,10 +740,6 @@ def test_loop_partition_with_unit_loop_in_condition():
         placeholder_2: T.Buffer[25088, "int8"],
         T_concat: T.Buffer[100352, "int8"],
     ) -> None:
-        T.preflattened_buffer(placeholder, [50176], dtype="int8", data=placeholder.data)
-        T.preflattened_buffer(placeholder_1, [25088], dtype="int8", data=placeholder_1.data)
-        T.preflattened_buffer(placeholder_2, [25088], dtype="int8", data=placeholder_2.data)
-        T.preflattened_buffer(T_concat, [100352], dtype="int8", data=T_concat.data)
         for _ in T.serial(1, annotations={"preserve_unit_loop": True}):
             for i1, i2, i3 in T.grid(64, 28, 28):
                 T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3]
diff --git a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py
index bfa132d4ce..635badb847 100644
--- a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py
+++ b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py
@@ -25,12 +25,12 @@ from tvm.script import tir as T
 @tvm.script.ir_module
 class Before:
     @T.prim_func
-    def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None:
+    def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None:
         # function attr dict
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data)
-        T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data)
-        T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data)
+        inputs_flat = T.buffer_decl([8192], dtype="float32", data=inputs.data)
+        weight_flat = T.buffer_decl([2097152], dtype="float32", data=weight.data)
+        conv2d_transpose_nhwc_flat = T.buffer_decl([16384], dtype="float32", data=conv2d_transpose_nhwc.data)
         # var definition
         threadIdx_x = T.env_thread("threadIdx.x")
         blockIdx_x = T.env_thread("blockIdx.x")
@@ -44,24 +44,24 @@ class Before:
             conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0)
         for i6_0 in T.serial(16):
             for ax0_ax1_ax2_ax3_fused_0 in T.serial(24):
-                PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 and blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 < 5, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype= [...]
+                PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 and blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 < 5, inputs_flat[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), d [...]
             for ax0_ax1_ax2_ax3_fused_0 in T.serial(32):
-                weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) % 256 // 8 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)]
+                weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight_flat[T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) % 256 // 8 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)]
             for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2):
                 conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024]
         for ax1, ax2 in T.grid(2, 4):
-            conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2]
+            conv2d_transpose_nhwc_flat[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2]
 
 
 @tvm.script.ir_module
 class After:
     @T.prim_func
-    def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None:
+    def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None:
         # function attr dict
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data)
-        T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data)
-        T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data)
+        inputs_flat = T.buffer_decl([8192], dtype="float32", data=inputs.data)
+        weight_flat = T.buffer_decl([2097152], dtype="float32", data=weight.data)
+        conv2d_transpose_nhwc_flat = T.buffer_decl([16384], dtype="float32", data=conv2d_transpose_nhwc.data)
         # var definition
         threadIdx_x = T.env_thread("threadIdx.x")
         blockIdx_x = T.env_thread("blockIdx.x")
@@ -75,27 +75,27 @@ class After:
             conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0)
         for i6_0 in T.serial(16):
             for ax0_ax1_ax2_ax3_fused_0 in T.serial(24):
-                PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(1 <= (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 4 and (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 20 < 1 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4 and (blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4) // 5 < 1, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0),  [...]
+                PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(1 <= (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 4 and (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 20 < 1 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4 and (blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4) // 5 < 1, inputs_flat[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32 [...]
             for ax0_ax1_ax2_ax3_fused_0 in T.serial(32):
-                weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp((ax0_ax1_ax2_ax3_fused_0 + threadIdx_x * 4 // 128) // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x * 4 // 8) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)]
+                weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight_flat[T.ramp((ax0_ax1_ax2_ax3_fused_0 + threadIdx_x * 4 // 128) // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x * 4 // 8) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)]
             for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2):
                 conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024]
         for ax1, ax2 in T.grid(2, 4):
-            conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2]
+            conv2d_transpose_nhwc_flat[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2]
 
 
 @tvm.script.ir_module
 class After_simplified:
     @T.prim_func
-    def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None:
+    def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None:
         # function attr dict
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
         # var definition
         threadIdx_x = T.env_thread("threadIdx.x")
         blockIdx_x = T.env_thread("blockIdx.x")
-        T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data)
-        T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data)
-        T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data)
+        inputs_flat = T.buffer_decl([8192], dtype="float32", data=inputs.data)
+        weight_flat = T.buffer_decl([2097152], dtype="float32", data=weight.data)
+        conv2d_transpose_nhwc_flat = T.buffer_decl([16384], dtype="float32", data=conv2d_transpose_nhwc.data)
         # body
         T.launch_thread(blockIdx_x, 64)
         conv2d_transpose_nhwc_local = T.decl_buffer([8], "float32", scope="local")
@@ -106,13 +106,13 @@ class After_simplified:
             conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0)
         for i6_0 in T.serial(16):
             for ax0_ax1_ax2_ax3_fused_0 in T.serial(24):
-                PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(4 <= ax0_ax1_ax2_ax3_fused_0 and ax0_ax1_ax2_ax3_fused_0 < 20 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32")
+                PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(4 <= ax0_ax1_ax2_ax3_fused_0 and ax0_ax1_ax2_ax3_fused_0 < 20 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, inputs_flat[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32")
             for ax0_ax1_ax2_ax3_fused_0 in T.serial(32):
-                weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + ax0_ax1_ax2_ax3_fused_0 % 2 * 4096 + threadIdx_x // 2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)]
+                weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight_flat[T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + ax0_ax1_ax2_ax3_fused_0 % 2 * 4096 + threadIdx_x // 2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)]
             for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2):
                 conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024]
         for ax1, ax2 in T.grid(2, 4):
-            conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2]
+            conv2d_transpose_nhwc_flat[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2]
 
 # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,redundant-keyword-arg
 # fmt: on
diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py
index c80cd55ea2..0c5d77d02b 100644
--- a/tests/python/unittest/test_tir_transform_thread_sync.py
+++ b/tests/python/unittest/test_tir_transform_thread_sync.py
@@ -98,10 +98,10 @@ def test_sync_else_branch():
 @tvm.testing.requires_cuda
 def test_sync_read_thread_id_independent_location():
     @T.prim_func
-    def func(p0: T.Buffer[2, "float32"], p1: T.Buffer[2, "float32"]) -> None:
+    def func(p0_arg: T.Buffer[(1, 2, 1, 1), "float32"], p1: T.Buffer[2, "float32"]) -> None:
         threadIdx_x = T.env_thread("threadIdx.x")
         blockIdx_x = T.env_thread("blockIdx.x")
-        T.preflattened_buffer(p0, [1, 2, 1, 1], dtype="float32", data=p0.data)
+        p0 = T.buffer_decl([2], dtype="float32", data=p0_arg.data)
         result_local = T.alloc_buffer([1], dtype="float32", scope="local")
         temp_shared = T.alloc_buffer([1], dtype="float32", scope="shared")
         T.launch_thread(blockIdx_x, 8)
diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
index 31cc6e07de..d1f86814e7 100644
--- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
+++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
@@ -75,11 +75,8 @@ class LinearStructure:
         # function attr dict
         T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True})
         placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=64, offset_factor=1)
-        T.preflattened_buffer(placeholder_4, [150528], dtype="uint8", elem_offset=0, align=64, offset_factor=1)
         placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=64, offset_factor=1)
-        T.preflattened_buffer(placeholder_5, [1], dtype="int16", elem_offset=0, align=64, offset_factor=1)
         T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=64, offset_factor=1)
-        T.preflattened_buffer(T_subtract_1, [452], dtype="int16", elem_offset=0, align=64, offset_factor=1)
         # body
         for ax0_ax1_fused_1 in T.serial(0, 224):
             for ax2_1, ax3_inner_1 in T.grid(224, 3):
@@ -90,13 +87,9 @@ class LinearStructure:
         # function attr dict
         T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True})
         placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=64, offset_factor=1)
-        T.preflattened_buffer(placeholder_65, [150528], dtype="int16", elem_offset=0, align=64, offset_factor=1)
         placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=64, offset_factor=1)
-        T.preflattened_buffer(placeholder_66, [9408], dtype="int16", elem_offset=0, align=64, offset_factor=1)
         placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=64, offset_factor=1)
-        T.preflattened_buffer(placeholder_67, [64], dtype="int32", elem_offset=0, align=64, offset_factor=1)
         T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=64, offset_factor=1)
-        T.preflattened_buffer(T_cast_21, [289], dtype="uint8", elem_offset=0, align=64, offset_factor=1)
         # body
         PaddedInput_7_data = T.allocate([157323], "int16", "global")
         PaddedInput_7 = T.buffer_decl(shape=[157323], dtype="int16", data=PaddedInput_7_data)
@@ -118,9 +111,7 @@ class LinearStructure:
         # function attr dict
         T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True})
         placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=64, offset_factor=1)
-        T.preflattened_buffer(placeholder_29, [802816], dtype="uint8", elem_offset=0, align=64, offset_factor=1)
         T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=64, offset_factor=1)
-        T.preflattened_buffer(T_cast_7, [177], dtype="int16", elem_offset=0, align=64, offset_factor=1)
         # body
         tensor_2_data = T.allocate([200704], "uint8", "global")
         tensor_2 = T.buffer_decl(shape=[200704], dtype="uint8", data=tensor_2_data)
@@ -168,13 +159,9 @@ class LinearStructurePlanned:
     @T.prim_func
     def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.Ptr[T.uint8], slow_memory_7_var: T.Ptr[T.uint8]) -> None:
         placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8")
-        T.preflattened_buffer(placeholder_29, [802816], dtype="uint8")
         T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16")
-        T.preflattened_buffer(T_cast_7, [177], dtype="int16")
         fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16)
-        T.preflattened_buffer(fast_memory_6_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16)
         slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16)
-        T.preflattened_buffer(slow_memory_7_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16)
         # body
         tensor_2_let = T.buffer_decl([200704], dtype="uint8")
         with T.let(tensor_2_let.data, T.address_of(fast_memory_6_buffer_var[0], dtype="handle")):
@@ -189,15 +176,10 @@ class LinearStructurePlanned:
     @T.prim_func
     def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.Ptr[T.uint8], slow_memory_3_var: T.Ptr[T.uint8]) -> None:
         placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8")
-        T.preflattened_buffer(placeholder_4, [150528], dtype="uint8")
         placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16")
-        T.preflattened_buffer(placeholder_5, [1], dtype="int16")
         T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16")
-        T.preflattened_buffer(T_subtract_1, [452], dtype="int16")
         fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16)
-        T.preflattened_buffer(fast_memory_2_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16)
         slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16)
-        T.preflattened_buffer(slow_memory_3_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16)
         # body
         for ax0_ax1_fused_1, ax2_1, ax3_inner_1 in T.grid(224, 224, 3):
             T_subtract_1[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1] = T.cast(placeholder_4[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16") - placeholder_5[0]
@@ -205,17 +187,11 @@ class LinearStructurePlanned:
     @T.prim_func
     def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.Ptr[T.uint8], slow_memory_5_var: T.Ptr[T.uint8]) -> None:
         placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16")
-        T.preflattened_buffer(placeholder_65, [150528], dtype="int16")
         placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16")
-        T.preflattened_buffer(placeholder_66, [9408], dtype="int16")
         placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32")
-        T.preflattened_buffer(placeholder_67, [64], dtype="int32")
         T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8")
-        T.preflattened_buffer(T_cast_21, [289], dtype="uint8")
         fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16)
-        T.preflattened_buffer(fast_memory_4_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16)
         slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16)
-        T.preflattened_buffer(slow_memory_5_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16)
         # body
         PaddedInput_7_let = T.buffer_decl([157323], "int16")
         with T.let(PaddedInput_7_let.data, T.address_of(slow_memory_5_buffer_var[802816], dtype="handle")):
@@ -280,11 +256,8 @@ class ResnetStructure:
         # function attr dict
         T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True})
         placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8")
-        T.preflattened_buffer(placeholder_2, [360000], dtype="uint8")
         placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32")
-        T.preflattened_buffer(placeholder_3, [64], dtype="int32")
         T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16")
-        T.preflattened_buffer(T_cast_1, [215], dtype="int16")
         # body
         for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16):
             T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16")
@@ -294,13 +267,9 @@ class ResnetStructure:
         # function attr dict
         T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True})
         placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16")
-        T.preflattened_buffer(placeholder_13, [360000], dtype="int16")
         placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16")
-        T.preflattened_buffer(placeholder_14, [36864], dtype="int16")
         placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32")
-        T.preflattened_buffer(placeholder_15, [64], dtype="int32")
         T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16")
-        T.preflattened_buffer(T_cast_5, [215], dtype="int16")
         # body
         PaddedInput_1_data = T.allocate([379456], "int16", "global")
         PaddedInput_1 = T.buffer_decl(shape=[379456], dtype="int16", data=PaddedInput_1_data)
@@ -321,13 +290,9 @@ class ResnetStructure:
         # function attr dict
         T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True})
         placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16")
-        T.preflattened_buffer(placeholder_19, [360000], dtype="int16")
         placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16")
-        T.preflattened_buffer(placeholder_20, [16384], dtype="int16")
         placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32")
-        T.preflattened_buffer(placeholder_21, [256], dtype="int32")
         T_add_1 = T.match_buffer(T_add, [407], dtype="int32")
-        T.preflattened_buffer(T_add_1, [407], dtype="int32")
         # body
         PaddedInput_2_data = T.allocate([360000], "int16", "global")
         PaddedInput_2 = T.buffer_decl(shape=[360000], dtype="int16", data=PaddedInput_2_data)
@@ -349,15 +314,10 @@ class ResnetStructure:
         # function attr dict
         T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True})
         placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16")
-        T.preflattened_buffer(placeholder_29, [360000], dtype="int16")
         placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16")
-        T.preflattened_buffer(placeholder_27, [16384], dtype="int16")
         placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32")
-        T.preflattened_buffer(placeholder_26, [256], dtype="int32")
         placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32")
-        T.preflattened_buffer(placeholder_28, [1440000], dtype="int32")
         T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8")
-        T.preflattened_buffer(T_cast_7, [407], dtype="uint8")
         # body
         PaddedInput_3_data = T.allocate([360000], "int16", "global")
         PaddedInput_3 = T.buffer_decl(shape=[360000], dtype="int16", data=PaddedInput_3_data)
@@ -396,13 +356,9 @@ class ResnetStructure:
         # function attr dict
         T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True})
         placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16")
-        T.preflattened_buffer(placeholder_7, [360000], dtype="int16")
         placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16")
-        T.preflattened_buffer(placeholder_8, [4096], dtype="int16")
         placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32")
-        T.preflattened_buffer(placeholder_9, [64], dtype="int32")
         T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16")
-        T.preflattened_buffer(T_cast_3, [215], dtype="int16")
         # body
         PaddedInput_data = T.allocate([360000], "int16", "global")
         PaddedInput = T.buffer_decl([360000], "int16", data=PaddedInput_data)
@@ -426,13 +382,9 @@ class ResnetStructurePlanned:
     @T.prim_func
     def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.Ptr[T.uint8]) -> None:
         placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8")
-        T.preflattened_buffer(placeholder_2, [360000], dtype="uint8")
         placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32")
-        T.preflattened_buffer(placeholder_3, [64], dtype="int32")
         T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16")
-        T.preflattened_buffer(T_cast_1, [215], dtype="int16")
         global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16)
-        T.preflattened_buffer(global_workspace_1_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16)
         # body
         for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16):
             T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16")
@@ -440,17 +392,11 @@ class ResnetStructurePlanned:
     @T.prim_func
     def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.Ptr[T.uint8]) -> None:
         placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16")
-        T.preflattened_buffer(placeholder_29, [360000], dtype="int16")
         placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16")
-        T.preflattened_buffer(placeholder_27, [16384], dtype="int16")
         placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32")
-        T.preflattened_buffer(placeholder_26, [256], dtype="int32")
         placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32")
-        T.preflattened_buffer(placeholder_28, [1440000], dtype="int32")
         T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8")
-        T.preflattened_buffer(T_cast_7, [407], dtype="uint8")
         global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16)
-        T.preflattened_buffer(global_workspace_5_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16)
         # body
         PaddedInput_3_let = T.buffer_decl([360000], 'int16')
         with T.let(PaddedInput_3_let.data, T.address_of(global_workspace_5_buffer_var[6480000], dtype="handle")):
@@ -470,15 +416,10 @@ class ResnetStructurePlanned:
     @T.prim_func
     def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.Ptr[T.uint8]) -> None:
         placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16")
-        T.preflattened_buffer(placeholder_19, [360000], dtype="int16")
         placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16")
-        T.preflattened_buffer(placeholder_20, [16384], dtype="int16")
         placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32")
-        T.preflattened_buffer(placeholder_21, [256], dtype="int32")
         T_add_1 = T.match_buffer(T_add, [407], dtype="int32")
-        T.preflattened_buffer(T_add_1, [407], dtype="int32")
         global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16)
-        T.preflattened_buffer(global_workspace_4_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16)
         # body
         PaddedInput_2_let = T.buffer_decl([360000], "int16")
         with T.let(PaddedInput_2_let.data, T.address_of(global_workspace_4_buffer_var[7200000], dtype="handle")):
@@ -498,15 +439,10 @@ class ResnetStructurePlanned:
     @T.prim_func
     def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.Ptr[T.uint8]) -> None:
         placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16")
-        T.preflattened_buffer(placeholder_7, [360000], dtype="int16")
         placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16")
-        T.preflattened_buffer(placeholder_8, [4096], dtype="int16")
         placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32")
-        T.preflattened_buffer(placeholder_9, [64], dtype="int32")
         T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16")
-        T.preflattened_buffer(T_cast_3, [215], dtype="int16")
         global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16)
-        T.preflattened_buffer(global_workspace_2_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16)
         # body
         PaddedInput_let = T.buffer_decl([360000], "int16")
         with T.let(PaddedInput_let.data, T.address_of(global_workspace_2_buffer_var[7200000], dtype="handle")):
@@ -525,15 +461,10 @@ class ResnetStructurePlanned:
     @T.prim_func
     def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.Ptr[T.uint8]) -> None:
         placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16")
-        T.preflattened_buffer(placeholder_13, [360000], dtype="int16")
         placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16")
-        T.preflattened_buffer(placeholder_14, [36864], dtype="int16")
         placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32")
-        T.preflattened_buffer(placeholder_15, [64], dtype="int32")
         T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16")
-        T.preflattened_buffer(T_cast_5, [215], dtype="int16")
         global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16)
-        T.preflattened_buffer(global_workspace_3_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16)
         # body
         PaddedInput_1_let = T.buffer_decl([379456], "int16")
         with T.let(PaddedInput_1_let.data, T.address_of(global_workspace_3_buffer_var[0], dtype="handle")):
@@ -630,9 +561,6 @@ class TensorIntrinStructurePlanned:
         global_workspace_1_buffer_var = T.match_buffer(
             global_workspace_1_var, [40], dtype="uint8", strides=[1], elem_offset=0, align=16
         )
-        T.preflattened_buffer(
-            global_workspace_1_buffer_var, [40], dtype="uint8", strides=[1], elem_offset=0, align=16
-        )
         dense_let = T.buffer_decl([10], "int32")
         with T.let(dense_let.data, T.address_of(global_workspace_1_buffer_var[0], dtype="handle")):
             T.evaluate(
diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py
index 32293cccdc..f542080f89 100644
--- a/tests/python/unittest/test_tvmscript_error_report.py
+++ b/tests/python/unittest/test_tvmscript_error_report.py
@@ -565,26 +565,6 @@ def test_non_integer_typed_block_iter():
     check_error(non_integer_typed_block_iter, 3)
 
 
-def test_preflattened_buffer_map_align():
-    def preflattened_buffer_map_align_nonint(foo: T.handle):
-        foo_1 = T.match_buffer(foo, [1])
-        T.preflattened_buffer(
-            foo_1, [1], align="bar"
-        )  # check_error: align: want int or IntImm, got 'bar'
-
-    check_error(preflattened_buffer_map_align_nonint, 3)
-
-
-def test_preflattened_buffer_map_offset_factor():
-    def preflattened_buffer_map_offset_factor_nonint(foo: T.handle):
-        foo_1 = T.match_buffer(foo, [1])
-        T.preflattened_buffer(
-            foo_1, [1], offset_factor="bar"
-        )  # check_error: offset_factor: want int or IntImm, got 'bar'
-
-    check_error(preflattened_buffer_map_offset_factor_nonint, 3)
-
-
 def test_illegal_buffer_slice():
     def strided_buffer_region(A: T.handle):
         # do not allow stride in buffer region
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
index 29e03f8bb6..7d542c7bc7 100644
--- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
@@ -41,7 +41,6 @@ def test_ir_builder_tir_primfunc_base():
         body=tir.Evaluate(0),
         ret_type=None,
         buffer_map=None,
-        preflattened_buffer_map=None,
         attrs=None,
     )
 
@@ -60,7 +59,6 @@ def test_ir_builder_tir_primfunc_complete():
             T.func_attr({"key": "value"})
             T.func_ret(tvm.ir.PrimType("int64"))
             buffer_d = T.match_buffer(d, (64, 64), "int64")
-            T.preflattened_buffer(e, (32, 32), "int8", data=e.data)
             T.evaluate(0)
 
     # the prim_func generated by IRBuilder
@@ -83,9 +81,6 @@ def test_ir_builder_tir_primfunc_complete():
         body=tir.Evaluate(0),
         ret_type=tvm.ir.PrimType("int64"),
         buffer_map={c_handle: c_buffer, d_handle: d_buffer, e_handle: e_buffer},
-        preflattened_buffer_map={
-            e_handle: tir.decl_buffer((32, 32), "int8", name="e_preflatten", data=e_buffer.data)
-        },
         attrs=tvm.ir.make_node("DictAttrs", key="value"),
     )
 
diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py
index a39354b955..02b18e7e7c 100644
--- a/tests/python/unittest/test_tvmscript_syntax_sugar.py
+++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py
@@ -186,23 +186,6 @@ def test_dynamic_shape_gemm():
     assert_structural_equal(gemm_dyn_shape, gemm_dyn_shape_roundtrip)
 
 
-@T.prim_func
-def preflattened_buffer_map(A: T.handle, B: T.handle):
-    A_1 = T.match_buffer(A, [1])
-    T.preflattened_buffer(A_1, [1], align=1, offset_factor=2)
-    B_1 = T.match_buffer(B, [1])
-    T.preflattened_buffer(B_1, [1])
-    B_1[0] = A_1[0]
-
-
-def test_preflattened_buffer_map():
-    A_var = [
-        k for k, _ in preflattened_buffer_map.preflattened_buffer_map.items() if k.name == "A"
-    ][0]
-    assert preflattened_buffer_map.preflattened_buffer_map[A_var].data_alignment == 1
-    assert preflattened_buffer_map.preflattened_buffer_map[A_var].offset_factor == 2
-
-
 @T.prim_func
 def match_buffer_int64(a: T.handle, c: T.handle) -> None:
     A = T.match_buffer(a, (T.int64(128), T.int64(128)), dtype="float32")