You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/12 16:23:51 UTC

[tvm] branch main updated: [microNPU] Add MergeConstants pass (#12029)

This is an automated email from the ASF dual-hosted git repository.

manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fbf80bb386 [microNPU] Add MergeConstants pass (#12029)
fbf80bb386 is described below

commit fbf80bb3869f5f5046e3a1a5bb335e21f2f9deae
Author: Nicola Lancellotti <ni...@arm.com>
AuthorDate: Tue Jul 12 17:23:44 2022 +0100

    [microNPU] Add MergeConstants pass (#12029)
    
    * [microNPU] Add MergeConstants pass
    
    Change-Id: I1ff51d8147fba8c66d442a370b9f058e9b2758d8
    
    * Fix errors and warnings
    
    Change-Id: I29f68f83a73fa00ca34ed0ab2321c53c6b761137
    
    * Address comments
    
    Change-Id: Iad59107d5abdec6b079c6fd4ab48c6bffbb5e0bb
    
    * Fix lint error
    
    Change-Id: Ie5caf506337de01e169d6f422e4682eefbd93241
---
 .../relay/backend/contrib/ethosu/tir/compiler.py   |   4 +
 .../tvm/relay/backend/contrib/ethosu/tir/passes.py |  35 ++
 src/tir/contrib/ethosu/passes.cc                   | 643 ++++++++++++++++++++-
 .../test_ethosu/cascader/test_integration.py       |  10 +-
 .../contrib/test_ethosu/test_encode_constants.py   | 244 +++-----
 .../contrib/test_ethosu/test_merge_constants.py    | 561 ++++++++++++++++++
 tests/python/contrib/test_ethosu/test_networks.py  |  14 +-
 .../test_ethosu/test_remove_concatenates.py        |   3 -
 .../contrib/test_ethosu/test_replace_conv2d.py     |  24 -
 .../contrib/test_ethosu/test_replace_copy.py       |  37 +-
 tests/python/contrib/test_ethosu/test_scheduler.py |  24 +-
 11 files changed, 1336 insertions(+), 263 deletions(-)

diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
index 0fd82378c3..85c6df4c7d 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
@@ -90,6 +90,10 @@ def lower_ethosu(sch, args, const_dict, name="main"):
         mod = tvm.tir.transform.RemoveNoOp()(mod)
         mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod)
         mod = ethosu_passes.HoistAllocates()(mod)
+        #  MergeConstant pass currently does not support striped schedules.
+        #  It requires further investigation.
+        if not util.is_striping_enabled():
+            mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod)
         mod = ethosu_passes.CopyComputeReordering()(mod)
 
         # When striping is enabled and if storage_rewrite is not run
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
index 76726132e0..c0b017e703 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
@@ -938,3 +938,38 @@ def CopyComputeReordering(max_copy_movements: Optional[int] = None) -> tvm.IRMod
         The new module with copy and compute nodes reordered.
     """
     return _ffi_api.CopyComputeReordering(max_copy_movements)
+
+
+def MergeConstants(const_dict):
+    """
+    This pass looks for the constants used by each compute operator
+    and merges them into a single buffer.
+    Constants written to a buffer with local scope are not merged.
+    """
+
+    def _merge_constants(mod):
+        nonlocal const_dict
+        try:
+            mod["main"]
+        except:
+            raise tvm.TVMError(
+                "Expected a single primitive function called 'main'. "
+                "Please run the MergeConstants pass in conjunction with the LowerToTIR() pass."
+            )
+
+        new_const_dict = {}
+        for param in const_dict.keys():
+            new_const_dict[tvm.tir.IntImm("int64", param)] = tvm.nd.array(const_dict[param])
+        mod["main"] = mod["main"].with_attr("ethos-u.const_dict", new_const_dict)
+
+        mod = _ffi_api.MergeConstants()(mod)
+        const_dict = mod["main"].attrs["ethos-u.const_dict"]
+        mod = _ffi_api.RemoveConstDictAttribute()(mod)
+
+        new_const_dict = {}
+        for param in const_dict.keys():
+            new_const_dict[int(param)] = const_dict[param].numpy()
+
+        return mod, new_const_dict
+
+    return _merge_constants
diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc
index 609d986dbb..b662e9dfd0 100644
--- a/src/tir/contrib/ethosu/passes.cc
+++ b/src/tir/contrib/ethosu/passes.cc
@@ -24,10 +24,13 @@
  */
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
 #include <algorithm>
+#include <unordered_map>
+#include <unordered_set>
 
 namespace tvm {
 
@@ -42,6 +45,62 @@ namespace tir {
 namespace contrib {
 namespace ethosu {
 
+namespace {
+
+/*! Returns the arguments of the given statement */
+Array<PrimExpr> GetStmtArgs(const Stmt& stmt) {
+  auto attr{stmt.as<AttrStmtNode>()};
+  Stmt eval_stmt{attr ? attr->body : stmt};
+  auto eval{eval_stmt.as<EvaluateNode>()};
+  ICHECK(eval) << "Expected statement to be an evaluate node, but was " << eval_stmt->GetTypeKey();
+  auto call{eval->value.as<CallNode>()};
+  ICHECK(call) << "Expected expression to be a call node, but was " << eval->value->GetTypeKey();
+  return call->args;
+}
+
+enum class StmtType { global_copy, local_copy, compute };
+
+/*! Returns the type of the given statement */
+StmtType GetStmtType(const Stmt& stmt) {
+  Array<PrimExpr> args{GetStmtArgs(stmt)};
+  if (args[0].as<StringImmNode>()->value == "ethosu_copy") {
+    if (args[3].as<BufferLoadNode>()->buffer.scope() == "global") {
+      return StmtType::global_copy;
+    } else {
+      return StmtType::local_copy;
+    }
+  }
+  return StmtType::compute;
+}
+/*! Returns the buffer read my the given copy statement */
+Buffer GetCopyReadBuffer(const Stmt& stmt) {
+  Array<PrimExpr> args{GetStmtArgs(stmt)};
+  return args[1].as<BufferLoadNode>()->buffer;
+}
+
+/*! Returns the buffer written my the given copy statement */
+Buffer GetCopyWriteBuffer(const Stmt& stmt) {
+  Array<PrimExpr> args{GetStmtArgs(stmt)};
+  return args[3].as<BufferLoadNode>()->buffer;
+}
+
+/*! Returns the length of the given copy statement */
+int64_t GetCopyLength(const Stmt& stmt) {
+  Array<PrimExpr> args{GetStmtArgs(stmt)};
+  return args[2].as<IntImmNode>()->value;
+}
+
+/*! Returns the cycles of the given statement */
+int64_t GetStmtCycles(const Stmt& stmt) {
+  auto attr{stmt.as<AttrStmtNode>()};
+  if (attr && attr->attr_key == "pragma_compute_cycles_hint") {
+    int64_t cycles{Downcast<Integer>(attr->value)->value};
+    return cycles;
+  }
+  return 0;
+}
+}  // namespace
+
 /*!
  * \brief This mutator moves allocates to the top of the body of the main
  * function.
@@ -154,9 +213,9 @@ class CopyComputeReorderingMutator : public StmtExprMutator {
     // Each copy statement to a buffer with global scope is moved up
     // at most `_max_copy_movements` times.
     for (size_t index = 0; index < new_seq.size(); ++index) {
-      if (stmt_is_global_copy(new_seq[index])) {
+      if (GetStmtType(new_seq[index]) == StmtType::global_copy) {
         int lower = std::max(0, static_cast<int>(index) - _max_copy_movements);
-        for (int i = index; i > lower && !stmt_is_copy(new_seq[i - 1]); --i) {
+        for (int i = index; i > lower && (GetStmtType(new_seq[i - 1]) == StmtType::compute); --i) {
           std::swap(new_seq[i - 1], new_seq[i]);
         }
       }
@@ -167,32 +226,6 @@ class CopyComputeReorderingMutator : public StmtExprMutator {
     return Stmt{seq_stmt_node};
   }
 
-  tvm::runtime::Array<tvm::PrimExpr> get_stmt_args(const Stmt& stmt) {
-    Stmt eval_stmt = stmt;
-    if (const auto* attr_stmt = eval_stmt.as<AttrStmtNode>()) {
-      eval_stmt = attr_stmt->body;
-    }
-
-    auto eval_node{eval_stmt.as<EvaluateNode>()};
-    ICHECK(eval_node) << "Expected statement to be an evaluate node, but was "
-                      << eval_stmt->GetTypeKey();
-    auto call_node{eval_node->value.as<CallNode>()};
-    ICHECK(call_node) << "Expected expression to be a call node, but was "
-                      << eval_node->value->GetTypeKey();
-    return call_node->args;
-  }
-
-  bool stmt_is_copy(const Stmt& stmt) {
-    auto args{get_stmt_args(stmt)};
-    return args[0].as<StringImmNode>()->value == "ethosu_copy";
-  }
-
-  bool stmt_is_global_copy(const Stmt& stmt) {
-    auto args{get_stmt_args(stmt)};
-    return args[0].as<StringImmNode>()->value == "ethosu_copy" &&
-           args[3].as<BufferLoadNode>()->buffer.scope() == "global";
-  }
-
   /*! The maximum number of movements allowed for a copy. */
   int _max_copy_movements;
 };
@@ -223,6 +256,560 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This mutator removes all allocates.
+ */
+class RemoveAllocatesMutator : public StmtExprMutator {
+ public:
+  PrimFunc operator()(PrimFunc main_func) {
+    auto prim_func_node{main_func.CopyOnWrite()};
+    prim_func_node->body = this->VisitStmt(main_func->body);
+    return GetRef<PrimFunc>(prim_func_node);
+  }
+
+ private:
+  Stmt VisitStmt_(const AllocateNode* op) override { return VisitStmt(op->body); }
+};
+
+/*!
+ * \brief This extractor collects information used by the MergeConstantsMutator
+ */
+class MergeConstantsInfoExtractor : public StmtExprVisitor {
+ public:
+  class Info {
+   public:
+    /*! A stack to store allocates as they are visited. */
+    std::vector<Allocate> allocates{};
+
+    /*! A list that contains in the i-th position the write buffer of the i-th statement
+     * if that statement is a copy to a buffer with global scope  */
+    std::vector<Optional<Buffer>> copy_write_buffers{};
+
+    /*! Maps a copy's write buffer to an index representing the
+     * new buffer and an offset in that buffer */
+    std::unordered_map<const BufferNode*, std::pair<int /* new buffer index */, int /* offset */>>
+        old_to_new_write_buffer{};
+
+    /*! Maps an index representing a new buffer to the length of that buffer */
+    std::unordered_map<int /* new buffer index */, int /* length */> new_buffers_length{};
+
+    /*! Maps an index representing a new buffer to the cycless needed to copy that buffer */
+    std::unordered_map<int /* new buffer index */, int64_t> cycless{};
+  };
+
+  Info operator()(PrimFunc main_func) {
+    this->VisitStmt(main_func->body);
+    return std::move(_info);
+  }
+
+ private:
+  /*! The information collected by this extractor */
+  Info _info{};
+
+  void VisitStmt_(const AllocateNode* op) override {
+    _info.allocates.push_back(GetRef<Allocate>(op));
+    VisitStmt(op->body);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) override {
+    if (op->size() <= 1) {
+      StmtExprVisitor::VisitStmt_(op);
+      return;
+    }
+
+    auto seq_stmt{GetRef<SeqStmt>(op)};
+    for (size_t i = 0; i < seq_stmt.size(); ++i) {
+      Stmt stmt{seq_stmt[i]};
+      switch (GetStmtType(stmt)) {
+        case StmtType::global_copy: {
+          Buffer write_buffer{GetCopyWriteBuffer(stmt)};
+          _info.copy_write_buffers.push_back(write_buffer);
+          _info.old_to_new_write_buffer[write_buffer.as<BufferNode>()] = std::make_pair(-1, -1);
+          break;
+        }
+        case StmtType::local_copy: {
+          _info.copy_write_buffers.push_back(Optional<Buffer>{});
+          break;
+        }
+        case StmtType::compute: {
+          _info.copy_write_buffers.push_back(Optional<Buffer>{});
+          std::vector<Buffer> buffers{GetCopiedBuffersUsedByStmt(stmt)};
+          if (buffers.empty()) {
+            continue;
+          }
+          _info.new_buffers_length[i] = 0;
+          for (Buffer buffer : buffers) {
+            for (size_t j{i - 1}; j >= 0; --j) {
+              if (_info.copy_write_buffers[j] == buffer) {
+                _info.old_to_new_write_buffer[buffer.as<BufferNode>()] =
+                    std::make_pair(i, _info.new_buffers_length[i]);
+                _info.new_buffers_length[i] += GetCopyLength(seq_stmt[j]);
+                _info.cycless[i] += GetStmtCycles(seq_stmt[j]);
+                break;
+              }
+            }
+          }
+          break;
+        }
+      }
+    }
+  }
+
+  /*! Get all buffers written by copies and used by a given statement */
+  std::vector<Buffer> GetCopiedBuffersUsedByStmt(const Stmt& stmt) {
+    std::vector<Buffer> buffers{};
+    for (PrimExpr arg : GetStmtArgs(stmt)) {
+      if (auto buffer_load = arg.as<BufferLoadNode>()) {
+        Buffer buffer{buffer_load->buffer};
+        // Check if the buffer has already been added
+        if (std::find(buffers.begin(), buffers.end(), buffer) == buffers.end()) {
+          // Check if the buffer is copied
+          if (_info.old_to_new_write_buffer.count(buffer.as<BufferNode>())) {
+            buffers.push_back(buffer);
+          }
+        }
+      }
+    }
+    return buffers;
+  }
+};
+
+/*!
+ * \brief This mutator looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  explicit MergeConstantsMutator(MergeConstantsInfoExtractor::Info info) : _info{std::move(info)} {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Rewrite
+    Stmt new_body = RewritePrimFuncBody(main_func->body);
+    std::unordered_set<const VarNode*> params_to_delete{};
+    Map<Var, Buffer> new_buffer_map{MakeNewBufferMap(main_func->buffer_map, &params_to_delete)};
+    Array<Var> new_params{MakeNewParams(main_func->params, params_to_delete)};
+
+    // Make the new const dict
+    Array<Array<IntImm>> args_to_merge{GetArgsToMerge(main_func->buffer_map, main_func->params)};
+    Array<Array<IntImm>> buffers_to_merge{
+        GetArgsToMergeWithoutArgsNotInConstDict(args_to_merge, const_dict)};
+    Map<IntImm, runtime::NDArray> new_const_dict{MakeNewConstDict(buffers_to_merge, const_dict)};
+
+    // Make the new prim func
+    auto prim_func_node{main_func.CopyOnWrite()};
+    prim_func_node->body = std::move(new_body);
+    prim_func_node->buffer_map = std::move(new_buffer_map);
+    prim_func_node->params = std::move(new_params);
+    prim_func_node->preflattened_buffer_map = {};
+    PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
+
+    // Add the new const dict as an attribute
+    f = WithAttr(std::move(f), "ethos-u.const_dict", new_const_dict);
+
+    return f;
+  }
+
+ private:
+  /*! The information collected by the MergeConstantsInfoExtractor */
+  MergeConstantsInfoExtractor::Info _info;
+
+  /*! Maps an index representing a new buffer to the new buffer */
+  std::unordered_map<int /* new buffer index */, Buffer> new_buffers{};
+
+  /*! Maps a copy's read buffer to the new copy's read buffer */
+  std::unordered_map<const BufferNode*, Buffer> old_to_new_read_buffers{};
+
+  /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer
+   */
+  std::unordered_map<int /* new buffer index */, std::vector<Buffer>> buffers_to_merge{};
+
+  /*! A set of buffers to delete */
+  std::unordered_set<const BufferNode*> buffers_to_delete{};
+
+  Stmt RewritePrimFuncBody(Stmt body) {
+    std::unordered_map<const VarNode*, Allocate> var_to_allocate{};
+
+    // Rewrite old allocates
+    std::unordered_set<const VarNode*> buffer_vars{GetVarsForWrittenCopyBuffers()};
+    for (auto it{_info.allocates.rbegin()}; it != _info.allocates.rend(); ++it) {
+      Allocate alloc{*it};
+      var_to_allocate[alloc->buffer_var.get()] = alloc;
+      if (buffer_vars.count(alloc->buffer_var.as<VarNode>()) == 0) {
+        body = Allocate(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->condition, body,
+                        alloc->annotations, alloc->span);
+      }
+    }
+
+    // Rewrite new allocates
+    for (auto it{_info.copy_write_buffers.rbegin()}; it != _info.copy_write_buffers.rend(); ++it) {
+      if (Optional<Buffer> buffer_opt = *it) {
+        Buffer old_write_buffer{buffer_opt.value()};
+        int new_buffer_index{
+            _info.old_to_new_write_buffer[old_write_buffer.as<BufferNode>()].first};
+
+        // Check if the allocate has already been created
+        if (new_buffers.count(new_buffer_index) == 0) {
+          BufferNode* new_buffer{old_write_buffer.CopyOnWrite()};
+          new_buffer->shape = {_info.new_buffers_length[new_buffer_index]};
+
+          new_buffers[new_buffer_index] = GetRef<Buffer>(new_buffer);
+
+          Allocate old_allocate{var_to_allocate[old_write_buffer->data.get()]};
+          body = Allocate(new_buffer->data, new_buffer->dtype, new_buffer->shape, tir::const_true(),
+                          body, old_allocate->annotations, old_allocate->span);
+        }
+      }
+    }
+
+    // Rewrite operators
+    return this->VisitStmt(body);
+  }
+
+  Stmt VisitStmt_(const AllocateNode* op) override {
+    auto allocate{CopyOnWrite(op)};
+    allocate->body = this->VisitStmt(op->body);
+    return Stmt(allocate);
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) override {
+    if (op->size() <= 1) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+
+    Array<Stmt> new_seq{};
+    SeqStmt seq_stmt{GetRef<SeqStmt>(op)};
+    for (size_t i{0}; i < seq_stmt.size(); ++i) {
+      Stmt stmt{seq_stmt[i]};
+
+      switch (GetStmtType(stmt)) {
+        case StmtType::global_copy: {
+          Buffer old_write_buffer{_info.copy_write_buffers[i].value()};
+          std::pair<int, int> pair{
+              _info.old_to_new_write_buffer[old_write_buffer.as<BufferNode>()]};
+          int new_buffer_index{pair.first};
+          int new_buffer_offset{pair.second};
+          UpdateBuffersToMergeAndDelete(stmt, new_buffer_index, new_buffer_offset);
+
+          if (!IsCopyToBeDeleted(new_buffer_offset)) {
+            Optional<PrimExpr> cycless{GetMergedCycles(new_buffer_index)};
+            new_seq.push_back(MakeNewStmt(
+                stmt, MakeNewCopyArgs(stmt, old_write_buffer, new_buffer_index), cycless));
+          }
+          break;
+        }
+        case StmtType::local_copy: {
+          new_seq.push_back(stmt);
+          break;
+        }
+        case StmtType::compute: {
+          new_seq.push_back(MakeNewStmt(stmt, MakeNewComputeArgs(stmt)));
+          break;
+        }
+      }
+    }
+    return SeqStmt(new_seq, op->span);
+  }
+
+  /*! Returns the variables of the buffers written by copies */
+  std::unordered_set<const VarNode*> GetVarsForWrittenCopyBuffers() {
+    std::unordered_set<const VarNode*> buffer_vars{};
+    std::transform(_info.old_to_new_write_buffer.begin(), _info.old_to_new_write_buffer.end(),
+                   std::inserter(buffer_vars, buffer_vars.begin()),
+                   [](std::pair<const BufferNode*, std::pair<int, int>> pair) -> const VarNode* {
+                     return pair.first->data.as<VarNode>();
+                   });
+    return buffer_vars;
+  }
+
+  /*! Returns the cycles of the new buffer at the given index */
+  Optional<PrimExpr> GetMergedCycles(int new_buffer_index) {
+    auto it = _info.cycless.find(new_buffer_index);
+    if (it != _info.cycless.end()) {
+      return Integer(it->second);
+    }
+    return Optional<PrimExpr>{};
+  }
+
+  /*! Returns true if a copy must be deleted, false otherwise */
+  bool IsCopyToBeDeleted(int new_buffer_offset) { return new_buffer_offset > 0; }
+
+  Array<PrimExpr> MakeNewCopyArgs(const Stmt& stmt, const Buffer& old_write_buffer,
+                                  int new_buffer_index) {
+    Array<PrimExpr> args{GetStmtArgs(stmt)};
+    int new_length{_info.new_buffers_length[new_buffer_index]};
+
+    Array<PrimExpr> new_args{};
+    for (size_t i = 0; i < args.size(); ++i) {
+      switch (i) {
+        case 1: /* read_address */ {
+          auto buffer_load = args[1].as<BufferLoadNode>();
+          Buffer buffer{buffer_load->buffer};
+          Buffer new_buffer{buffer->data,
+                            buffer->dtype,
+                            {new_length},
+                            buffer->strides,
+                            buffer->elem_offset,
+                            buffer->name,
+                            buffer->data_alignment,
+                            buffer->offset_factor,
+                            buffer->buffer_type,
+                            buffer->axis_separators,
+                            buffer->span};
+          old_to_new_read_buffers[buffer.as<BufferNode>()] = new_buffer;
+          new_args.push_back(BufferLoad(new_buffer, buffer_load->indices, buffer_load->span));
+          break;
+        }
+        case 2: /* length */ {
+          new_args.push_back(new_length);
+          break;
+        }
+        case 3: /* write_address */ {
+          new_args.push_back(MakeNewBufferLoad(old_write_buffer, 0, true).value());
+          break;
+        }
+        default:
+          new_args.push_back(args[i]);
+          break;
+      }
+    }
+    return new_args;
+  }
+
+  Array<PrimExpr> MakeNewComputeArgs(const Stmt& stmt) {
+    Array<PrimExpr> args{GetStmtArgs(stmt)};
+    Array<PrimExpr> new_args{};
+    for (size_t i = 0; i < args.size(); ++i) {
+      if (auto buffer_load = args[i].as<BufferLoadNode>()) {
+        BufferLoad new_buffer_load{
+            MakeNewBufferLoad(buffer_load->buffer, buffer_load->indices[0], false)
+                .value_or(GetRef<BufferLoad>(buffer_load))};
+        new_args.push_back(new_buffer_load);
+      } else {
+        new_args.push_back(args[i]);
+      }
+    }
+    return new_args;
+  }
+
+  Stmt MakeNewStmt(const Stmt& stmt, const Array<PrimExpr>& new_args,
+                   Optional<PrimExpr> cycless = Optional<PrimExpr>{}) {
+    auto attr{stmt.as<AttrStmtNode>()};
+    Stmt eval_stmt{attr ? attr->body : stmt};
+    auto eval{eval_stmt.as<EvaluateNode>()};
+    ICHECK(eval) << "Expected statement to be an evaluate node, but was "
+                 << eval_stmt->GetTypeKey();
+    auto call{eval->value.as<CallNode>()};
+    ICHECK(call) << "Expected expression to be a call node, but was " << eval->value->GetTypeKey();
+
+    Call new_call{call->dtype, call->op, new_args, call->span};
+    Evaluate new_eval{new_call, eval->span};
+
+    if (attr) {
+      ICHECK(attr->attr_key == "pragma_compute_cycles_hint");
+      PrimExpr value = cycless.value_or(attr->value);
+      return AttrStmt{attr->node, attr->attr_key, value, new_eval, attr->span};
+    } else {
+      return std::move(new_eval);
+    }
+  }
+
+  Optional<BufferLoad> MakeNewBufferLoad(const Buffer& write_buffer, const PrimExpr& old_index,
+                                         bool only_old_index) {
+    auto it = _info.old_to_new_write_buffer.find(write_buffer.as<BufferNode>());
+    if (it != _info.old_to_new_write_buffer.end()) {
+      std::pair<int, int> pair{it->second};
+      int new_buffer_index{pair.first};
+      PrimExpr new_index{only_old_index ? old_index : (pair.second + old_index)};
+      return BufferLoad{new_buffers[new_buffer_index], {new_index}};
+    }
+    return Optional<BufferLoad>{};
+  }
+
+  Map<tir::Var, Buffer> MakeNewBufferMap(const Map<tir::Var, Buffer>& buffer_map,
+                                         std::unordered_set<const VarNode*>* params_to_delete) {
+    Map<tir::Var, Buffer> new_buffer_map{};
+    for (std::pair<Var, Buffer> pair : buffer_map) {
+      Var var{pair.first};
+      Buffer buffer{pair.second};
+
+      if (buffers_to_delete.count(buffer.as<BufferNode>()) == 1) {
+        params_to_delete->insert(var.as<VarNode>());
+      } else if (old_to_new_read_buffers.count(buffer.as<BufferNode>()) == 1) {
+        new_buffer_map.Set(var, old_to_new_read_buffers[buffer.as<BufferNode>()]);
+      } else {
+        new_buffer_map.Set(var, buffer);
+      }
+    }
+    return new_buffer_map;
+  }
+
+  Array<tir::Var> MakeNewParams(const Array<tir::Var>& params,
+                                const std::unordered_set<const VarNode*>& params_to_delete) {
+    std::vector<Var> new_params{};
+    for (Var var : params) {
+      if (params_to_delete.count(var.as<VarNode>()) == 0) {
+        new_params.push_back(var);
+      }
+    }
+    return new_params;
+  }
+
+  void UpdateBuffersToMergeAndDelete(const Stmt& stmt, int new_buffer_index,
+                                     int new_buffer_offset) {
+    Array<PrimExpr> args{GetStmtArgs(stmt)};
+    Buffer read_buffer{GetCopyReadBuffer(stmt)};
+
+    if (buffers_to_merge.count(new_buffer_index) == 0) {
+      buffers_to_merge[new_buffer_index] = std::vector<Buffer>{read_buffer};
+    } else {
+      buffers_to_merge[new_buffer_index].push_back(read_buffer);
+    }
+
+    if (new_buffer_offset > 0) {
+      buffers_to_delete.insert(read_buffer.as<BufferNode>());
+    }
+  }
+
+  /*! Returns an array whose elements are the indices of the function arguments to be merged.
+   * Example: if a function has three arguments and the second and the third ones must
+   * be merged then the array is: [[0], [1, 2], [3]] */
+  Array<Array<IntImm>> GetArgsToMerge(const Map<Var, Buffer>& buffer_map,
+                                      const Array<Var>& params) {
+    std::unordered_map<const BufferNode*, Var> buffer_to_var{};
+    for (std::pair<Var, Buffer> var_buffer : buffer_map) {
+      buffer_to_var[var_buffer.second.as<BufferNode>()] = var_buffer.first;
+    }
+
+    std::unordered_map<const VarNode*, int> var_to_index{};
+    for (int i = 0; i < static_cast<int>(params.size()); ++i) {
+      var_to_index[params[i].as<VarNode>()] = i;
+    }
+
+    std::vector<Array<IntImm>> vector{};
+    for (std::pair<int, std::vector<Buffer>> index_vector : buffers_to_merge) {
+      std::vector<IntImm> indices{};
+      for (Buffer buffer : index_vector.second) {
+        const VarNode* var{buffer_to_var[buffer.as<BufferNode>()].as<VarNode>()};
+        IntImm index{DataType::Int(64), var_to_index[var]};
+        var_to_index.erase(var);
+        auto it = std::find_if(indices.begin(), indices.end(),
+                               [&](IntImm value) { return value->value == index->value; });
+        if (it == indices.end()) {
+          indices.push_back(index);
+        }
+      }
+      vector.push_back(Array<IntImm>{indices});
+    }
+
+    for (std::pair<const VarNode*, int> var_index : var_to_index) {
+      vector.push_back(Array<IntImm>{IntImm(DataType::Int(64), var_index.second)});
+    }
+    std::sort(vector.begin(), vector.end(),
+              [](Array<IntImm> a, Array<IntImm> b) { return a[0]->value < b[0]->value; });
+    return vector;
+  }
+
+  Array<Array<IntImm>> GetArgsToMergeWithoutArgsNotInConstDict(
+      const Array<Array<IntImm>>& args_to_merge, const Map<IntImm, runtime::NDArray>& const_dict) {
+    Array<Array<IntImm>> new_args_to_merge{};
+    for (Array<IntImm> args : args_to_merge) {
+      IntImm key{args[0]};
+      auto it = std::find_if(const_dict.begin(), const_dict.end(),
+                             [&](std::pair<tvm::IntImm, runtime::NDArray> pair) {
+                               return pair.first->value == key->value;
+                             });
+      if (it != const_dict.end()) {
+        new_args_to_merge.push_back(args);
+      }
+    }
+    return new_args_to_merge;
+  }
+
+  Map<IntImm, runtime::NDArray> MakeNewConstDict(const Array<Array<IntImm>>& args_to_merge,
+                                                 Map<IntImm, runtime::NDArray> const_dict) {
+    Map<IntImm, runtime::NDArray> new_const_dict{};
+    if (args_to_merge.size() == 0) {
+      return new_const_dict;
+    }
+
+    int64_t key = args_to_merge[0][0]->value;
+    for (Array<IntImm> args : args_to_merge) {
+      int64_t size = 0;
+      for (IntImm arg : args) {
+        auto it = std::find_if(const_dict.begin(), const_dict.end(),
+                               [&](auto pair) { return pair.first->value == arg->value; });
+        runtime::NDArray arg_constant{(*it).second};
+        size += runtime::GetDataSize(*arg_constant.operator->());
+      }
+
+      runtime::NDArray constant = runtime::NDArray::Empty({size}, DataType::UInt(8), {kDLCPU, 0});
+
+      size_t offset = 0;
+      for (IntImm arg : args) {
+        auto it = std::find_if(const_dict.begin(), const_dict.end(),
+                               [&](auto pair) { return pair.first->value == arg->value; });
+        runtime::NDArray arg_constant{(*it).second};
+        size_t nbytes = runtime::GetDataSize(*arg_constant.operator->());
+        arg_constant.CopyToBytes(static_cast<uint8_t*>(constant->data) + offset, nbytes);
+        offset += nbytes;
+      }
+      new_const_dict.Set(IntImm(DataType::Int(64), key), constant);
+      key += 1;
+    }
+    return new_const_dict;
+  }
+};
+
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ * \return tvm::transform::Pass
+ */
+tvm::transform::Pass MergeConstants() {
+  auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) {
+    ICHECK(mod->GetGlobalVars().size() == 1 && mod->ContainGlobalVar("main"))
+        << "Expected a single primitive function called 'main'. Please run the "
+           "MergeConstants pass in conjunction with the LowerToTIR() pass.";
+    Optional<Map<IntImm, runtime::NDArray>> const_dict{
+        f->attrs.GetAttr("ethos-u.const_dict", Optional<Map<IntImm, runtime::NDArray>>{})};
+    ICHECK(const_dict) << "Expected a ethos-u.const_dict attribute";
+
+    MergeConstantsInfoExtractor::Info info{MergeConstantsInfoExtractor()(f)};
+    f = RemoveAllocatesMutator()(f);
+    return MergeConstantsMutator(info)(f, const_dict.value());
+  };
+  return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, "tir.contrib.ethos-u.MergeConstants",
+                                                 {});
+}
+
+TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.MergeConstants").set_body_typed(MergeConstants);
+
+/*!
+ * \brief This pass removes the ethos-u.const_dict attribute
+ * \return tvm::transform::Pass
+ */
+class RemoveConstDictAttributeMutator : public StmtExprMutator {
+ public:
+  RemoveConstDictAttributeMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func) {
+    return WithoutAttr(std::move(main_func), "ethos-u.const_dict");
+  }
+};
+
+tvm::transform::Pass RemoveConstDictAttribute() {
+  auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) {
+    return RemoveConstDictAttributeMutator()(f);
+  };
+  return tvm::tir::transform::CreatePrimFuncPass(
+      pass_func, 0, "tir.contrib.ethos-u.RemoveConstDictAttribute", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.RemoveConstDictAttribute")
+    .set_body_typed(RemoveConstDictAttribute);
+
 }  // namespace ethosu
 }  // namespace contrib
 }  // namespace tir
diff --git a/tests/python/contrib/test_ethosu/cascader/test_integration.py b/tests/python/contrib/test_ethosu/cascader/test_integration.py
index 8e1f020861..14cc8fbc61 100644
--- a/tests/python/contrib/test_ethosu/cascader/test_integration.py
+++ b/tests/python/contrib/test_ethosu/cascader/test_integration.py
@@ -109,9 +109,8 @@ def test_single_conv_compute_cycles_hint():
     for single convolution.
     """
     primfunc = _compile_model(_create_single_conv2d())
-    ops = primfunc.body.body.body.seq
-
-    compute_cycles_hints = [2304, 640, 320]
+    ops = primfunc.body.body.seq
+    compute_cycles_hints = [2944, 320]
     for op, compute_cycle_hint in zip(ops, compute_cycles_hints):
         assert op.attr_key == "pragma_compute_cycles_hint"
         assert op.value == compute_cycle_hint
@@ -123,9 +122,8 @@ def test_double_conv_compute_cycles_hint():
     for double convolution.
     """
     primfunc = _compile_model(_create_double_conv2d())
-    ops = primfunc.body.body.body.body.body.body.seq
-
-    compute_cycles_hints = [2304, 640, 768, 640, 320, 240]
+    ops = primfunc.body.body.body.body.seq
+    compute_cycles_hints = [2944, 1408, 320, 240]
     for op, compute_cycle_hint in zip(ops, compute_cycles_hints):
         assert op.attr_key == "pragma_compute_cycles_hint"
         assert op.value == compute_cycle_hint
diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py
index 15b719f33c..fd9f373739 100644
--- a/tests/python/contrib/test_ethosu/test_encode_constants.py
+++ b/tests/python/contrib/test_ethosu/test_encode_constants.py
@@ -37,34 +37,23 @@ class WeightStreamOnlyU55:
     def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
-        buffer1 = T.buffer_decl([128], "uint8")
-        buffer2 = T.buffer_decl([32], "uint8")
-        buffer3 = T.buffer_decl([112], "uint8")
-        buffer4 = T.buffer_decl([32], "uint8")
-        buffer5 = T.buffer_decl([112], "uint8")
-        buffer6 = T.buffer_decl([32], "uint8")
-        buffer7 = T.buffer_decl([112], "uint8")
+        buffer1 = T.buffer_decl([160], "uint8")
+        buffer3 = T.buffer_decl([144], "uint8")
+        buffer5 = T.buffer_decl([144], "uint8")
+        buffer7 = T.buffer_decl([144], "uint8")
         buffer8 = T.buffer_decl([32], "uint8")
-        T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data)
-        T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data)
         # body
-        p1 = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True})
-        p2 = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True})
-        p3 = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True})
-        p4 = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True})
-        buffer9 = T.buffer_decl([112], "uint8", data=p1.data)
-        T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 128, p1[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 32, p2[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 112, p3[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 32, p4[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, T.int8(-1), T.int8(-1), 12, p2[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 112, buffer9[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 32, p2[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 112, T.int8(-1), T.int8(-1), 12, p4[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 112, p3[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 32, p4[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 112, T.int8(-1), T.int8(-1), 12, p2[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 112, T.int8(-1), T.int8(-1), 12, p4[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        p1 = T.allocate([160], "uint8", "global", annotations={"disable_lower_builtin":True})
+        p2 = T.allocate([144], "uint8", "global", annotations={"disable_lower_builtin":True})
+        buffer9 = T.buffer_decl([144], "uint8", data=p1.data)
+        T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 160, p1[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 144, p2[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, T.int8(-1), T.int8(-1), 12, p1[128], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 144, buffer9[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 112, T.int8(-1), T.int8(-1), 12, p2[112], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 144, p2[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 112, T.int8(-1), T.int8(-1), 12, buffer9[112], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 112, T.int8(-1), T.int8(-1), 12, p2[112], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
     __tvm_meta__ = None
 
 
@@ -75,34 +64,22 @@ class WeightStreamOnlyU65:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         # buffer definition
-        T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
-        T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
-        buffer_encoded_1 = T.buffer_decl([160], dtype="uint8")
-        buffer_encoded_1_1 = T.buffer_decl([32], dtype="uint8")
-        buffer_encoded_2_1 = T.buffer_decl([160], dtype="uint8")
-        buffer_encoded_3_1 = T.buffer_decl([32], dtype="uint8")
-        buffer_encoded_4_1 = T.buffer_decl([176], dtype="uint8")
-        buffer_encoded_5_1 = T.buffer_decl([32], dtype="uint8")
-        buffer_encoded_6_1 = T.buffer_decl([160], dtype="uint8")
-        buffer_encoded_7_1 = T.buffer_decl([32], dtype="uint8")
+        buffer_encoded_1 = T.buffer_decl([192], dtype="uint8")
+        buffer_encoded_2_1 = T.buffer_decl([192], dtype="uint8")
+        buffer_encoded_4_1 = T.buffer_decl([208], dtype="uint8")
+        buffer_encoded_6_1 = T.buffer_decl([192], dtype="uint8")
         # body
-        placeholder_global = T.allocate([176], "uint8", "global", annotations={"disable_lower_builtin":True})
-        placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True})
-        placeholder_global_2 = T.allocate([160], "uint8", "global", annotations={"disable_lower_builtin":True})
-        placeholder_d_global_2 = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True})
-        placeholder_global_1 = T.buffer_decl([160], dtype="uint8", data=placeholder_global.data)
-        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1[0], 160, placeholder_global_1[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1_1[0], 32, placeholder_d_global[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_2_1[0], 160, placeholder_global_2[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_3_1[0], 32, placeholder_d_global_2[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_1[0], 80, placeholder_global_1[80], 80, 12, placeholder_d_global[0], 16, placeholder_d_global[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4_1[0], 176, placeholder_global[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_5_1[0], 32, placeholder_d_global[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 80, placeholder_global_2[80], 80, 12, placeholder_d_global_2[0], 16, placeholder_d_global_2[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_6_1[0], 160, placeholder_global_2[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_7_1[0], 32, placeholder_d_global_2[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 96, placeholder_global[96], 80, 12, placeholder_d_global[0], 16, placeholder_d_global[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 80, placeholder_global_2[80], 80, 12, placeholder_d_global_2[0], 16, placeholder_d_global_2[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        p1 = T.allocate([208], "uint8", "global", annotations={"disable_lower_builtin":True})
+        p2 = T.allocate([192], "uint8", "global", annotations={"disable_lower_builtin":True})
+        p3 = T.buffer_decl([192], dtype="uint8", data=p1.data)
+        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1[0], 192, p3[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_2_1[0], 192, p2[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 80, p3[80], 80, 12, p3[160], 16, p3[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4_1[0], 208, p1[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, p2[80], 80, 12, p2[160], 16, p2[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_6_1[0], 192, p2[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 96, p1[96], 80, 12, p1[176], 16, p1[192], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, p2[80], 80, 12, p2[160], 16, p2[176], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
     __tvm_meta__ = None
 # fmt: on
 
@@ -113,12 +90,12 @@ class WeightStreamOnlyU65:
         (
             "ethos-u55-128",
             WeightStreamOnlyU55,
-            [128, 32, 112, 32, 112, 32, 112, 32],
+            [160, 144, 144, 144],
         ),
         (
             "ethos-u65-512",
             WeightStreamOnlyU65,
-            [160, 32, 160, 32, 176, 32, 160, 32],
+            [192, 192, 208, 192],
         ),
     ],
 )
@@ -160,7 +137,7 @@ def test_weight_stream_only(accelerator, reference_mod, reference_const_sizes):
         tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)
 
         test_const_size = [value.size for value in list(consts.values())]
-        assert reference_const_sizes == test_const_size
+        assert reference_const_sizes.sort() == test_const_size.sort()
 
 
 # fmt: off
@@ -170,21 +147,14 @@ class RereadWeightsU55:
     def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
-        buffer1 = T.buffer_decl([304], "uint8")
-        buffer2 = T.buffer_decl([80], "uint8")
-        T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data)
-        T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data)
+        buffer1 = T.buffer_decl([384], "uint8")
         # body
-        p1 = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin":True})
-        p2 = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True})
-        p3 = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin":True})
-        p4 = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True})
-        T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 304, p1[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p2[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 304, p3[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p4[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 304, T.int8(-1), T.int8(-1), 12, p2[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 304, T.int8(-1), T.int8(-1), 12, p4[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        p1 = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin":True})
+        p2 = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin":True})
+        T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 384, p1[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 384, p2[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 304, T.int8(-1), T.int8(-1), 12, p1[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 304, T.int8(-1), T.int8(-1), 12, p2[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
     __tvm_meta__ = None
 
 
@@ -195,21 +165,14 @@ class RereadWeightsU65:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         # buffer definition
-        T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
-        T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
-        placeholder_encoded_1 = T.buffer_decl([368], "uint8")
-        placeholder_encoded_1_2 = T.buffer_decl([96], "uint8")
+        placeholder_encoded_1 = T.buffer_decl([464], "uint8")
         # body
-        placeholder_global = T.allocate([368], "uint8", "global", annotations={"disable_lower_builtin":True})
-        placeholder_d_global = T.allocate([96], "uint8", "global", annotations={"disable_lower_builtin":True})
-        placeholder_global_1 = T.allocate([368], "uint8", "global", annotations={"disable_lower_builtin":True})
-        placeholder_d_global_1 = T.allocate([96], "uint8", "global", annotations={"disable_lower_builtin":True})
-        T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 368, placeholder_global[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1_2[0], 96, placeholder_d_global[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 368, placeholder_global_1[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1_2[0], 96, placeholder_d_global_1[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 192, placeholder_global[192], 176, 12, placeholder_d_global[0], 48, placeholder_d_global[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_1[0], 192, placeholder_global_1[192], 176, 12, placeholder_d_global_1[0], 48, placeholder_d_global_1[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        p1 = T.allocate([464], "uint8", "global", annotations={"disable_lower_builtin":True})
+        p2 = T.allocate([464], "uint8", "global", annotations={"disable_lower_builtin":True})
+        T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 464, p1[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 464, p2[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
 
     __tvm_meta__ = None
 # fmt: on
@@ -221,12 +184,12 @@ class RereadWeightsU65:
         (
             "ethos-u55-128",
             RereadWeightsU55,
-            [304, 80],
+            [384],
         ),
         (
             "ethos-u65-512",
             RereadWeightsU65,
-            [368, 96],
+            [464],
         ),
     ],
 )
@@ -268,7 +231,7 @@ def test_re_read_weights(accelerator, reference_mod, reference_const_sizes):
         tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)
 
         test_const_size = [value.size for value in list(consts.values())]
-        assert reference_const_sizes == test_const_size
+        assert reference_const_sizes.sort() == test_const_size.sort()
 
 
 # fmt: off
@@ -282,8 +245,6 @@ class DirectReadOnlyU55:
         buffer_1 = T.buffer_decl([160], "uint8")
         buffer_2 = T.buffer_decl([160], "uint8")
         buffer_3 = T.buffer_decl([80], "uint8")
-        T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data)
-        T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data)
         # body
         ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True})
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
@@ -302,8 +263,6 @@ class DirectReadOnlyU65:
         placeholder_encoded_1 = T.buffer_decl([160], dtype="uint8")
         placeholder_encoded_2 = T.buffer_decl([208], dtype="uint8")
         placeholder_encoded_3 = T.buffer_decl([96], dtype="uint8")
-        T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
-        T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
         # body
         ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True})
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded[0], 304, placeholder_encoded[304], 304, 12, placeholder_encoded_1[0], 80, placeholder_encoded_1[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
@@ -364,87 +323,64 @@ def test_direct_read_only(accelerator, reference_mod, reference_const_sizes):
         tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)
 
         test_const_size = [value.size for value in list(consts.values())]
-        assert reference_const_sizes == test_const_size
+        assert reference_const_sizes.sort() == test_const_size.sort()
 
 
 # fmt: off
 @tvm.script.ir_module
 class MixedReadU55:
     @T.prim_func
-    def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
+    def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(112,), "uint8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
-        buffer1 = T.buffer_decl([80], "uint8")
-        buffer2 = T.buffer_decl([32], "uint8")
-        buffer3 = T.buffer_decl([80], "uint8")
-        buffer4 = T.buffer_decl([32], "uint8")
-        buffer5 = T.buffer_decl([80], "uint8")
-        buffer6 = T.buffer_decl([32], "uint8")
-        buffer7 = T.buffer_decl([80], "uint8")
-        buffer8 = T.buffer_decl([32], "uint8")
+        buffer1 = T.buffer_decl([112], "uint8")
+        buffer3 = T.buffer_decl([112], "uint8")
+        buffer5 = T.buffer_decl([112], "uint8")
         buffer9 = T.buffer_decl([592], "uint8")
         buffer10 = T.buffer_decl([160], "uint8")
-        T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data)
-        T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data)
+        buffer11 = T.buffer_decl([2048], "int8")
         # body
-        p1 = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True})
-        p2 = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True})
+        p1 = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True})
         p3 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True})
-        p4 = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True})
-        p5 = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True})
-        T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 80, p1[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 32, p2[0], dtype="handle"))
+        p2 = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True})
+        T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 112, p1[0], dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 592, T.int8(-1), T.int8(-1), 12, buffer10[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 80, p4[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 32, p5[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p2[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 80, p1[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 32, p2[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 80, T.int8(-1), T.int8(-1), 12, p5[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 80, p4[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 32, p5[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p2[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 80, T.int8(-1), T.int8(-1), 12, p5[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 112, p2[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 112, p1[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 112, p2[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
     __tvm_meta__ = None
 
 
 @tvm.script.ir_module
 class MixedReadU65:
     @T.prim_func
-    def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
+    def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
-        T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data)
-        T.preflattened_buffer(placeholder, [1, 16, 16, 32], "int8", data=placeholder.data)
+
         # buffer definition
-        buffer_encoded_1 = T.buffer_decl([96], dtype="uint8")
-        buffer_encoded_1_2 = T.buffer_decl([32], dtype="uint8")
-        placeholder_encoded_1 = T.buffer_decl([608], dtype="uint8")
-        placeholder_encoded_1_2 = T.buffer_decl([160], dtype="uint8")
-        buffer_encoded_2_1 = T.buffer_decl([96], dtype="uint8")
-        buffer_encoded_3_1 = T.buffer_decl([32], dtype="uint8")
-        buffer_encoded_4_1 = T.buffer_decl([96], dtype="uint8")
-        buffer_encoded_5_1 = T.buffer_decl([32], dtype="uint8")
-        buffer_encoded_6_1 = T.buffer_decl([96], dtype="uint8")
-        buffer_encoded_7_1 = T.buffer_decl([32], dtype="uint8")
-        placeholder_global = T.allocate([96], "uint8", "global", annotations={"disable_lower_builtin":True})
-        placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True})
-        ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True})
-        placeholder_global_2 = T.allocate([96], "uint8", "global", annotations={"disable_lower_builtin":True})
-        placeholder_d_global_2 = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True})
-        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1[0], 96, placeholder_global[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1_2[0], 32, placeholder_d_global[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded_1[0], 304, placeholder_encoded_1[304], 304, 12, placeholder_encoded_1_2[0], 80, placeholder_encoded_1_2[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_2_1[0], 96, placeholder_global_2[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_3_1[0], 32, placeholder_d_global_2[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 48, placeholder_global[48], 48, 12, placeholder_d_global[0], 16, placeholder_d_global[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4_1[0], 96, placeholder_global[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_5_1[0], 32, placeholder_d_global[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 48, placeholder_global_2[48], 48, 12, placeholder_d_global_2[0], 16, placeholder_d_global_2[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_6_1[0], 96, placeholder_global_2[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_7_1[0], 32, placeholder_d_global_2[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 48, placeholder_global[48], 48, 12, placeholder_d_global[0], 16, placeholder_d_global[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 48, placeholder_global_2[48], 48, 12, placeholder_d_global_2[0], 16, placeholder_d_global_2[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        buffer1 = T.buffer_decl([128], dtype="uint8")
+        buffer2 = T.buffer_decl([128], dtype="uint8")
+        buffer3 = T.buffer_decl([128], dtype="uint8")
+        buffer4 = T.buffer_decl([608], dtype="uint8")
+        buffer5 = T.buffer_decl([160], dtype="uint8")
+        buffer6 = T.buffer_decl([2048], dtype="int8")
+        p1 = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True})
+        p2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True})
+        p3 = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True})
+        T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 128, p1[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer4[0], 304, buffer4[304], 304, 12, buffer5[0], 80, buffer5[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p3[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 128, p1[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 128, p3[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
     __tvm_meta__ = None
 # fmt: on
 
@@ -455,12 +391,12 @@ class MixedReadU65:
         (
             "ethos-u55-128",
             MixedReadU55,
-            [592, 160, 80, 32, 80, 32, 80, 32, 80, 32],
+            [592, 160, 112, 112, 112, 112],
         ),
         (
             "ethos-u65-512",
             MixedReadU65,
-            [608, 160, 96, 32, 96, 32, 96, 32, 96, 32],
+            [608, 160, 128, 128, 128, 128],
         ),
     ],
 )
@@ -512,7 +448,7 @@ def test_mixed_read(accelerator, reference_mod, reference_const_sizes):
         tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)
 
         test_const_size = [value.size for value in list(consts.values())]
-        assert reference_const_sizes == test_const_size
+        assert reference_const_sizes.sort() == test_const_size.sort()
 
 
 def test_constant_as_input():
@@ -543,7 +479,7 @@ def test_constant_as_input():
 
     # Check tile address for the scalar constant input hasn't been
     # overwritten.
-    extern_calls = tir_mod["main"].body.body.body.body.body
+    extern_calls = tir_mod["main"].body.body.body.body
     binary_elementwise = extern_calls[-1].value
     args = binary_elementwise.args
 
diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py b/tests/python/contrib/test_ethosu/test_merge_constants.py
new file mode 100644
index 0000000000..caf09abdb0
--- /dev/null
+++ b/tests/python/contrib/test_ethosu/test_merge_constants.py
@@ -0,0 +1,561 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+pytest.importorskip("ethosu.vela")
+
+import tvm
+from tvm.script import tir as T
+from tvm.relay.backend.contrib.ethosu.tir.passes import MergeConstants
+import numpy as np
+
+
+def check_const_dictionaries(const_dict, new_const_dict):
+    assert list(const_dict) == list(new_const_dict)
+    for key, value in const_dict.items():
+        new_value = new_const_dict[key]
+        assert len(value) == len(new_value)
+        for i in range(len(value)):
+            assert value[i] == new_value[i]
+
+
+def test_only_one_operator():
+    # fmt: off
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def main(buffer2: T.Buffer[(128,), "uint8"], buffer3: T.Buffer[(32,), "uint8"]) -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            buffer1 = T.buffer_decl([8192], "int8")
+            buffer10 = T.buffer_decl([2048], "int8")
+            # body
+            p1 = T.allocate([128], "uint8", "global")
+            p4 = T.allocate([32], "uint8", "global")
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+
+
+    @tvm.script.ir_module
+    class ReferenceModule:
+        @T.prim_func
+        def main(buffer2: T.Buffer[(160,), "uint8"]) -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            buffer1 = T.buffer_decl([8192], "int8")
+            buffer10 = T.buffer_decl([2048], "int8")
+            # body
+            p4 = T.allocate([160], "uint8", "global")
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p4[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 128, 12, p4[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+    # fmt: on
+    const_dict = {
+        0: np.array([0, 10], dtype=np.uint8),
+        1: np.array([1, 11], dtype=np.uint8),
+    }
+    new_const_dict = {0: np.concatenate((const_dict[0], const_dict[1]))}
+    test_mod, const_dict = MergeConstants(const_dict)(InputModule)
+    reference_mod = ReferenceModule
+    tvm.ir.assert_structural_equal(test_mod, reference_mod, True)
+    check_const_dictionaries(const_dict, new_const_dict)
+
+
+def test_all_operators_with_weights():
+    # fmt: off
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def main(buffer2: T.Buffer[(128,), "uint8"], buffer3: T.Buffer[(32,), "uint8"], buffer4: T.Buffer[(112,), "uint8"], buffer5: T.Buffer[(32,), "uint8"], buffer6: T.Buffer[(112,), "uint8"], buffer7: T.Buffer[(32,), "uint8"], buffer8: T.Buffer[(112,), "uint8"], buffer9: T.Buffer[(32,), "uint8"]) -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            buffer1 = T.buffer_decl([8192], "int8")
+            buffer10 = T.buffer_decl([2048], "int8")
+            # body
+            p1 = T.allocate([128], "uint8", "global")
+            p2 = T.allocate([112], "uint8", "global")
+            p3 = T.allocate([112], "uint8", "global")
+            p4 = T.allocate([32], "uint8", "global")
+            p5 = T.allocate([32], "uint8", "global")
+            p6 = T.allocate([32], "uint8", "global")
+            p7 = T.allocate([112], "uint8", "global")
+            p8 = T.allocate([3], "uint8", "global")
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 112, p2[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 32, p5[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 112, p3[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 32, p6[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 112, 12, p5[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 112, p7[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer9[0], 32, p8[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 112, 12, p6[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p7[0], 112, 12, p8[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+
+
+    @tvm.script.ir_module
+    class ReferenceModule:
+        @T.prim_func
+        def main(buffer2: T.Buffer[(160,), "uint8"], buffer4: T.Buffer[(144,), "uint8"], buffer6: T.Buffer[(144,), "uint8"], buffer8: T.Buffer[(144,), "uint8"]) -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            buffer1 = T.buffer_decl([8192], "int8")
+            buffer10 = T.buffer_decl([2048], "int8")
+            # body
+            p4 = T.allocate([160], "uint8", "global")
+            p7 = T.allocate([144], "uint8", "global")
+            p10 = T.allocate([144], "uint8", "global")
+            p11 = T.allocate([144], "uint8", "global")
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p4[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 144, p7[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 128, 12, p4[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 144, p10[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p7[0], 112, 12, p7[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 144, p11[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p10[0], 112, 12, p10[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p11[0], 112, 12, p11[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+    # fmt: on
+
+    const_dict = {
+        0: np.array([0], dtype=np.uint8),
+        1: np.array([1], dtype=np.uint8),
+        2: np.array([2], dtype=np.uint8),
+        3: np.array([3], dtype=np.uint8),
+        4: np.array([4], dtype=np.uint8),
+        5: np.array([5], dtype=np.uint8),
+        6: np.array([6], dtype=np.uint8),
+        7: np.array([7], dtype=np.uint8),
+    }
+    new_const_dict = {
+        0: np.concatenate((const_dict[0], const_dict[1])),
+        1: np.concatenate((const_dict[2], const_dict[3])),
+        2: np.concatenate((const_dict[4], const_dict[5])),
+        3: np.concatenate((const_dict[6], const_dict[7])),
+    }
+    test_mod, const_dict = MergeConstants(const_dict)(InputModule)
+    reference_mod = ReferenceModule
+    tvm.ir.assert_structural_equal(test_mod, reference_mod, True)
+    check_const_dictionaries(const_dict, new_const_dict)
+
+
+def test_operators_with_and_without_weights():
+    # fmt: off
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def main(buffer2: T.Buffer[(80,), "uint8"], buffer3: T.Buffer[(64,), "uint8"]) -> None:
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})  
+            buffer0 = T.buffer_decl([390336], "int8")
+            buffer1 = T.buffer_decl([97156], "int8")
+            buffer6 = T.buffer_decl([390336], "int8")
+            # body
+            p2 = T.allocate([80], "uint8", "global")
+            p3 = T.allocate([64], "uint8", "global")
+            T.evaluate(T.call_extern("ethosu_pooling", "int8", 214, 227, 2, 214, 0, 227, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 454, 2, 1, "int8", 214, 114, 2, 214, 0, 114, buffer0[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 1824, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p2[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 64, p3[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 214, 114, 2, 214, 0, 114, buffer0[0], 0, 0, 0, T.float32(0.00392157), -128, "NHCWB16", 1824, 16, 1, "int8", 214, 114, 5, 214, 0, 114, buffer6[0], 0, 0, 0, T.float32(0.0174839), -128, "NHCWB16", 1824, 16, 1, 3, 1, 1, 1, 1, 2, p2[0], 80, 0, p3[0], 64, 0, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+
+
+    @tvm.script.ir_module
+    class ReferenceModule:
+        @T.prim_func
+        def main(buffer2: T.Buffer[(144,), "uint8"]) -> None:
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})  
+            buffer0 = T.buffer_decl([390336], "int8")
+            buffer1 = T.buffer_decl([97156], "int8")
+            buffer6 = T.buffer_decl([390336], "int8")
+            # body
+            p3 = T.allocate([144], "uint8", "global")
+            T.evaluate(T.call_extern("ethosu_pooling", "int8", 214, 227, 2, 214, 0, 227, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 454, 2, 1, "int8", 214, 114, 2, 214, 0, 114, buffer0[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 1824, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 144, p3[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 214, 114, 2, 214, 0, 114, buffer0[0], 0, 0, 0, T.float32(0.00392157), -128, "NHCWB16", 1824, 16, 1, "int8", 214, 114, 5, 214, 0, 114, buffer6[0], 0, 0, 0, T.float32(0.0174839), -128, "NHCWB16", 1824, 16, 1, 3, 1, 1, 1, 1, 2, p3[0], 80, 0, p3[80], 64, 0, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+    # fmt: on
+
+    const_dict = {
+        0: np.array([0], dtype=np.uint8),
+        1: np.array([1], dtype=np.uint8),
+    }
+    new_const_dict = {0: np.concatenate((const_dict[0], const_dict[1]))}
+    test_mod, const_dict = MergeConstants(const_dict)(InputModule)
+    reference_mod = ReferenceModule
+    tvm.ir.assert_structural_equal(test_mod, reference_mod, True)
+    check_const_dictionaries(const_dict, new_const_dict)
+
+
+def test_copy_to_buffer_with_local_scope():
+    # fmt: off
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def main(buffer1: T.Buffer[(64,), "uint8"], 
+        buffer2: T.Buffer[(48,), "uint8"], 
+        buffer3: T.Buffer[(256,), "uint8"],
+        buffer4: T.Buffer[(256,), "uint8"],
+        buffer5: T.Buffer[(16,), "uint8"],
+        buffer6: T.Buffer[(48,), "uint8"],
+        buffer7: T.Buffer[(256,), "uint8"],
+        buffer8: T.Buffer[(64,), "uint8"],
+        buffer9: T.Buffer[(256,), "int8"],
+        ) -> None:
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            # body
+            p1 = T.allocate([48], "uint8", "global")
+            p2 = T.allocate([48], "uint8", "global")
+            p3 = T.allocate([256], "int8", "local")
+            p5 = T.allocate([16], "uint8", "global")
+            p6 = T.allocate([48], "uint8", "global")
+            p7 = T.allocate([256], "int8", "local")
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 48, p1[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 48, p2[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 256, p3[0], dtype="handle")) # Local
+            T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 16, p5[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 48, p6[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 4, 4, 4, 0, 4, buffer1[0], 0, 0, 0, T.float32(0.00392081), -128, "NHWC", 16, 4, 1, "int8", 4, 4, 4, 4, 0, 4, buffer9[0], 0, 0, 0, T.float32(0.00839574), -128, "NHCWB16", 64, 16, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, 0, p2[0], 48, 0, 0, 0, 0, "TANH", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 256, p7[0], dtype="handle")) # Local
+            T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 4, 4, 4, 4, 0, 4, buffer9[0], 0, 0, 0, T.float32(0.0078125), 0, "NHCWB16", 64, 16, 1, "int8", 4, 4, 4, 4, 0, 4, buffer8[0], 0, 0, 0, T.float32(0.00372155), -128, "NHWC", 16, 4, 1, 1, 1, 1, 1, 1, 1, p5[0], 16, 0, p6[0], 48, 0, 0, 0, 0, "TANH", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+
+
+    @tvm.script.ir_module
+    class ReferenceModule:
+        @T.prim_func
+        def main(buffer1: T.Buffer[(64,), "uint8"], 
+            buffer2: T.Buffer[(96,), "uint8"], 
+            buffer4: T.Buffer[(256,), "uint8"],
+            buffer5: T.Buffer[(64,), "uint8"],
+            buffer7: T.Buffer[(256,), "uint8"],
+            buffer8: T.Buffer[(64,), "uint8"],
+            buffer9: T.Buffer[(256,), "int8"],
+            ) -> None:
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            # body
+            p1 = T.allocate([96], "uint8", "global")
+            p2 = T.allocate([64], "uint8", "global")
+            p3 = T.allocate([256], "int8", "local")
+            p7 = T.allocate([256], "int8", "local")
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p1[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 256, p3[0], dtype="handle")) # Local
+            T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 64, p2[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 4, 4, 4, 0, 4, buffer1[0], 0, 0, 0, T.float32(0.00392081), -128, "NHWC", 16, 4, 1, "int8", 4, 4, 4, 4, 0, 4, buffer9[0], 0, 0, 0, T.float32(0.00839574), -128, "NHCWB16", 64, 16, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, 0, p1[48], 48, 0, 0, 0, 0, "TANH", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 256, p7[0], dtype="handle")) # Local
+            T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 4, 4, 4, 4, 0, 4, buffer9[0], 0, 0, 0, T.float32(0.0078125), 0, "NHCWB16", 64, 16, 1, "int8", 4, 4, 4, 4, 0, 4, buffer8[0], 0, 0, 0, T.float32(0.00372155), -128, "NHWC", 16, 4, 1, 1, 1, 1, 1, 1, 1, p2[0], 16, 0, p2[16], 48, 0, 0, 0, 0, "TANH", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+    # fmt: on
+
+    const_dict = {
+        1: np.array([1], dtype=np.uint8),
+        2: np.array([2], dtype=np.uint8),
+        3: np.array([3], dtype=np.uint8),
+        4: np.array([4], dtype=np.uint8),
+        5: np.array([5], dtype=np.uint8),
+        6: np.array([6], dtype=np.uint8),
+    }
+    new_const_dict = {
+        1: np.concatenate((const_dict[1], const_dict[2])),
+        2: const_dict[3],
+        3: np.concatenate((const_dict[4], const_dict[5])),
+        4: const_dict[6],
+    }
+    test_mod, const_dict = MergeConstants(const_dict)(InputModule)
+    reference_mod = ReferenceModule
+    tvm.ir.assert_structural_equal(test_mod, reference_mod, True)
+    check_const_dictionaries(const_dict, new_const_dict)
+
+
+def test_no_copies():
+    # fmt: off
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def main() -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            placeholder = T.buffer_decl([20], "int8")
+            ethosu_write = T.buffer_decl([16], "int8")
+            # body
+            ethosu_write_4 = T.allocate([16], "int8", "global")
+            T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 1, 4, 4, 1, 0, 4, placeholder[0], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 4, 1, "int8", 1, 4, 1, 1, 0, 4, placeholder[16], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 1, 1, "int8", 1, 4, 4, 1, 0, 4, ethosu_write_4[0], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 4, 1, "MAX", 0, "CLIP", -128, 127, "TFL", 1, 4, 4, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_identity", "int8", 1, 4, 4, 1, 0, 4, ethosu_write_4[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "int8", 1, 4, 4, 1, 0, 4, ethosu_write[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "AVG", 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+    
+    @tvm.script.ir_module
+    class ReferenceModule:
+        @T.prim_func
+        def main() -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            placeholder = T.buffer_decl([20], "int8")
+            ethosu_write = T.buffer_decl([16], "int8")
+            # body
+            ethosu_write_4 = T.allocate([16], "int8", "global")
+            T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 1, 4, 4, 1, 0, 4, placeholder[0], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 4, 1, "int8", 1, 4, 1, 1, 0, 4, placeholder[16], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 1, 1, "int8", 1, 4, 4, 1, 0, 4, ethosu_write_4[0], 0, 0, 0, T.float32(0.00783747), -128, "NHWC", 1, 4, 1, "MAX", 0, "CLIP", -128, 127, "TFL", 1, 4, 4, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_identity", "int8", 1, 4, 4, 1, 0, 4, ethosu_write_4[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "int8", 1, 4, 4, 1, 0, 4, ethosu_write[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 4, 1, "AVG", 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+    # fmt: on
+
+    const_dict = {}
+    new_const_dict = {}
+    test_mod, const_dict = MergeConstants(const_dict)(InputModule)
+    reference_mod = ReferenceModule
+    tvm.ir.assert_structural_equal(test_mod, reference_mod, True)
+    check_const_dictionaries(const_dict, new_const_dict)
+
+
+def test_copies_to_the_same_buffer():
+    # fmt: off
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def main(buffer2: T.Buffer[(128,), "uint8"], buffer3: T.Buffer[(32,), "uint8"]) -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            buffer1 = T.buffer_decl([8192], "int8")
+            buffer10 = T.buffer_decl([2048], "int8")
+            # body
+            p1 = T.allocate([128], "uint8", "global")
+            p4 = T.allocate([32], "uint8", "global")
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+
+
+    @tvm.script.ir_module
+    class ReferenceModule:
+        @T.prim_func
+        def main(buffer2: T.Buffer[(160,), "uint8"]) -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            buffer1 = T.buffer_decl([8192], "int8")
+            buffer10 = T.buffer_decl([2048], "int8")
+            # body
+            p5 = T.allocate([160], "uint8", "global")
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p5[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5[0], 128, 12, p5[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p5[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p5[0], 128, 12, p5[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+    # fmt: on
+
+    const_dict = {
+        0: np.array([0], dtype=np.uint8),
+        1: np.array([1], dtype=np.uint8),
+    }
+    new_const_dict = {0: np.concatenate((const_dict[0], const_dict[1]))}
+    test_mod, const_dict = MergeConstants(const_dict)(InputModule)
+    reference_mod = ReferenceModule
+    tvm.ir.assert_structural_equal(test_mod, reference_mod, True)
+    check_const_dictionaries(const_dict, new_const_dict)
+
+
+def test_read_from_the_same_buffer():
+    # fmt: off
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            # buffer definition
+            T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
+            T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
+            # body
+            p1 = T.allocate([368], "uint8", "global")
+            p2 = T.allocate([96], "uint8", "global") 
+            T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p2[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p2[0], 48, p2[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        __tvm_meta__ = None
+
+
+    @tvm.script.ir_module
+    class ReferenceModule:
+        @T.prim_func
+        def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            # body
+            p1 = T.allocate([464], "uint8", "global")
+            T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+    __tvm_meta__ = None
+    # fmt: on
+
+    const_dict = {
+        1: np.array([1], dtype=np.uint8),
+        2: np.array([2], dtype=np.uint8),
+    }
+    new_const_dict = {1: np.concatenate((const_dict[1], const_dict[2]))}
+    test_mod, const_dict = MergeConstants(const_dict)(InputModule)
+    reference_mod = ReferenceModule
+    tvm.ir.assert_structural_equal(test_mod, reference_mod, True)
+    check_const_dictionaries(const_dict, new_const_dict)
+
+
+def test_cycle_count():
+    # fmt: off
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def main(buffer2: T.Buffer[(128,), "uint8"], buffer3: T.Buffer[(32,), "uint8"], buffer4: T.Buffer[(112,), "uint8"], buffer5: T.Buffer[(32,), "uint8"], buffer6: T.Buffer[(112,), "uint8"], buffer7: T.Buffer[(32,), "uint8"], buffer8: T.Buffer[(112,), "uint8"], buffer9: T.Buffer[(32,), "uint8"]) -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            v1a = T.var("int32")
+            v1b = T.var("int32")
+            v1c = T.var("int32")
+            v2a = T.var("int32")
+            v2b = T.var("int32")
+            v2c = T.var("int32")
+            v3a = T.var("int32")
+            v3b = T.var("int32")
+            v3c = T.var("int32")
+            v4a = T.var("int32")
+            v4b = T.var("int32")
+            v4c = T.var("int32")
+            buffer1 = T.buffer_decl([8192], "int8")
+            buffer10 = T.buffer_decl([2048], "int8")
+            # body
+            p1 = T.allocate([128], "uint8", "global")
+            p2 = T.allocate([112], "uint8", "global")
+            p3 = T.allocate([112], "uint8", "global")
+            p4 = T.allocate([32], "uint8", "global")
+            p5 = T.allocate([32], "uint8", "global")
+            p6 = T.allocate([32], "uint8", "global")
+            p7 = T.allocate([112], "uint8", "global")
+            p8 = T.allocate([3], "uint8", "global")
+            with T.attr(T.iter_var(v1a, None, "DataPar", ""), "pragma_compute_cycles_hint", 100):
+                T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p1[0], dtype="handle"))
+            with T.attr(T.iter_var(v1b, None, "DataPar", ""), "pragma_compute_cycles_hint", 101):
+                T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 32, p4[0], dtype="handle"))
+            with T.attr(T.iter_var(v2a, None, "DataPar", ""), "pragma_compute_cycles_hint", 102):
+                T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 112, p2[0], dtype="handle"))
+            with T.attr(T.iter_var(v2b, None, "DataPar", ""), "pragma_compute_cycles_hint", 103):
+                T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 32, p5[0], dtype="handle"))
+            with T.attr(T.iter_var(v1c, None, "DataPar", ""), "pragma_compute_cycles_hint", 300):
+                T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 128, 12, p4[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            with T.attr(T.iter_var(v3a, None, "DataPar", ""), "pragma_compute_cycles_hint", 104):
+                T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 112, p3[0], dtype="handle"))
+            with T.attr(T.iter_var(v3b, None, "DataPar", ""), "pragma_compute_cycles_hint", 105):
+                T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 32, p6[0], dtype="handle"))
+            with T.attr(T.iter_var(v2c, None, "DataPar", ""), "pragma_compute_cycles_hint", 301):
+                T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 112, 12, p5[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            with T.attr(T.iter_var(v4a, None, "DataPar", ""), "pragma_compute_cycles_hint", 106):
+                T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 112, p7[0], dtype="handle"))
+            with T.attr(T.iter_var(v4b, None, "DataPar", ""), "pragma_compute_cycles_hint", 107):
+                T.evaluate(T.call_extern("ethosu_copy", buffer9[0], 32, p8[0], dtype="handle"))
+            with T.attr(T.iter_var(v3c, None, "DataPar", ""), "pragma_compute_cycles_hint", 302):
+                T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 112, 12, p6[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            with T.attr(T.iter_var(v4c, None, "DataPar", ""), "pragma_compute_cycles_hint", 303):
+                T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p7[0], 112, 12, p8[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+
+
+    @tvm.script.ir_module
+    class ReferenceModule:
+        @T.prim_func
+        def main(buffer2: T.Buffer[(160,), "uint8"], buffer4: T.Buffer[(144,), "uint8"], buffer6: T.Buffer[(144,), "uint8"], buffer8: T.Buffer[(144,), "uint8"]) -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            v1a = T.var("int32")
+            v1c = T.var("int32")
+            v2a = T.var("int32")
+            v2c = T.var("int32")
+            v3a = T.var("int32")
+            v3c = T.var("int32")
+            v4a = T.var("int32")
+            v4c = T.var("int32")
+            buffer1 = T.buffer_decl([8192], "int8")
+            buffer10 = T.buffer_decl([2048], "int8")
+            # body
+            p4 = T.allocate([160], "uint8", "global")
+            p7 = T.allocate([144], "uint8", "global")
+            p10 = T.allocate([144], "uint8", "global")
+            p11 = T.allocate([144], "uint8", "global")
+            with T.attr(T.iter_var(v1a, None, "DataPar", ""), "pragma_compute_cycles_hint", 201):
+                T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p4[0], dtype="handle"))
+            with T.attr(T.iter_var(v2a, None, "DataPar", ""), "pragma_compute_cycles_hint", 205):
+                T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 144, p7[0], dtype="handle"))
+            with T.attr(T.iter_var(v1c, None, "DataPar", ""), "pragma_compute_cycles_hint", 300):
+                T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 128, 12, p4[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            with T.attr(T.iter_var(v3a, None, "DataPar", ""), "pragma_compute_cycles_hint", 209):
+                T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 144, p10[0], dtype="handle"))
+            with T.attr(T.iter_var(v2c, None, "DataPar", ""), "pragma_compute_cycles_hint", 301):
+                T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p7[0], 112, 12, p7[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            with T.attr(T.iter_var(v4a, None, "DataPar", ""), "pragma_compute_cycles_hint", 213):
+                T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 144, p11[0], dtype="handle"))
+            with T.attr(T.iter_var(v3c, None, "DataPar", ""), "pragma_compute_cycles_hint", 302):
+                T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p10[0], 112, 12, p10[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            with T.attr(T.iter_var(v4c, None, "DataPar", ""), "pragma_compute_cycles_hint", 303):
+                T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p11[0], 112, 12, p11[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+    # fmt: on
+
+    const_dict = {
+        0: np.array([0], dtype=np.uint8),
+        1: np.array([1], dtype=np.uint8),
+        2: np.array([2], dtype=np.uint8),
+        3: np.array([3], dtype=np.uint8),
+        4: np.array([4], dtype=np.uint8),
+        5: np.array([5], dtype=np.uint8),
+        6: np.array([6], dtype=np.uint8),
+        7: np.array([7], dtype=np.uint8),
+    }
+    new_const_dict = {
+        0: np.concatenate((const_dict[0], const_dict[1])),
+        1: np.concatenate((const_dict[2], const_dict[3])),
+        2: np.concatenate((const_dict[4], const_dict[5])),
+        3: np.concatenate((const_dict[6], const_dict[7])),
+    }
+    test_mod, const_dict = MergeConstants(const_dict)(InputModule)
+    reference_mod = ReferenceModule
+    tvm.ir.assert_structural_equal(test_mod, reference_mod, True)
+    check_const_dictionaries(const_dict, new_const_dict)
+
+
+def test_multiple_prim_funcs():
+    # fmt: off
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def main():
+            T.evaluate(0)
+
+        @T.prim_func
+        def abc():
+            T.evaluate(0)
+    # fmt: on
+
+    err_rgx = (
+        r"Expected a single primitive function called 'main'. "
+        r"Please run the MergeConstants pass in conjunction with the LowerToTIR\(\) pass."
+    )
+    with pytest.raises(tvm.TVMError, match=err_rgx):
+        MergeConstants({})(InputModule)
+
+
+def test_no_main_prim_func():
+    # fmt: off
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def abs():
+            T.evaluate(0)
+    # fmt: on
+
+    err_rgx = (
+        r"Expected a single primitive function called 'main'. "
+        r"Please run the MergeConstants pass in conjunction with the LowerToTIR\(\) pass."
+    )
+    with pytest.raises(tvm.TVMError, match=err_rgx):
+        MergeConstants({})(InputModule)
diff --git a/tests/python/contrib/test_ethosu/test_networks.py b/tests/python/contrib/test_ethosu/test_networks.py
index 075565cd92..02643f6c1d 100644
--- a/tests/python/contrib/test_ethosu/test_networks.py
+++ b/tests/python/contrib/test_ethosu/test_networks.py
@@ -44,13 +44,13 @@ MOBILENET_V2_URL = (
 @pytest.mark.parametrize(
     "accel_type, model_url, workspace_size",
     [
-        ("ethos-u65-256", MOBILENET_V1_URL, 1892704),
-        ("ethos-u65-256", MOBILENET_V2_URL, 2257984),
-        ("ethos-u55-256", MOBILENET_V1_URL, 1892704),
-        ("ethos-u55-256", MOBILENET_V2_URL, 2257984),
-        ("ethos-u55-128", MOBILENET_V2_URL, 2257984),
-        ("ethos-u55-64", MOBILENET_V2_URL, 2257984),
-        ("ethos-u55-32", MOBILENET_V2_URL, 2258000),
+        ("ethos-u65-256", MOBILENET_V1_URL, 1793376),
+        ("ethos-u65-256", MOBILENET_V2_URL, 2218160),
+        ("ethos-u55-256", MOBILENET_V1_URL, 1793376),
+        ("ethos-u55-256", MOBILENET_V2_URL, 2218160),
+        ("ethos-u55-128", MOBILENET_V2_URL, 2218160),
+        ("ethos-u55-64", MOBILENET_V2_URL, 2218160),
+        ("ethos-u55-32", MOBILENET_V2_URL, 2218160),
     ],
 )
 def test_networks_without_usmp(accel_type, model_url, workspace_size):
diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py
index cc996e5941..d2c759a0ae 100644
--- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py
+++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py
@@ -41,9 +41,6 @@ class ReferenceModule:
         buffer_5 = T.buffer_decl([160], "uint8")
         buffer_6 = T.buffer_decl([2992], "uint8")
         buffer_7 = T.buffer_decl([160], "uint8")
-        T.preflattened_buffer(placeholder, [1, 8, 12, 16], "int8", data=placeholder.data)
-        T.preflattened_buffer(placeholder_1, [1, 8, 10, 16], "int8", data=placeholder_1.data)
-        T.preflattened_buffer(T_concat, [1, 8, 32, 16], "int8", data=T_concat.data)
         # body
         T_concat_1 = T.allocate([2816], "int8", "global", annotations={"disable_lower_builtin":True})
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, placeholder_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py
index 63f9fc44c7..46a3c5a15b 100644
--- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py
+++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py
@@ -373,8 +373,6 @@ class Conv2dDoubleCascade1:
         buffer_1 = T.buffer_decl([80], "uint8")
         buffer_2 = T.buffer_decl([320], "uint8")
         buffer_3 = T.buffer_decl([160], "uint8")
-        T.preflattened_buffer(placeholder_5, [1, 8, 8, 3], 'int8', data=placeholder_5.data)
-        T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 8], 'int8', data=ethosu_write_1.data)
         # body
         ethosu_write_2 = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True})
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
@@ -394,8 +392,6 @@ class Conv2dDoubleCascade2:
         buffer_1 = T.buffer_decl([320], "uint8")
         buffer_2 = T.buffer_decl([1312], "uint8")
         buffer_3 = T.buffer_decl([2608], "uint8")
-        T.preflattened_buffer(placeholder_5, [1, 8, 8, 3], 'int8', data=placeholder_5.data)
-        T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 8], 'int8', data=ethosu_write_1.data)
         # body
         ethosu_write_2 = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True})
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
@@ -415,8 +411,6 @@ class Conv2dDoubleCascade3:
         buffer_1 = T.buffer_decl([80], "uint8")
         buffer_2 = T.buffer_decl([320], "uint8")
         buffer_3 = T.buffer_decl([880], "uint8")
-        T.preflattened_buffer(placeholder_5, [1, 16, 16, 3], 'int8', data=placeholder_5.data)
-        T.preflattened_buffer(ethosu_write_1, [1, 20, 4, 8], 'int8', data=ethosu_write_1.data)
         # body
         ethosu_write_2 = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True})
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
@@ -438,8 +432,6 @@ class Conv2dDoubleCascade4:
         buffer_1 = T.buffer_decl([352], "uint8")
         buffer_2 = T.buffer_decl([272], "uint8")
         buffer_3 = T.buffer_decl([11040], "uint8")
-        T.preflattened_buffer(placeholder_5, [1, 8, 1, 8, 16], 'int8', data=placeholder_5.data)
-        T.preflattened_buffer(ethosu_write_1, [1, 8, 2, 8, 16], 'int8', data=ethosu_write_1.data)
         # body
         ethosu_write_2 = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True})
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
@@ -459,8 +451,6 @@ class Conv2dDoubleCascade5:
         buffer_1 = T.buffer_decl([320], "uint8")
         buffer_2 = T.buffer_decl([304], "uint8")
         buffer_3 = T.buffer_decl([80], "uint8")
-        T.preflattened_buffer(placeholder, [1, 8, 8, 3], 'int8', data=placeholder.data)
-        T.preflattened_buffer(ethosu_write, [1, 32, 32, 8], 'int8', data=ethosu_write.data)
         # body
         ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True})
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle"))
@@ -480,8 +470,6 @@ class Conv2dDoubleCascade6:
         buffer_1 = T.buffer_decl([352], "uint8")
         buffer_2 = T.buffer_decl([11040], "uint8")
         buffer_3 = T.buffer_decl([272], "uint8")
-        T.preflattened_buffer(placeholder, [1, 8, 1, 8, 16], 'int8', data=placeholder.data)
-        T.preflattened_buffer(ethosu_write, [1, 32, 2, 32, 16], 'int8', data=ethosu_write.data)
         # body
         ethosu_write_1 = T.allocate([12288], "int8", "global", annotations={"disable_lower_builtin":True})
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 3, 8, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 768, 16, 256, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", 0, 0, 0, dtype="handle"))
@@ -641,8 +629,6 @@ class Conv2dInlineCopy1:
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([848], "uint8")
         buffer_1 = T.buffer_decl([160], "uint8")
-        T.preflattened_buffer(placeholder_3, [1, 10, 12, 8], 'int8', data=placeholder_3.data)
-        T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 16], 'int8', data=ethosu_write_1.data)
         # body
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, placeholder_3[120], 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 848, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
     __tvm_meta__ = None
@@ -656,8 +642,6 @@ class Conv2dInlineCopy2:
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([160], "uint8")
         buffer_1 = T.buffer_decl([656], "uint8")
-        T.preflattened_buffer(placeholder_3, [1, 7, 9, 5], 'int8', data=placeholder_3.data)
-        T.preflattened_buffer(ethosu_write_1, [1, 3, 5, 16], 'int8', data=ethosu_write_1.data)
         # body
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, placeholder_3[146], 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 656, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
     __tvm_meta__ = None
@@ -700,8 +684,6 @@ class Conv2dInlineReshape1:
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([160], "uint8")
         buffer_1 = T.buffer_decl([848], "uint8")
-        T.preflattened_buffer(placeholder_3, [4, 6, 8, 1], 'int8', data=placeholder_3.data)
-        T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data)
         # body
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
@@ -716,8 +698,6 @@ class Conv2dInlineReshape2:
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([160], "uint8")
         buffer_1 = T.buffer_decl([848], "uint8")
-        T.preflattened_buffer(placeholder_3, [1, 24, 8], 'int8', data=placeholder_3.data)
-        T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data)
         # body
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
@@ -732,8 +712,6 @@ class Conv2dInlineReshape3:
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([160], "uint8")
         buffer_1 = T.buffer_decl([848], "uint8")
-        T.preflattened_buffer(placeholder_3, [192, 1], 'int8', data=placeholder_3.data)
-        T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data)
         # body
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
@@ -748,8 +726,6 @@ class Conv2dInlineReshape4:
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer = T.buffer_decl([160], "uint8")
         buffer_1 = T.buffer_decl([848], "uint8")
-        T.preflattened_buffer(placeholder_3, [192], 'int8', data=placeholder_3.data)
-        T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data)
         # body
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py
index 932df71d24..6b97b38d80 100644
--- a/tests/python/contrib/test_ethosu/test_replace_copy.py
+++ b/tests/python/contrib/test_ethosu/test_replace_copy.py
@@ -34,16 +34,11 @@ class ReferenceModule:
     def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(2048,), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
-        buffer = T.buffer_decl([80], "uint8")
-        buffer_1 = T.buffer_decl([304], "uint8")
-        T.preflattened_buffer(placeholder_3, [1, 16, 16, 32], dtype="int8", data=placeholder_3.data)
-        T.preflattened_buffer(ethosu_write_1, [1, 16, 16, 8], dtype="int8", data=ethosu_write_1.data)
+        buffer_1 = T.buffer_decl([384], "uint8")
         # body
-        placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin": True})
-        placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin": True})
-        T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 304, placeholder_global[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer[0], 80, placeholder_d_global[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        placeholder_global = T.allocate([384], "uint8", "global", annotations={"disable_lower_builtin": True})
+        T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 384, placeholder_global[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, T.int8(-1), T.int8(-1), 12, placeholder_global[304], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
     __tvm_meta__ = None
 # fmt: on
 
@@ -80,23 +75,15 @@ class WeightStream:
     def main(placeholder_5: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(4096,), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
-        buffer = T.buffer_decl([416], "uint8")
-        buffer_1 = T.buffer_decl([112], "uint8")
-        buffer_2 = T.buffer_decl([272], "uint8")
-        buffer_3 = T.buffer_decl([64], "uint8")
-        T.preflattened_buffer(placeholder_5, [1, 16, 16, 32], dtype="int8", data=placeholder_5.data)
-        T.preflattened_buffer(ethosu_write_1, [1, 16, 16, 16], dtype="int8", data=ethosu_write_1.data)
+        buffer = T.buffer_decl([528], "uint8")
+        buffer_2 = T.buffer_decl([336], "uint8")
         # body
-        placeholder_global_unrolled_iter_0 = T.allocate([416], "uint8", "global", annotations={"disable_lower_builtin": True})
-        placeholder_d_global_unrolled_iter_0 = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin": True})
-        placeholder_global_unrolled_iter_1 = T.allocate([272], "uint8", "global", annotations={"disable_lower_builtin": True})
-        placeholder_d_global_unrolled_iter_1 = T.allocate([64],  "uint8", "global", annotations={"disable_lower_builtin": True})
-        T.evaluate(T.call_extern("ethosu_copy", buffer[0], 416, placeholder_global_unrolled_iter_0[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 112, placeholder_d_global_unrolled_iter_0[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 272, placeholder_global_unrolled_iter_1[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 64, placeholder_d_global_unrolled_iter_1[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_global_unrolled_iter_0[0], 416, T.int8(-1), T.int8(-1), 12, placeholder_d_global_unrolled_iter_0[0], 112, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, ethosu_write_1[10], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_global_unrolled_iter_1[0], 272, T.int8(-1), T.int8(-1), 12, placeholder_d_global_unrolled_iter_1[0], 64, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        placeholder_d_global = T.allocate([528], "uint8", "global", annotations={"disable_lower_builtin": True})
+        placeholder_d_global_1 = T.allocate([336], "uint8", "global", annotations={"disable_lower_builtin": True})
+        T.evaluate(T.call_extern("ethosu_copy", buffer[0], 528, placeholder_d_global[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 336, placeholder_d_global_1[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_d_global[0], 416, T.int8(-1), T.int8(-1), 12, placeholder_d_global[416], 112, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, ethosu_write_1[10], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_d_global_1[0], 272, T.int8(-1), T.int8(-1), 12, placeholder_d_global_1[272], 64, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
     __tvm_meta__ = None
 # fmt: on
 
diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py
index 4baea26e59..ba050de2b4 100644
--- a/tests/python/contrib/test_ethosu/test_scheduler.py
+++ b/tests/python/contrib/test_ethosu/test_scheduler.py
@@ -182,24 +182,16 @@ class DiamondGraphTir:
     @T.prim_func
     def main(placeholder: T.Buffer[(301056,), "int8"], ethosu_write: T.Buffer[(75264,), "int8"]) -> None:
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
-        T.preflattened_buffer(placeholder, [1, 56, 56, 96], dtype='int8', data=placeholder.data)
-        T.preflattened_buffer(ethosu_write, [1, 56, 56, 24], dtype='int8', data=ethosu_write.data)
-        buffer1 = T.buffer_decl([2608], "uint8")
-        buffer2 = T.buffer_decl([240], "uint8")
-        buffer3 = T.buffer_decl([736], "uint8")
-        buffer4 = T.buffer_decl([240], "uint8")
-        p1 = T.allocate([2608], "uint8", "global", annotations={"disable_lower_builtin":True})
-        p2 = T.allocate([240], "uint8", "global", annotations={"disable_lower_builtin":True})
-        p3 = T.allocate([736], "uint8", "global", annotations={"disable_lower_builtin":True})
-        p4 = T.allocate([240], "uint8", "global", annotations={"disable_lower_builtin":True})
+        buffer1 = T.buffer_decl([2848], "uint8")
+        buffer3 = T.buffer_decl([976], "uint8")
+        p1 = T.allocate([2848], "uint8", "global", annotations={"disable_lower_builtin":True})
+        p2 = T.allocate([976], "uint8", "global", annotations={"disable_lower_builtin":True})
         p5 = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin":True})
         p6 = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin":True})
-        T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 2608, p1[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 240, p2[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 736, p3[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 240, p4[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, p1[0], 2608, T.int8(-1), T.int8(-1), 12, p2[0], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, p6[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, p3[0], 736, T.int8(-1), T.int8(-1), 12, p4[0], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 2848, p1[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 976, p2[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, p1[0], 2608, T.int8(-1), T.int8(-1), 12, p1[2608], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, p6[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, p2[0], 736, T.int8(-1), T.int8(-1), 12, p2[736], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
         T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 56, 56, 24, 56, 0, 56, p5[0], 0, 0, 0,T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, p6[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, ethosu_write[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "ADD", 0, "NONE", 0, 0, "TFL", 0, 0, 0, dtype="handle"))
     __tvm_meta__ = None
 # fmt: on