You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2022/11/10 09:01:10 UTC

[tvm] branch main updated: [microNPU] Fixed MergeConstants pass on striped networks (#13281)

This is an automated email from the ASF dual-hosted git repository.

lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 54bd5e1f5f [microNPU] Fixed MergeConstants pass on striped networks (#13281)
54bd5e1f5f is described below

commit 54bd5e1f5fa52c498b4a4ff13d795daf52a81bfd
Author: Sergei Smirnov <89...@users.noreply.github.com>
AuthorDate: Thu Nov 10 12:01:05 2022 +0300

    [microNPU] Fixed MergeConstants pass on striped networks (#13281)
    
    This PR fixes the bug in MergeConstants pass on striped networks on Ethos-U NPU.
    
    The issue was caused by _DivideConstants_ pass which is introducing new mod parameters and changing their order. So ethosu_write parameter in some cases is moved from the end of the list to the middle.
    E.g. from:
    `[ethos-u_0_i0, p1, p2, p3, p4, p5, p6, ethosu_write]`
    To:
    `[ethos-u_0_i0, p1, p2, ethosu_write, placeholder, placeholder, placeholder, placeholder, placeholder, placeholder, placeholder, placeholder]`
    
    Updated version of the  _GetArgsToMergeWithoutArgsNotInConstDict_ and _MakeNewConstDict_ methods in passes.cc can now correctly modify const_dict according to the new parameter list.
---
 .../relay/backend/contrib/ethosu/tir/compiler.py   |   5 +-
 src/tir/contrib/ethosu/passes.cc                   |  27 ++-
 .../contrib/test_ethosu/test_encode_constants.py   |  32 ++--
 .../contrib/test_ethosu/test_merge_constants.py    | 189 +++++++++++++++++++++
 4 files changed, 224 insertions(+), 29 deletions(-)

diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
index aaac59ad4a..4133aff6ef 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
@@ -91,10 +91,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):
         mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod)
         mod = ethosu_passes.HoistAllocates()(mod)
         mod = tvm.tir.transform.RemoveNoOp()(mod)
-        #  MergeConstant pass currently does not support striped schedules.
-        #  It requires further investigation.
-        if not util.is_striping_enabled():
-            mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod)
+        mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod)
         mod = ethosu_passes.CopyComputeReordering()(mod)
 
         # When striping is enabled and if storage_rewrite is not run
diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc
index 2f6fa8f3ea..d51ffbf833 100644
--- a/src/tir/contrib/ethosu/passes.cc
+++ b/src/tir/contrib/ethosu/passes.cc
@@ -514,7 +514,7 @@ class MergeConstantsMutator : public StmtExprMutator {
 
     // Make the new const dict
     Array<Array<IntImm>> args_to_merge{GetArgsToMerge(main_func->buffer_map, main_func->params)};
-    Array<Array<IntImm>> buffers_to_merge{
+    Map<IntImm, Array<IntImm>> buffers_to_merge{
         GetArgsToMergeWithoutArgsNotInConstDict(args_to_merge, const_dict)};
     Map<IntImm, runtime::NDArray> new_const_dict{MakeNewConstDict(buffers_to_merge, const_dict)};
 
@@ -832,9 +832,11 @@ class MergeConstantsMutator : public StmtExprMutator {
     return vector;
   }
 
-  Array<Array<IntImm>> GetArgsToMergeWithoutArgsNotInConstDict(
+  Map<IntImm, Array<IntImm>> GetArgsToMergeWithoutArgsNotInConstDict(
       const Array<Array<IntImm>>& args_to_merge, const Map<IntImm, runtime::NDArray>& const_dict) {
-    Array<Array<IntImm>> new_args_to_merge{};
+    Map<IntImm, Array<IntImm>> new_args_to_merge{};
+    bool first_arg_found = false;
+    int64_t new_arg_key = 0;  // the updated key of the merged const_dict
     for (Array<IntImm> args : args_to_merge) {
       IntImm key{args[0]};
       auto it = std::find_if(const_dict.begin(), const_dict.end(),
@@ -842,21 +844,29 @@ class MergeConstantsMutator : public StmtExprMutator {
                                return pair.first->value == key->value;
                              });
       if (it != const_dict.end()) {
-        new_args_to_merge.push_back(args);
+        if (first_arg_found == false) {
+          first_arg_found = true;
+          new_arg_key = key->value;
+        }
+        new_args_to_merge.Set(IntImm(DataType::Int(64), new_arg_key), args);
+      }
+      if (first_arg_found) {
+        new_arg_key++;
       }
     }
     return new_args_to_merge;
   }
 
-  Map<IntImm, runtime::NDArray> MakeNewConstDict(const Array<Array<IntImm>>& args_to_merge,
+  Map<IntImm, runtime::NDArray> MakeNewConstDict(const Map<IntImm, Array<IntImm>>& args_to_merge,
                                                  Map<IntImm, runtime::NDArray> const_dict) {
     Map<IntImm, runtime::NDArray> new_const_dict{};
     if (args_to_merge.size() == 0) {
       return new_const_dict;
     }
 
-    int64_t key = args_to_merge[0][0]->value;
-    for (Array<IntImm> args : args_to_merge) {
+    for (auto const& elem : args_to_merge) {
+      IntImm key = elem.first;
+      Array<IntImm> args = elem.second;
       int64_t size = 0;
       for (IntImm arg : args) {
         auto it = std::find_if(const_dict.begin(), const_dict.end(),
@@ -876,8 +886,7 @@ class MergeConstantsMutator : public StmtExprMutator {
         arg_constant.CopyToBytes(static_cast<uint8_t*>(constant->data) + offset, nbytes);
         offset += nbytes;
       }
-      new_const_dict.Set(IntImm(DataType::Int(64), key), constant);
-      key += 1;
+      new_const_dict.Set(key, constant);
     }
     return new_const_dict;
   }
diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py
index 6ffbf22312..c751d44b61 100644
--- a/tests/python/contrib/test_ethosu/test_encode_constants.py
+++ b/tests/python/contrib/test_ethosu/test_encode_constants.py
@@ -340,15 +340,15 @@ def test_direct_read_only(accelerator, reference_mod, reference_const_sizes):
 @tvm.script.ir_module
 class MixedReadU55:
     @T.prim_func
-    def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(112,), "uint8"]) -> None:
+    def main(ifm: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
         buffer1 = T.buffer_decl([112], "uint8")
         buffer3 = T.buffer_decl([112], "uint8")
         buffer5 = T.buffer_decl([112], "uint8")
+        buffer7 = T.buffer_decl([112], "uint8")
         buffer9 = T.buffer_decl([592], "uint8")
         buffer10 = T.buffer_decl([160], "uint8")
-        buffer11 = T.buffer_decl([2048], "int8")
         # body
         p1_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True})
         p1 = T.buffer_decl([112], "uint8", data=p1_data)
@@ -357,21 +357,21 @@ class MixedReadU55:
         p2_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True})
         p2 = T.buffer_decl([112], "uint8", data=p2_data)
         T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 112, p1[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 592, T.int8(-1), T.int8(-1), 12, buffer10[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 592, T.int8(-1), T.int8(-1), 12, buffer10[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
         T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 112, p2[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
         T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 112, p1[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 112, p2[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 112, p2[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
     __tvm_meta__ = None
 
 
 @tvm.script.ir_module
 class MixedReadU65:
     @T.prim_func
-    def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"]) -> None:
+    def main(ifm: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
         # function attr dict
         T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
 
@@ -381,7 +381,7 @@ class MixedReadU65:
         buffer3 = T.buffer_decl([128], dtype="uint8")
         buffer4 = T.buffer_decl([608], dtype="uint8")
         buffer5 = T.buffer_decl([160], dtype="uint8")
-        buffer6 = T.buffer_decl([2048], dtype="int8")
+        buffer6 = T.buffer_decl([128], dtype="uint8")
         p1_data = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True})
         p1 = T.buffer_decl([128], "uint8", data=p1_data)
         p2_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True})
@@ -389,14 +389,14 @@ class MixedReadU65:
         p3_data = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True})
         p3 = T.buffer_decl([128], "uint8", data=p3_data)
         T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 128, p1[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer4[0], 304, buffer4[304], 304, 12, buffer5[0], 80, buffer5[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer4[0], 304, buffer4[304], 304, 12, buffer5[0], 80, buffer5[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
         T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p3[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
         T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 128, p1[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 128, p3[0], dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
-        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 128, p3[0], dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
     __tvm_meta__ = None
 # fmt: on
 
diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py b/tests/python/contrib/test_ethosu/test_merge_constants.py
index 337b5c70d1..a5adcfceac 100644
--- a/tests/python/contrib/test_ethosu/test_merge_constants.py
+++ b/tests/python/contrib/test_ethosu/test_merge_constants.py
@@ -441,6 +441,195 @@ def test_read_from_the_same_buffer():
     check_const_dictionaries(const_dict, new_const_dict)
 
 
+def test_arbitrary_argument_order():
+    # fmt: off
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            # buffer definition
+            T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
+            T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
+            # body
+            p1_data = T.allocate([368], "uint8", "global")
+            p1 = T.buffer_decl([368], "uint8", data=p1_data)
+            p2_data = T.allocate([96], "uint8", "global")
+            p2 = T.buffer_decl([96], "uint8", data=p2_data)
+            p3_data = T.allocate([368], "uint8", "global")
+            p3 = T.buffer_decl([368], "uint8", data=p3_data)
+            p4_data = T.allocate([96], "uint8", "global")
+            p4 = T.buffer_decl([96], "uint8", data=p4_data)
+            T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p2[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p2[0], 48, p2[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 368, p3[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p4[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 192, p3[192], 176, 12, p4[0], 48, p4[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        __tvm_meta__ = None
+
+
+    @tvm.script.ir_module
+    class ReferenceModule:
+        @T.prim_func
+        def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            # body
+            p1_data = T.allocate([464], "uint8", "global")
+            p1 = T.buffer_decl([464], "uint8", data=p1_data)
+            p2_data = T.allocate([464], "uint8", "global")
+            p2 = T.buffer_decl([464], "uint8", data=p2_data)
+            T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 464, p2[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+    __tvm_meta__ = None
+    # fmt: on
+
+    const_dict = {
+        1: np.array([1], dtype=np.uint8),
+        2: np.array([2], dtype=np.uint8),
+        4: np.array([4], dtype=np.uint8),
+        5: np.array([5], dtype=np.uint8),
+    }
+    new_const_dict = {
+        1: np.concatenate((const_dict[1], const_dict[2])),
+        3: np.concatenate((const_dict[4], const_dict[5])),
+    }
+    test_mod, const_dict = MergeConstants(const_dict)(InputModule)
+    reference_mod = ReferenceModule
+    tvm.ir.assert_structural_equal(test_mod, reference_mod, False)
+    check_const_dictionaries(const_dict, new_const_dict)
+
+
+def test_arbitrary_argument_order_const_split():
+    # fmt: off
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(96,), "uint8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            # buffer definition
+            T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
+            T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
+            # body
+            p1_data = T.allocate([368], "uint8", "global")
+            p1 = T.buffer_decl([368], "uint8", data=p1_data)
+            p2_data = T.allocate([96], "uint8", "global")
+            p2 = T.buffer_decl([96], "uint8", data=p2_data)
+            p3_data = T.allocate([368], "uint8", "global")
+            p3 = T.buffer_decl([368], "uint8", data=p3_data)
+            p4_data = T.allocate([96], "uint8", "global")
+            p4 = T.buffer_decl([96], "uint8", data=p4_data)
+            T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p2[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p2[0], 48, p2[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 368, p3[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p4[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 192, p3[192], 176, 12, p4[0], 48, p4[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        __tvm_meta__ = None
+
+
+    @tvm.script.ir_module
+    class ReferenceModule:
+        @T.prim_func
+        def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            # body
+            p1_data = T.allocate([464], "uint8", "global")
+            p1 = T.buffer_decl([464], "uint8", data=p1_data)
+            p2_data = T.allocate([464], "uint8", "global")
+            p2 = T.buffer_decl([464], "uint8", data=p2_data)
+            T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 464, p2[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+    __tvm_meta__ = None
+    # fmt: on
+
+    const_dict = {
+        1: np.array([1], dtype=np.uint8),
+        3: np.array([3], dtype=np.uint8),
+        4: np.array([4], dtype=np.uint8),
+        5: np.array([5], dtype=np.uint8),
+    }
+    new_const_dict = {
+        1: np.concatenate((const_dict[1], const_dict[3])),
+        3: np.concatenate((const_dict[4], const_dict[5])),
+    }
+    test_mod, const_dict = MergeConstants(const_dict)(InputModule)
+    reference_mod = ReferenceModule
+    tvm.ir.assert_structural_equal(test_mod, reference_mod, True)
+    check_const_dictionaries(const_dict, new_const_dict)
+
+
+def test_arbitrary_argument_order_const_split_mixed():
+    # fmt: off
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(368,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer3: T.Buffer[(96,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            # buffer definition
+            T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
+            T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
+            # body
+            p1_data = T.allocate([368], "uint8", "global")
+            p1 = T.buffer_decl([368], "uint8", data=p1_data)
+            p2_data = T.allocate([368], "uint8", "global")
+            p2 = T.buffer_decl([368], "uint8", data=p2_data)
+            p3_data = T.allocate([96], "uint8", "global")
+            p3 = T.buffer_decl([96], "uint8", data=p3_data)
+            p4_data = T.allocate([96], "uint8", "global")
+            p4 = T.buffer_decl([96], "uint8", data=p4_data)
+            T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 96, p3[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p3[0], 48, p3[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 368, p2[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p4[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p4[0], 48, p4[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+        __tvm_meta__ = None
+
+
+    @tvm.script.ir_module
+    class ReferenceModule:
+        @T.prim_func
+        def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], buffer2: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"]) -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+            # body
+            p1_data = T.allocate([464], "uint8", "global")
+            p1 = T.buffer_decl([464], "uint8", data=p1_data)
+            p2_data = T.allocate([464], "uint8", "global")
+            p2 = T.buffer_decl([464], "uint8", data=p2_data)
+            T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 464, p2[0], dtype="handle"))
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+    __tvm_meta__ = None
+    # fmt: on
+
+    const_dict = {
+        1: np.array([1], dtype=np.uint8),
+        2: np.array([2], dtype=np.uint8),
+        4: np.array([4], dtype=np.uint8),
+        5: np.array([5], dtype=np.uint8),
+    }
+    new_const_dict = {
+        1: np.concatenate((const_dict[1], const_dict[4])),
+        2: np.concatenate((const_dict[2], const_dict[5])),
+    }
+    test_mod, const_dict = MergeConstants(const_dict)(InputModule)
+    reference_mod = ReferenceModule
+    tvm.ir.assert_structural_equal(test_mod, reference_mod, True)
+    check_const_dictionaries(const_dict, new_const_dict)
+
+
 def test_cycle_count():
     # fmt: off
     @tvm.script.ir_module