You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2022/11/10 09:01:10 UTC
[tvm] branch main updated: [microNPU] Fixed MergeConstants pass on striped networks (#13281)
This is an automated email from the ASF dual-hosted git repository.
lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 54bd5e1f5f [microNPU] Fixed MergeConstants pass on striped networks (#13281)
54bd5e1f5f is described below
commit 54bd5e1f5fa52c498b4a4ff13d795daf52a81bfd
Author: Sergei Smirnov <89...@users.noreply.github.com>
AuthorDate: Thu Nov 10 12:01:05 2022 +0300
[microNPU] Fixed MergeConstants pass on striped networks (#13281)
This PR fixes the bug in MergeConstants pass on striped networks on Ethos-U NPU.
The issue was caused by _DivideConstants_ pass which is introducing new mod parameters and changing their order. So ethosu_write parameter in some cases is moved from the end of the list to the middle.
E.g. from:
`[ethos-u_0_i0, p1, p2, p3, p4, p5, p6, ethosu_write]`
To:
`[ethos-u_0_i0, p1, p2, ethosu_write, placeholder, placeholder, placeholder, placeholder, placeholder, placeholder, placeholder, placeholder]`
Updated version of the _GetArgsToMergeWithoutArgsNotInConstDict_ and _MakeNewConstDict_ methods in passes.cc can now correctly modify const_dict according to the new parameter list.
---
.../relay/backend/contrib/ethosu/tir/compiler.py | 5 +-
src/tir/contrib/ethosu/passes.cc | 27 ++-
.../contrib/test_ethosu/test_encode_constants.py | 32 ++--
.../contrib/test_ethosu/test_merge_constants.py | 189 +++++++++++++++++++++
4 files changed, 224 insertions(+), 29 deletions(-)
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
index aaac59ad4a..4133aff6ef 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
@@ -91,10 +91,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):
mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod)
mod = ethosu_passes.HoistAllocates()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
- # MergeConstant pass currently does not support striped schedules.
- # It requires further investigation.
- if not util.is_striping_enabled():
- mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod)
+ mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod)
mod = ethosu_passes.CopyComputeReordering()(mod)
# When striping is enabled and if storage_rewrite is not run
diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc
index 2f6fa8f3ea..d51ffbf833 100644
--- a/src/tir/contrib/ethosu/passes.cc
+++ b/src/tir/contrib/ethosu/passes.cc
@@ -514,7 +514,7 @@ class MergeConstantsMutator : public StmtExprMutator {
// Make the new const dict
Array<Array<IntImm>> args_to_merge{GetArgsToMerge(main_func->buffer_map, main_func->params)};
- Array<Array<IntImm>> buffers_to_merge{
+ Map<IntImm, Array<IntImm>> buffers_to_merge{
GetArgsToMergeWithoutArgsNotInConstDict(args_to_merge, const_dict)};
Map<IntImm, runtime::NDArray> new_const_dict{MakeNewConstDict(buffers_to_merge, const_dict)};
@@ -832,9 +832,11 @@ class MergeConstantsMutator : public StmtExprMutator {
return vector;
}
- Array<Array<IntImm>> GetArgsToMergeWithoutArgsNotInConstDict(
+ Map<IntImm, Array<IntImm>> GetArgsToMergeWithoutArgsNotInConstDict(
const Array<Array<IntImm>>& args_to_merge, const Map<IntImm, runtime::NDArray>& const_dict) {
- Array<Array<IntImm>> new_args_to_merge{};
+ Map<IntImm, Array<IntImm>> new_args_to_merge{};
+ bool first_arg_found = false;
+ int64_t new_arg_key = 0; // the updated key of the merged const_dict
for (Array<IntImm> args : args_to_merge) {
IntImm key{args[0]};
auto it = std::find_if(const_dict.begin(), const_dict.end(),
@@ -842,21 +844,29 @@ class MergeConstantsMutator : public StmtExprMutator {
return pair.first->value == key->value;
});
if (it != const_dict.end()) {
- new_args_to_merge.push_back(args);
+ if (first_arg_found == false) {
+ first_arg_found = true;
+ new_arg_key = key->value;
+ }
+ new_args_to_merge.Set(IntImm(DataType::Int(64), new_arg_key), args);
+ }
+ if (first_arg_found) {
+ new_arg_key++;
}
}
return new_args_to_merge;
}
- Map<IntImm, runtime::NDArray> MakeNewConstDict(const Array<Array<IntImm>>& args_to_merge,
+ Map<IntImm, runtime::NDArray> MakeNewConstDict(const Map<IntImm, Array<IntImm>>& args_to_merge,
Map<IntImm, runtime::NDArray> const_dict) {
Map<IntImm, runtime::NDArray> new_const_dict{};
if (args_to_merge.size() == 0) {
return new_const_dict;
}
- int64_t key = args_to_merge[0][0]->value;
- for (Array<IntImm> args : args_to_merge) {
+ for (auto const& elem : args_to_merge) {
+ IntImm key = elem.first;
+ Array<IntImm> args = elem.second;
int64_t size = 0;
for (IntImm arg : args) {
auto it = std::find_if(const_dict.begin(), const_dict.end(),
@@ -876,8 +886,7 @@ class MergeConstantsMutator : public StmtExprMutator {
arg_constant.CopyToBytes(static_cast<uint8_t*>(constant->data) + offset, nbytes);
offset += nbytes;
}
- new_const_dict.Set(IntImm(DataType::Int(64), key), constant);
- key += 1;
+ new_const_dict.Set(key, constant);
}
return new_const_dict;
}
diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py
index 6ffbf22312..c751d44b61 100644
--- a/tests/python/contrib/test_ethosu/test_encode_constants.py
+++ b/tests/python/contrib/test_ethosu/test_encode_constants.py
@@ -340,15 +340,15 @@ def test_direct_read_only(accelerator, reference_mod, reference_const_sizes):
@tvm.script.ir_module
class MixedReadU55:
@T.prim_func
- def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(112,), "uint8"]) -> None:
+ def main(ifm: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
buffer1 = T.buffer_decl([112], "uint8")
buffer3 = T.buffer_decl([112], "uint8")
buffer5 = T.buffer_decl([112], "uint8")
+ buffer7 = T.buffer_decl([112], "uint8")
buffer9 = T.buffer_decl([592], "uint8")
buffer10 = T.buffer_decl([160], "uint8")
- buffer11 = T.buffer_decl([2048], "int8")
# body
p1_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True})
p1 = T.buffer_decl([112], "uint8", data=p1_data)
@@ -357,21 +357,21 @@ class MixedReadU55:
p2_data = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin":True})
p2 = T.buffer_decl([112], "uint8", data=p2_data)
T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 112, p1[0], dtype="handle"))
- T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 592, T.int8(-1), T.int8(-1), 12, buffer10[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer9[0], 592, T.int8(-1), T.int8(-1), 12, buffer10[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 112, p2[0], dtype="handle"))
- T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 112, p1[0], dtype="handle"))
- T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
- T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 112, p2[0], dtype="handle"))
- T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
- T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer11[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_copy", buffer7[0], 112, p2[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 80, T.int8(-1), T.int8(-1), 12, p1[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 80, T.int8(-1), T.int8(-1), 12, p2[80], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
__tvm_meta__ = None
@tvm.script.ir_module
class MixedReadU65:
@T.prim_func
- def main(placeholder: T.Buffer[(8192,), "int8"], buffer_encoded: T.Buffer[(128,), "uint8"]) -> None:
+ def main(ifm: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
@@ -381,7 +381,7 @@ class MixedReadU65:
buffer3 = T.buffer_decl([128], dtype="uint8")
buffer4 = T.buffer_decl([608], dtype="uint8")
buffer5 = T.buffer_decl([160], dtype="uint8")
- buffer6 = T.buffer_decl([2048], dtype="int8")
+ buffer6 = T.buffer_decl([128], dtype="uint8")
p1_data = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True})
p1 = T.buffer_decl([128], "uint8", data=p1_data)
p2_data = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True})
@@ -389,14 +389,14 @@ class MixedReadU65:
p3_data = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True})
p3 = T.buffer_decl([128], "uint8", data=p3_data)
T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 128, p1[0], dtype="handle"))
- T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer4[0], 304, buffer4[304], 304, 12, buffer5[0], 80, buffer5[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, ifm[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer4[0], 304, buffer4[304], 304, 12, buffer5[0], 80, buffer5[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 128, p3[0], dtype="handle"))
- T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 128, p1[0], dtype="handle"))
- T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
- T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 128, p3[0], dtype="handle"))
- T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
- T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, buffer6[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 128, p3[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 48, p1[48], 48, 12, p1[96], 16, p1[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, p2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 48, p3[48], 48, 12, p3[96], 16, p3[112], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
__tvm_meta__ = None
# fmt: on
diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py b/tests/python/contrib/test_ethosu/test_merge_constants.py
index 337b5c70d1..a5adcfceac 100644
--- a/tests/python/contrib/test_ethosu/test_merge_constants.py
+++ b/tests/python/contrib/test_ethosu/test_merge_constants.py
@@ -441,6 +441,195 @@ def test_read_from_the_same_buffer():
check_const_dictionaries(const_dict, new_const_dict)
+def test_arbitrary_argument_order():
+ # fmt: off
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(96,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None:
+ # function attr dict
+ T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+ # buffer definition
+ T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
+ T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
+ # body
+ p1_data = T.allocate([368], "uint8", "global")
+ p1 = T.buffer_decl([368], "uint8", data=p1_data)
+ p2_data = T.allocate([96], "uint8", "global")
+ p2 = T.buffer_decl([96], "uint8", data=p2_data)
+ p3_data = T.allocate([368], "uint8", "global")
+ p3 = T.buffer_decl([368], "uint8", data=p3_data)
+ p4_data = T.allocate([96], "uint8", "global")
+ p4 = T.buffer_decl([96], "uint8", data=p4_data)
+ T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p2[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p2[0], 48, p2[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 368, p3[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p4[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 192, p3[192], 176, 12, p4[0], 48, p4[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ __tvm_meta__ = None
+
+
+ @tvm.script.ir_module
+ class ReferenceModule:
+ @T.prim_func
+ def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None:
+ # function attr dict
+ T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+ # body
+ p1_data = T.allocate([464], "uint8", "global")
+ p1 = T.buffer_decl([464], "uint8", data=p1_data)
+ p2_data = T.allocate([464], "uint8", "global")
+ p2 = T.buffer_decl([464], "uint8", data=p2_data)
+ T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 464, p2[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ __tvm_meta__ = None
+ # fmt: on
+
+ const_dict = {
+ 1: np.array([1], dtype=np.uint8),
+ 2: np.array([2], dtype=np.uint8),
+ 4: np.array([4], dtype=np.uint8),
+ 5: np.array([5], dtype=np.uint8),
+ }
+ new_const_dict = {
+ 1: np.concatenate((const_dict[1], const_dict[2])),
+ 3: np.concatenate((const_dict[4], const_dict[5])),
+ }
+ test_mod, const_dict = MergeConstants(const_dict)(InputModule)
+ reference_mod = ReferenceModule
+ tvm.ir.assert_structural_equal(test_mod, reference_mod, False)
+ check_const_dictionaries(const_dict, new_const_dict)
+
+
+def test_arbitrary_argument_order_const_split():
+ # fmt: off
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(96,), "uint8"], buffer3: T.Buffer[(368,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None:
+ # function attr dict
+ T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+ # buffer definition
+ T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
+ T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
+ # body
+ p1_data = T.allocate([368], "uint8", "global")
+ p1 = T.buffer_decl([368], "uint8", data=p1_data)
+ p2_data = T.allocate([96], "uint8", "global")
+ p2 = T.buffer_decl([96], "uint8", data=p2_data)
+ p3_data = T.allocate([368], "uint8", "global")
+ p3 = T.buffer_decl([368], "uint8", data=p3_data)
+ p4_data = T.allocate([96], "uint8", "global")
+ p4 = T.buffer_decl([96], "uint8", data=p4_data)
+ T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 96, p2[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p2[0], 48, p2[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 368, p3[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p4[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p3[0], 192, p3[192], 176, 12, p4[0], 48, p4[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ __tvm_meta__ = None
+
+
+ @tvm.script.ir_module
+ class ReferenceModule:
+ @T.prim_func
+ def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer2: T.Buffer[(464,), "uint8"]) -> None:
+ # function attr dict
+ T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+ # body
+ p1_data = T.allocate([464], "uint8", "global")
+ p1 = T.buffer_decl([464], "uint8", data=p1_data)
+ p2_data = T.allocate([464], "uint8", "global")
+ p2 = T.buffer_decl([464], "uint8", data=p2_data)
+ T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 464, p2[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ __tvm_meta__ = None
+ # fmt: on
+
+ const_dict = {
+ 1: np.array([1], dtype=np.uint8),
+ 3: np.array([3], dtype=np.uint8),
+ 4: np.array([4], dtype=np.uint8),
+ 5: np.array([5], dtype=np.uint8),
+ }
+ new_const_dict = {
+ 1: np.concatenate((const_dict[1], const_dict[3])),
+ 3: np.concatenate((const_dict[4], const_dict[5])),
+ }
+ test_mod, const_dict = MergeConstants(const_dict)(InputModule)
+ reference_mod = ReferenceModule
+ tvm.ir.assert_structural_equal(test_mod, reference_mod, True)
+ check_const_dictionaries(const_dict, new_const_dict)
+
+
+def test_arbitrary_argument_order_const_split_mixed():
+ # fmt: off
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(368,), "uint8"], buffer2: T.Buffer[(368,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"], buffer3: T.Buffer[(96,), "uint8"], buffer4: T.Buffer[(96,), "uint8"]) -> None:
+ # function attr dict
+ T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+ # buffer definition
+ T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
+ T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
+ # body
+ p1_data = T.allocate([368], "uint8", "global")
+ p1 = T.buffer_decl([368], "uint8", data=p1_data)
+ p2_data = T.allocate([368], "uint8", "global")
+ p2 = T.buffer_decl([368], "uint8", data=p2_data)
+ p3_data = T.allocate([96], "uint8", "global")
+ p3 = T.buffer_decl([96], "uint8", data=p3_data)
+ p4_data = T.allocate([96], "uint8", "global")
+ p4 = T.buffer_decl([96], "uint8", data=p4_data)
+ T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 368, p1[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 96, p3[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p3[0], 48, p3[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 368, p2[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p4[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p4[0], 48, p4[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ __tvm_meta__ = None
+
+
+ @tvm.script.ir_module
+ class ReferenceModule:
+ @T.prim_func
+ def main(placeholder: T.Buffer[(8192,), "int8"], buffer1: T.Buffer[(464,), "uint8"], buffer2: T.Buffer[(464,), "uint8"], ethosu_write: T.Buffer[(4096,), "int8"]) -> None:
+ # function attr dict
+ T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+ # body
+ p1_data = T.allocate([464], "uint8", "global")
+ p1 = T.buffer_decl([464], "uint8", data=p1_data)
+ p2_data = T.allocate([464], "uint8", "global")
+ p2 = T.buffer_decl([464], "uint8", data=p2_data)
+ T.evaluate(T.call_extern("ethosu_copy", buffer1[0], 464, p1[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1[0], 192, p1[192], 176, 12, p1[368], 48, p1[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 464, p2[0], dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[2048], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p2[0], 192, p2[192], 176, 12, p2[368], 48, p2[416], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
+ __tvm_meta__ = None
+ # fmt: on
+
+ const_dict = {
+ 1: np.array([1], dtype=np.uint8),
+ 2: np.array([2], dtype=np.uint8),
+ 4: np.array([4], dtype=np.uint8),
+ 5: np.array([5], dtype=np.uint8),
+ }
+ new_const_dict = {
+ 1: np.concatenate((const_dict[1], const_dict[4])),
+ 2: np.concatenate((const_dict[2], const_dict[5])),
+ }
+ test_mod, const_dict = MergeConstants(const_dict)(InputModule)
+ reference_mod = ReferenceModule
+ tvm.ir.assert_structural_equal(test_mod, reference_mod, True)
+ check_const_dictionaries(const_dict, new_const_dict)
+
+
def test_cycle_count():
# fmt: off
@tvm.script.ir_module