You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2021/12/01 15:45:00 UTC

[tvm] branch main updated: [microNPU] Move the compilation to use Target Hooks. (#9597)

This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c22d80d  [microNPU] Move the compilation to use Target Hooks. (#9597)
c22d80d is described below

commit c22d80d44eabef5df6ca80c18cc3b274e63f2fdc
Author: Manupa Karunaratne <ma...@arm.com>
AuthorDate: Wed Dec 1 15:44:34 2021 +0000

    [microNPU] Move the compilation to use Target Hooks. (#9597)
    
    * [microNPU] Move the compilation to use Target Hooks.
    
    This commits moves the current compilation flow
    to use target hooks, so that the generated TIR
    is provided to unified module to for unified
    optimizations.
    
    Change-Id: Ib3239a04ab201748e7f1b1ffa503cfe2aa7ccb7b
    
    * [microNPU] Move the compilation to use Target Hooks.
    
    *Fixing unpacked API tests
    *Adding use_device_api target attr to example target hooks
    
    Change-Id: I72c51caa57e9a0c2a538f40eb73939e28d4f112f
    
    * [microNPU] Move the compilation to use Target Hooks.
    
    * Modifed CLZ test case to support target hooks
    * Modifed reference TIR for test to include allocate annotation
    * TIR to CS translation tests are modified to run MakeUnpackedAPI
    
    Change-Id: I3a3d28777a6995e7f2b8789e14c5cb0f280dc763
    
    * [microNPU] Move the compilation to use Target Hooks.
    
    * Added a missed documentation to changes in source module
    * Skipping device api test for packed API as microNPU does not
      support it.
    
    Change-Id: I6da1adcf8fdd3f972ec9b37ff530ff673e93058c
    
    * [microNPU] Move the compilation to use Target Hooks.
    
    * fixed tvmc test use unpacked-api for microNPU compilation
    
    Change-Id: Ib722d91ca3b3e4c6d13075ee0873acb86f487247
    
    * [microNPU] Move the compilation to use Target Hooks.
    
    * adjust target name.
    
    Change-Id: I862957324440705fb6093939b97b1a00fa1d4b46
    
    * [microNPU] follow up on using target hooks
    
    * Fixed few typos and cleaned up as per suggestions
    
    Change-Id: I2a744a4bc4015e1884dbef4165252aa13aa30b31
    
    * [microNPU] follow up on using target hooks
    
    Fixing some typos and change params to
    const_dict as it seems more clearer
    
    Change-Id: Ia36a4635a68f6490bcc3eeaa72eeeeaadb6aa7f6
    
    * [microNPU] Move the compilation to use Target Hooks.
    
    Fixing up lookup table tests to use new runtime module
    import structure resulted from using target hooks.
    
    Change-Id: I250aedef7cc73edad3812bb7e9aab013ed8bed5b
---
 include/tvm/tir/transform.h                        |   5 +
 python/tvm/relay/backend/contrib/ethosu/codegen.py |  98 ++++++++----
 .../relay/backend/contrib/ethosu/tir/compiler.py   |   3 +-
 .../tvm/relay/backend/contrib/ethosu/tir/passes.py |  31 ++++
 .../backend/contrib/ethosu/tir_to_cs_translator.py |  22 +--
 python/tvm/relay/backend/contrib/ethosu/util.py    |  30 ++++
 python/tvm/runtime/object_generic.py               |   8 +-
 src/relay/backend/contrib/ethosu/codegen.cc        | 136 ++++++++++++++++
 src/relay/backend/contrib/ethosu/source_module.cc  | 123 +++++++-------
 src/relay/backend/contrib/ethosu/utils.cc          |  75 +++++++++
 src/relay/backend/contrib/ethosu/utils.h           |  96 +++++++++++
 .../backend/contrib/example_target_hooks/target.cc |   1 +
 src/relay/backend/te_compiler.cc                   |   3 +
 src/relay/backend/te_compiler.h                    |   1 +
 src/target/target_kind.cc                          |   1 -
 src/tir/transforms/lower_tvm_builtin.cc            |   5 +
 src/tir/transforms/make_unpacked_api.cc            |  20 +--
 tests/python/contrib/test_ethosu/test_codegen.py   | 176 +++++++++------------
 .../contrib/test_ethosu/test_encode_constants.py   |  16 +-
 .../contrib/test_ethosu/test_lookup_table.py       |  22 +--
 .../contrib/test_ethosu/test_replace_conv2d.py     |   8 +-
 .../contrib/test_ethosu/test_replace_copy.py       |   8 +-
 .../test_ethosu/test_tir_to_cs_translator.py       |  20 +--
 tests/python/driver/tvmc/test_compiler.py          |   4 +-
 tests/python/relay/aot/test_c_device_api.py        |   5 +-
 .../test_tir_transform_make_unpacked_api.py        |  37 ++---
 26 files changed, 652 insertions(+), 302 deletions(-)

diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 7922e97..7a6cfa3 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -288,6 +288,11 @@ TVM_DLL Pass LowerThreadAllreduce();
 TVM_DLL Pass InferFragment();
 
 /*!
+ * \brief This annotation is for nodes to be disabled for builtin lowering
+ */
+static constexpr const char* kDisableLowerTVMBuiltin = "disable_lower_builtin";
+
+/*!
  * \brief Lower builtin intrinsics.
  * \return The pass.
  */
diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py
index e51f170..3b412cb 100644
--- a/python/tvm/relay/backend/contrib/ethosu/codegen.py
+++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py
@@ -14,7 +14,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Codegen for Arm(R) Ethos(TM)-U"""
+"""Codegen for Arm(R) Ethos(TM)-U NPU"""
+
 import tvm
 from tvm import relay
 from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
@@ -133,24 +134,6 @@ class LUTsOptimizer(Pass):
         return OptimizeLUTs().visit(func)
 
 
-@tvm._ffi.register_func("relay.ext.ethos-u")
-def ethosu_compiler(external_function):
-    """The entry-point to a compile a external relay function of
-    NPU compatible operators to generated command stream.
-    Such generated command stream would be used to create c-source r
-    runtime module that interfaces with NPU driver.
-    """
-    assert isinstance(external_function, tvm.ir.function.BaseFunc)
-    func_name = external_function.attrs["global_symbol"]
-    # There should only be a single input
-    assert len(external_function.params) == 1
-    input_size = util.calculate_size_bytes(external_function.params[0])
-    output_size = util.calculate_size_bytes(external_function.body)
-    cmms, encoded_constants, scratch_size = _compile(external_function)
-    ethosu_runtime = tvm._ffi.get_global_func("runtime.module.ethos-u.create")
-    return ethosu_runtime(func_name, cmms, encoded_constants, scratch_size, input_size, output_size)
-
-
 @tvm._ffi.register_func("relay.ext.ethos-u.constant_updater")
 def constant_updater(expr, symbol):  # pylint: disable=unused-argument
     """
@@ -161,25 +144,25 @@ def constant_updater(expr, symbol):  # pylint: disable=unused-argument
     return dict()
 
 
-def _compile(ext_func):
+@tvm._ffi.register_func("relay.ext.ethos-u.relay_to_tir_func")
+def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc:
     """
-    This is the main wrapper that accepts an external
-    relay function and runs all the passes to lower it down
-    to command stream
+    This is the hook for python-based lowering of relay function
+    that gets offloaded to the microNPU.
+
     Parameters
     ----------
-    ext_func : tvm.relay.function.Function
-        The partitioned relay function
+    ext_func : relay.Function
+        This is the partitioned relay function
+
     Returns
     -------
-    cs : str
-        An hex string of the bytes of command stream
-    encoded_constants : str
-        An hex string of the bytes that includes concat'd
-        encoded weights, encoded biases and scales.
-    scratch_size : int
-        The size of the scratch buffer needed.
+    primfunc : tir.PrimFunc
+        This returns the scheduled PrimFunc
     """
+    assert len(ext_func.params) == 1
+    input_size = util.calculate_size_bytes(ext_func.params[0])
+    output_size = util.calculate_size_bytes(ext_func.body)
     mod = tvm.IRModule()
     mod["main"] = ext_func
     mod = LegalizeEthosU()(mod)
@@ -189,6 +172,51 @@ def _compile(ext_func):
     # this should be a single intelligent and a composite scheduler
     # that can perform scheduling based on user inputs such as
     # scratch memory size.
-    tir_mod, params = lower_to_tir(mod["main"], copy_constants())
-    cmms, encoded_constants, scratch_size = tir_to_cs_translator.translate(tir_mod, params)
-    return cmms, encoded_constants, scratch_size
+    tir_mod, const_dict = lower_to_tir(mod["main"], copy_constants())
+
+    for idx in const_dict.keys():
+        const_dict[idx] = tvm.nd.array(const_dict[idx])
+
+    primfunc = tir_mod["main"]
+    primfunc = primfunc.with_attr("global_symbol", ext_func.attrs["global_symbol"])
+    primfunc = primfunc.with_attr("ethos-u.constants", const_dict)
+    primfunc = primfunc.with_attr("ethos-u.input_size", input_size)
+    primfunc = primfunc.with_attr("ethos-u.output_size", output_size)
+    return primfunc
+
+
+@tvm._ffi.register_func("relay.ext.ethos-u.primfunc_to_artifact")
+def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact:
+    """
+    This is the hook for python-based lowering of TIR PrimFunc
+    that has undergone unified optimization to compilation
+    artifact destined for the microNPU.
+
+    Parameters
+    ----------
+    primfunc : tir.PrimFunc
+        TIR PrimFunc that has undergone unified optimizations
+
+    Returns
+    -------
+    CompilationArtifact
+        This is a structure that holds the binary artifacts
+        for the microNPU
+    """
+    symbol = str(primfunc.attrs["global_symbol"])
+    const_dict = primfunc.attrs["ethos-u.constants"]
+    input_size = primfunc.attrs["ethos-u.input_size"]
+    output_size = primfunc.attrs["ethos-u.output_size"]
+    tir_mod = tvm.IRModule()
+    tir_mod[symbol] = primfunc
+
+    const_dict_with_int_keys = dict()
+    for idx in const_dict.keys():
+        const_dict_with_int_keys[int(idx)] = const_dict[idx].numpy()
+
+    cmms, encoded_constants, scratch_size = tir_to_cs_translator.translate(
+        tir_mod, const_dict_with_int_keys
+    )
+    return util.CompilationArtifact(
+        cmms, encoded_constants, scratch_size, input_size, output_size, symbol
+    )
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
index b68a5ad..b3ffecb 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
@@ -21,7 +21,7 @@ from tvm import relay
 from tvm.relay.expr_functor import ExprMutator
 from tvm.driver.build_module import schedule_to_module
 
-from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants
+from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants, AnnotateAllocates
 from .scheduler import schedule
 
 
@@ -88,6 +88,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):
         mod, const_dict = EncodeConstants(const_dict)(mod)
         mod = tvm.tir.transform.StorageRewrite()(mod)
         mod = tvm.tir.transform.RemoveNoOp()(mod)
+        mod = AnnotateAllocates()(mod)
     return mod, const_dict
 
 
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
index cb46ba3..41a6832 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
@@ -488,3 +488,34 @@ def EncodeConstants(const_dict):
         return new_func, new_const_dict
 
     return _encode_constants
+
+
+# This need to be kept in sync with kDisableLowerTVMBuiltin in include/tvm/tir/transform.h
+DISABLE_LOWER_BUILTIN = "disable_lower_builtin"
+
+
+def AnnotateAllocates():
+    """
+    This is pass to annotate all allocate
+    nodes of the PrimFuncs of the microNPU
+    to be not lowered to built-ins.
+    """
+
+    def _post_transform(allocate):
+        return tvm.tir.Allocate(
+            buffer_var=allocate.buffer_var,
+            dtype=allocate.dtype,
+            extents=allocate.extents,
+            condition=allocate.condition,
+            body=allocate.body,
+            annotations={DISABLE_LOWER_BUILTIN: True},
+        )
+
+    def _ftransform(f, mod, ctx):
+        return f.with_body(
+            tvm.tir.stmt_functor.ir_transform(f.body, None, _post_transform, ["tir.Allocate"])
+        )
+
+    return tvm.tir.transform.prim_func_pass(
+        _ftransform, opt_level=0, name="tir.ethosu.annotate_allocates"
+    )
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
index e1af7f1..c8e3d34 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
@@ -173,16 +173,16 @@ def extract_buffer_info(
     primfunc = mod.functions.items()[0][1]
     for idx, const_data in param_dict.items():
         param = primfunc.params[idx]
-        buffer_info[primfunc.buffer_map[param].data] = BufferInfo(
+        buffer_info[param] = BufferInfo(
             const_data, const_data.shape, const_data.dtype, BufferType.constant
         )
 
     for param in primfunc.params:
-        if primfunc.buffer_map[param].data not in buffer_info.keys():
-            buffer_info[primfunc.buffer_map[param].data] = BufferInfo(
+        if param not in buffer_info.keys():
+            buffer_info[param] = BufferInfo(
+                None,
+                None,
                 None,
-                primfunc.buffer_map[param].shape,
-                primfunc.buffer_map[param].dtype,
                 BufferType.input_or_output,
             )
 
@@ -253,7 +253,7 @@ def assign_addresses(buffer_info, npu_ops):
     def replace_npu_address_range_with_address(npu_addr_range):
         assert isinstance(npu_addr_range.address, tvm.tir.Load)
         buffer = npu_addr_range.address.buffer_var
-        assert buffer in buffer_addresses.keys()
+        assert buffer in buffer_addresses.keys(), f"searching for buffer : {buffer}, but not found"
         address, buffer_type = buffer_addresses[buffer]
         return vapi.NpuAddressRange(_REGION_MAP[buffer_type], address, npu_addr_range.length)
 
@@ -299,11 +299,6 @@ def assign_addresses(buffer_info, npu_ops):
                 size_in_bytes = util.round_up(size_in_bytes, 16)
                 constant_tensor = np.append(constant_tensor, np.resize(info.values, size_in_bytes))
         else:
-            size_in_bytes = int(
-                (np.iinfo(np.dtype(info.dtype)).bits // 8) * np.prod(list(info.shape))
-            )
-            # Every memory address the NPU access have to be 16 byte aligned
-            size_in_bytes = util.round_up(size_in_bytes, 16)
             if info.btype == BufferType.input_or_output:
                 buffer_type = classify_io(_buffer)
                 assert buffer_type in (BufferType.input, BufferType.output)
@@ -315,6 +310,11 @@ def assign_addresses(buffer_info, npu_ops):
                 address = arch_config.lut_start_address
                 buffer_addresses[_buffer] = (address, info.btype)
             else:
+                size_in_bytes = int(
+                    (np.iinfo(np.dtype(info.dtype)).bits // 8) * np.prod(list(info.shape))
+                )
+                # Every memory address the NPU access have to be 16 byte aligned
+                size_in_bytes = util.round_up(size_in_bytes, 16)
                 assert info.btype == BufferType.scratch
                 address = scratch_size
                 scratch_size += size_in_bytes
diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py
index 589ab21..45a82d5 100644
--- a/python/tvm/relay/backend/contrib/ethosu/util.py
+++ b/python/tvm/relay/backend/contrib/ethosu/util.py
@@ -28,6 +28,9 @@ import numpy as np  # type: ignore
 
 import tvm  # type: ignore
 from tvm import relay
+from tvm._ffi import register_object
+from tvm.runtime import Object
+from . import _ffi_api
 
 
 class QConv2DArgs(Enum):
@@ -209,3 +212,30 @@ def calculate_size_bytes(expr):
     element_size = type_info.bits // 8
     elements = np.prod(list(expr.checked_type.shape))
     return element_size * elements
+
+
+@register_object("relay.ext.ethos-u.CompilationArtifact")
+class CompilationArtifact(Object):
+    """
+    This is a structure to hold binary artifacts
+    for the microNPU.
+    """
+
+    def __init__(
+        self,
+        command_stream: str,
+        encoded_constants: str,
+        scratch_size: int,
+        input_size: int,
+        output_size: int,
+        function_name: str,
+    ):
+        self.__init_handle_by_constructor__(
+            _ffi_api.CompilationArtifact,  # type: ignore # pylint: disable=no-member
+            command_stream,
+            encoded_constants,
+            scratch_size,
+            input_size,
+            output_size,
+            function_name,
+        )
diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py
index 974523d..7a55d3e 100644
--- a/python/tvm/runtime/object_generic.py
+++ b/python/tvm/runtime/object_generic.py
@@ -68,9 +68,13 @@ def convert_to_object(value, span=None):
     if isinstance(value, dict):
         vlist = []
         for item in value.items():
-            if not isinstance(item[0], ObjectTypes) and not isinstance(item[0], string_types):
+            if (
+                not isinstance(item[0], ObjectTypes)
+                and not isinstance(item[0], string_types)
+                and not isinstance(item[0], Number)
+            ):
                 raise ValueError("key of map must already been a container type")
-            vlist.append(item[0])
+            vlist.append(convert_to_object(item[0]))
             vlist.append(convert_to_object(item[1]))
         return _ffi_api.Map(*vlist)
     if isinstance(value, ObjectGeneric):
diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc
new file mode 100644
index 0000000..d618a49
--- /dev/null
+++ b/src/relay/backend/contrib/ethosu/codegen.cc
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/backend/contrib/ethosu/codegen.cc
+ *
+ * \brief This file contains the target hooks for Arm(R) Ethos(TM)-U NPU
+ * Codegen.
+ */
+
+#include <tvm/ir/error.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/target/target.h>
+#include <tvm/tir/function.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "../../../op/make_op.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace ethosu {
+
+/*!
+ * \brief This mutator lowers each external
+ * relay function to a TIR PrimFunc
+ */
+class RelayToTIRMutator : public MixedModeMutator {
+ public:
+  explicit RelayToTIRMutator(IRModule ir_module) : ir_module_(ir_module) {}
+
+  IRModule operator()() {
+    GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
+    Function main_func = Downcast<Function>(ir_module_->Lookup(main_global_var));
+
+    // Copy everything across and mutate the body
+    Function mutated_main =
+        Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type,
+                 main_func->type_params, main_func->attrs, main_func->span);
+
+    ir_module_->Update(main_global_var, mutated_main);
+    ir_module_ = WithAttr(ir_module_, "device_contexts", device_contexts_);
+    return ir_module_;
+  }
+
+  Expr Rewrite_(const CallNode* pre, const Expr& post) override {
+    Call call = Downcast<Call>(post);
+    if (call->op->IsInstance<FunctionNode>()) {
+      Function func = Downcast<Function>(call->op);
+      auto codegen_name = func->GetAttr<String>(attr::kCompiler);
+      if (codegen_name.defined() && codegen_name == "ethos-u") {
+        auto relay_to_tir_func_pf =
+            tvm::runtime::Registry::Get("relay.ext.ethos-u.relay_to_tir_func");
+        ICHECK(relay_to_tir_func_pf);
+        tir::PrimFunc prim_func = (*relay_to_tir_func_pf)(func);
+        prim_func = WithAttr(prim_func, tvm::attr::kTarget, Target("ethos-u"));
+        String symbol_name = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol).value();
+        GlobalVar gv(symbol_name);
+        Array<RelayExpr> args = call->args;
+        gv->checked_type_ = func->checked_type();
+        ir_module_->Update(gv, prim_func);
+        device_contexts_.Set(gv, codegen_name.value());
+        return Call(gv, args, call->attrs, call->type_args);
+      }
+    }
+    return post;
+  }
+
+ private:
+  IRModule ir_module_;
+  Map<GlobalVar, String> device_contexts_;
+};
+
+tvm::transform::Pass RelayToTIR() {
+  runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
+      [=](IRModule ir_module, transform::PassContext pass_context) {
+        return RelayToTIRMutator(ir_module)();
+      };
+  return tvm::transform::CreateModulePass(pass_func, 0, "relay.contrib.ethos-u.RelayToTIR", {});
+}
+
+/*!
+ * \brief This function lowers the IRModule with PrimFunc
+ * with the target of the microNPU to a C-source runtime module
+ */
+runtime::Module TIRToRuntime(IRModule mod, Target target) {
+  Array<CompilationArtifact> compile_artifacts;
+  for (const auto& kv : mod->functions) {
+    const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(kv.second);
+    Optional<Map<Integer, runtime::NDArray>> params =
+        prim_func->GetAttr<Map<Integer, runtime::NDArray>>("ethos-u.constants");
+    ICHECK(params) << "microNPU params should be present";
+    auto primfunc_to_artifact_pf =
+        tvm::runtime::Registry::Get("relay.ext.ethos-u.primfunc_to_artifact");
+    ICHECK(primfunc_to_artifact_pf);
+    CompilationArtifact ca = (*primfunc_to_artifact_pf)(prim_func);
+    compile_artifacts.push_back(ca);
+  }
+  auto ca_to_runtime = tvm::runtime::Registry::Get("runtime.module.ethos-u.create");
+  return (*ca_to_runtime)(compile_artifacts);
+}
+
+TVM_REGISTER_TARGET_KIND("ethos-u", kDLCPU)
+    .set_attr<Bool>("use_device_api", Bool(true))
+    .set_attr<FTVMRelayToTIR>("RelayToTIR", RelayToTIR())
+    .set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime);
+
+}  // namespace ethosu
+}  // namespace contrib
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/backend/contrib/ethosu/source_module.cc b/src/relay/backend/contrib/ethosu/source_module.cc
index b7b359a..f56544a 100644
--- a/src/relay/backend/contrib/ethosu/source_module.cc
+++ b/src/relay/backend/contrib/ethosu/source_module.cc
@@ -41,34 +41,35 @@
 #include <vector>
 
 #include "../../../../runtime/file_utils.h"
+#include "utils.h"
 
 namespace tvm {
 namespace runtime {
 
+using CompilationArtifact = relay::contrib::ethosu::CompilationArtifact;
+
 // The runtime.Module that contains the host-side c code
 // required for invoking the NPU with the command stream
 class EthosUModuleNode : public ModuleNode {
  public:
   /*!
-   * \brief The ethos runtime module.
+   * \brief The microNPU runtime module.
    *
-   * \param func_name_ name of the should be codegen'd function
-   * \param cmms_hex_ command stream for the NPU in hex
-   * \param weights_bias_hex_ the encoded biases and weights for the NPU in hex
-   * \param scratch_size_ the size of the scratch memory required for command stream
-   * \param input_size_ the size (in bytes) for the input tensor
-   * \param output_size_ the size (in bytes) for the output tensor
+   * \param compilation_artifacts
+   *    This is an array of CompilationArtifacts that is produced via
+   *    lowering each PrimFunc to command stream. Here, those artifacts
+   *    will be used to create the c-source.
    */
-  explicit EthosUModuleNode(const String& func_name_, const String& cmms_hex_,
-                            const String& weights_bias_hex_, const Integer& scratch_size_,
-                            const Integer& input_size_, const Integer& output_size_) {
-    func_name = func_name_;
-    cmms_hex = std::move(cmms_hex_);
-    weights_bias_hex = std::move(weights_bias_hex_);
-    scratch_size = scratch_size_->value;
-    input_size = input_size_->value;
-    output_size = output_size_->value;
-    c_source = GenerateSource();
+  explicit EthosUModuleNode(Array<CompilationArtifact> compilation_artifacts)
+      : compilation_artifacts_(compilation_artifacts) {
+    c_source += "#include <stdio.h>\n";
+    c_source += "#include <stdlib.h>\n";
+    c_source += "#include <tvm/runtime/crt/module.h>\n";
+    c_source += "#include <tvm_ethosu_runtime.h>\n\n";
+    for (const CompilationArtifact& compilation_artifact : compilation_artifacts) {
+      c_source += GenerateSource(compilation_artifact);
+      c_source += "\n\n";
+    }
   }
 
   /*!
@@ -79,7 +80,6 @@ class EthosUModuleNode : public ModuleNode {
    */
   void SaveToFile(const std::string& file_name, const std::string& format) final {
     std::string fmt = GetFileFormat(file_name, format);
-    LOG(INFO) << "format=" << fmt << ";;\n";
     ICHECK_EQ(fmt, "c") << "Can only save to format="
                         << "c";
     std::ofstream out(file_name);
@@ -89,7 +89,7 @@ class EthosUModuleNode : public ModuleNode {
 
   std::string GetSource(const std::string& format) final { return c_source; }
 
-  std::string GetCS() { return cmms_hex; }
+  Array<CompilationArtifact> GetArtifacts() { return compilation_artifacts_; }
 
   /*!
    * \brief Get a PackedFunc from the module.
@@ -102,7 +102,11 @@ class EthosUModuleNode : public ModuleNode {
   PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
     if (name == "get_func_names") {
       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-        *rv = Array<String>{this->func_name};
+        Array<String> func_names;
+        for (const CompilationArtifact& ca : compilation_artifacts_) {
+          func_names.push_back(ca->function_name);
+        }
+        *rv = func_names;
       });
     }
     return PackedFunc();
@@ -110,21 +114,14 @@ class EthosUModuleNode : public ModuleNode {
 
   const char* type_key() const override { return "c"; }
 
-  static Module Create(String func_name, String cmms_hex, String weights_bias_hex,
-                       Integer scratch_size, Integer input_size, Integer output_size) {
-    auto n = make_object<EthosUModuleNode>(func_name, cmms_hex, weights_bias_hex, scratch_size,
-                                           input_size, output_size);
+  static Module Create(Array<CompilationArtifact> compilation_artifacts) {
+    auto n = make_object<EthosUModuleNode>(compilation_artifacts);
     return Module(n);
   }
 
  private:
-  String c_source;
-  String func_name;
-  String cmms_hex;
-  String weights_bias_hex;
-  size_t scratch_size;
-  size_t input_size;
-  size_t output_size;
+  std::string c_source;
+  Array<CompilationArtifact> compilation_artifacts_;
   int indent_{0};
 
   /*!
@@ -151,10 +148,10 @@ class EthosUModuleNode : public ModuleNode {
    * \return string of code that updates the base_addrs array with the base address of the given
    * array
    */
-  std::string SetBaseAddress(int index, std::string name) {
+  std::string SetBaseAddress(int index, std::string name, std::string size) {
     std::stringstream ss;
     ss << "  base_addrs[" << index << "] = (uintptr_t)(" << name << ");\n";
-    ss << "  base_addrs_size[" << index << "] = " << name << "_size;\n";
+    ss << "  base_addrs_size[" << index << "] = " << size << ";\n";
     return ss.str();
   }
 
@@ -211,43 +208,39 @@ class EthosUModuleNode : public ModuleNode {
    *
    * \return string of code that offloads a subgraph to the NPU
    */
-  std::string GenerateSource() {
-    std::string func_no_dashes = func_name;
+  std::string GenerateSource(relay::contrib::ethosu::CompilationArtifact compilation_artifact) {
+    std::string func_no_dashes = compilation_artifact->function_name;
     std::replace(func_no_dashes.begin(), func_no_dashes.end(), '-', '_');
     std::stringstream ss;
 
-    ss << "#include <stdio.h>\n";
-    ss << "#include <stdlib.h>\n";
-    ss << "#include <tvm/runtime/crt/module.h>\n";
-    ss << "#include <tvm_ethosu_runtime.h>\n";
-    ss << "\n";
-    size_t weights_size = (weights_bias_hex.size() / 2);
-    ss << "static const size_t weights_size = " << std::to_string(weights_size) << ";\n";
-    ss << "static const size_t scratch_size = " << std::to_string(scratch_size) << ";\n";
+    size_t weights_size = (compilation_artifact->encoded_constants.size() / 2);
+    size_t scratch_size = compilation_artifact->scratch_size;
     ss << "// Update linker script to place .rodata.tvm in memory that can be accessed by the "
           "NPU\n";
     if (weights_size > 0) {
-      ss << "__attribute__((section(\".rodata.tvm\"), aligned(16))) static int8_t weights["
-         << weights_size << "] = \"";
-      ss << GetHexString(weights_bias_hex);
+      ss << "__attribute__((section(\".rodata.tvm\"), aligned(16))) static int8_t "
+         << func_no_dashes << "_weights[" << weights_size << "] = \"";
+      ss << GetHexString(compilation_artifact->encoded_constants);
       ss << "\";\n";
     } else {
-      ss << "static int8_t* weights = NULL;\n";
+      ss << "static int8_t* " << func_no_dashes << "_weights = NULL;\n";
     }
-    ss << "__attribute__((section(\".rodata.tvm\"), aligned(16))) static int8_t cms_data_data["
-       << cmms_hex.size() / 2 << "] = \"";
-    ss << GetHexString(cmms_hex);
+    ss << "__attribute__((section(\".rodata.tvm\"), aligned(16))) static int8_t " << func_no_dashes
+       << "_cms_data_data[" << compilation_artifact->command_stream.size() / 2 << "] = \"";
+    ss << GetHexString(compilation_artifact->command_stream);
     ss << "\";\n";
-    ss << "static const size_t cms_data_size = sizeof(cms_data_data);\n";
     ss << "\n";
 
     PrintExternCPrefix(ss);
     ss << "static int32_t " << func_no_dashes + "_(int8_t* in0, "
        << "size_t in0_size, int8_t* out0, size_t out0_size, void* resource_handle) {\n";
     ss << "  int num_tensors = 5;\n";
-    ss << "  void* cms_data = (void*)(cms_data_data);\n";
+    ss << "  void* cms_data = (void*)(" << func_no_dashes << "_cms_data_data);\n";
     ss << "  int64_t device_type = kDLCPU;\n";
     ss << "  int64_t device_id = 0;\n";
+    ss << "  const size_t weights_size = " << std::to_string(weights_size) << ";\n";
+    ss << "  const size_t scratch_size = " << std::to_string(scratch_size) << ";\n";
+    ss << "  const size_t cms_data_size = sizeof(" << func_no_dashes << "_cms_data_data);\n";
     if (scratch_size > 0) {
       ss << "  int8_t* scratch = (int8_t*) TVMBackendAllocWorkspace(device_type, device_id, "
             "(uint64_t)scratch_size, 0, 16);\n";
@@ -257,11 +250,11 @@ class EthosUModuleNode : public ModuleNode {
     ss << "  size_t base_addrs_size[num_tensors];\n";
     ss << "  uint64_t base_addrs[num_tensors];\n";
     ss << "\n";
-    ss << SetBaseAddress(0, "weights");
-    ss << SetBaseAddress(1, "scratch");
-    ss << SetBaseAddress(2, "scratch");
-    ss << SetBaseAddress(3, "in0");
-    ss << SetBaseAddress(4, "out0");
+    ss << SetBaseAddress(0, func_no_dashes + "_weights", "weights_size");
+    ss << SetBaseAddress(1, "scratch", "scratch_size");
+    ss << SetBaseAddress(2, "scratch", "scratch_size");
+    ss << SetBaseAddress(3, "in0", "in0_size");
+    ss << SetBaseAddress(4, "out0", "out0_size");
     ss << "\n";
     ss << "  int32_t result = TVMEthosULaunch(resource_handle, cms_data, cms_data_size, "
           "base_addrs, base_addrs_size, num_tensors);\n";
@@ -277,8 +270,8 @@ class EthosUModuleNode : public ModuleNode {
     ss << "// Wrapper function is provided to allow for easier debugging\n";
     ss << "inline static int32_t " + func_no_dashes +
               "_wrapper_(void* input, void* output, void* resource_handle) {\n";
-    ss << "  size_t input_data_size = " << input_size << ";\n";
-    ss << "  size_t output_data_size = " << output_size << ";\n";
+    ss << "  size_t input_data_size = " << compilation_artifact->input_size << ";\n";
+    ss << "  size_t output_data_size = " << compilation_artifact->output_size << ";\n";
     ss << "  return " + func_no_dashes +
               "_((int8_t*)input, input_data_size, (int8_t*)output, output_data_size, " +
               "resource_handle);\n";
@@ -286,7 +279,7 @@ class EthosUModuleNode : public ModuleNode {
     PrintExternCPostfix(ss);
     ss << "\n";
     PrintExternCPrefix(ss);
-    PrintRuntimeFunctionHeader(ss, func_name);
+    PrintRuntimeFunctionHeader(ss, func_no_dashes);
     EnterScope();
     PrintIndents(ss);
     ss << "return " << func_no_dashes << "_wrapper_(input, output, resource_handle);\n";
@@ -313,14 +306,12 @@ inline EthosUModuleNode* EthosUModule::operator->() {
 }
 
 TVM_REGISTER_GLOBAL("runtime.module.ethos-u.create")
-    .set_body_typed([](String func_name, String cmms_hex, String weights_bias_hex,
-                       Integer scratch_size, Integer input_size, Integer output_size) {
-      return EthosUModuleNode::Create(func_name, cmms_hex, weights_bias_hex, scratch_size,
-                                      input_size, output_size);
+    .set_body_typed([](Array<CompilationArtifact> compilation_artifacts) {
+      return EthosUModuleNode::Create(compilation_artifacts);
     });
 
-TVM_REGISTER_GLOBAL("runtime.module.ethos-u.getcs").set_body_typed([](EthosUModule mod) {
-  return mod->GetCS();
+TVM_REGISTER_GLOBAL("runtime.module.ethos-u.get_artifacts").set_body_typed([](EthosUModule mod) {
+  return mod->GetArtifacts();
 });
 
 }  // namespace runtime
diff --git a/src/relay/backend/contrib/ethosu/utils.cc b/src/relay/backend/contrib/ethosu/utils.cc
new file mode 100644
index 0000000..7e6c1c2
--- /dev/null
+++ b/src/relay/backend/contrib/ethosu/utils.cc
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/backend/contrib/ethosu/utils.cc
+ * \brief Utilities for microNPU codegen
+ */
+
+#include "utils.h"
+
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/usmp/utils.h>
+
+#include <utility>
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace ethosu {
+
+CompilationArtifact::CompilationArtifact(String command_stream, String encoded_constants,
+                                         Integer scratch_size, Integer input_size,
+                                         Integer output_size, String function_name) {
+  auto compilation_artifact_node = make_object<CompilationArtifactNode>();
+  compilation_artifact_node->command_stream = command_stream;
+  compilation_artifact_node->encoded_constants = encoded_constants;
+  compilation_artifact_node->scratch_size = scratch_size;
+  compilation_artifact_node->input_size = input_size;
+  compilation_artifact_node->output_size = output_size;
+  compilation_artifact_node->function_name = function_name;
+  data_ = std::move(compilation_artifact_node);
+}
+
+TVM_REGISTER_NODE_TYPE(CompilationArtifactNode);
+TVM_REGISTER_GLOBAL("relay.ext.ethos-u.CompilationArtifact")
+    .set_body_typed([](String command_stream, String encoded_constants, Integer scratch_size,
+                       Integer input_size, Integer output_size, String function_name) {
+      return CompilationArtifact(command_stream, encoded_constants, scratch_size, input_size,
+                                 output_size, function_name);
+    });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<CompilationArtifactNode>([](const ObjectRef& ref, ReprPrinter* p) {
+      auto* node = static_cast<const CompilationArtifactNode*>(ref.get());
+      p->stream << "CompilationArtifactNode(\n"
+                << "command_stream=" << node->command_stream
+                << ",\n  encoded_constants=" << node->encoded_constants
+                << ",\n  scratch_size=" << node->scratch_size
+                << ",\n  input_size=" << node->input_size
+                << ",\n  output_size=" << node->output_size
+                << ",\n  function_name=" << node->function_name << ")";
+    });
+
+}  // namespace ethosu
+}  // namespace contrib
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/backend/contrib/ethosu/utils.h b/src/relay/backend/contrib/ethosu/utils.h
new file mode 100644
index 0000000..5e9e337
--- /dev/null
+++ b/src/relay/backend/contrib/ethosu/utils.h
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/backend/contrib/ethosu/utils.h
+ * \brief Utilities for microNPU codegen
+ */
+
+#ifndef TVM_RELAY_BACKEND_CONTRIB_ETHOSU_UTILS_H_
+#define TVM_RELAY_BACKEND_CONTRIB_ETHOSU_UTILS_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/target/target.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace ethosu {
+
+/*!
+ * \brief Captures all the binary artifactes required to create
+ * the C-source runtime module
+ */
+struct CompilationArtifactNode : public Object {
+  /*! \brief The binary command stream (CS) in hex format */
+  String command_stream;
+  /*! \brief The encoded biases and weights in hex format */
+  String encoded_constants;
+  /*! \brief The intermediary scratch area required for the execution of the CS */
+  Integer scratch_size;
+  /*! \brief The size of the input tensor in bytes */
+  Integer input_size;
+  /*! \brief The size of the output tensor in bytes */
+  Integer output_size;
+  /*! \brief The name of the function */
+  String function_name;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("command_stream", &command_stream);
+    v->Visit("encoded_constants", &encoded_constants);
+    v->Visit("scratch_size", &scratch_size);
+    v->Visit("input_size", &input_size);
+    v->Visit("output_size", &output_size);
+    v->Visit("function_name", &function_name);
+  }
+
+  bool SEqualReduce(const CompilationArtifactNode* other, SEqualReducer equal) const {
+    return equal(command_stream, other->command_stream) &&
+           equal(encoded_constants, other->encoded_constants) &&
+           equal(scratch_size, other->scratch_size) && equal(input_size, other->input_size) &&
+           equal(output_size, other->output_size) && equal(function_name, other->function_name);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(command_stream);
+    hash_reduce(encoded_constants);
+    hash_reduce(scratch_size);
+    hash_reduce(input_size);
+    hash_reduce(output_size);
+    hash_reduce(function_name);
+  }
+
+  static constexpr const char* _type_key = "relay.ext.ethos-u.CompilationArtifact";
+  TVM_DECLARE_FINAL_OBJECT_INFO(CompilationArtifactNode, Object);
+};
+
+class CompilationArtifact : public ObjectRef {
+ public:
+  TVM_DLL CompilationArtifact(String command_stream, String encoded_constants, Integer scratch_size,
+                              Integer input_size, Integer output_size, String function_name);
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CompilationArtifact, ObjectRef, CompilationArtifactNode);
+};
+
+}  // namespace ethosu
+}  // namespace contrib
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_BACKEND_CONTRIB_ETHOSU_UTILS_H_
diff --git a/src/relay/backend/contrib/example_target_hooks/target.cc b/src/relay/backend/contrib/example_target_hooks/target.cc
index 75b161a..6f1914e 100644
--- a/src/relay/backend/contrib/example_target_hooks/target.cc
+++ b/src/relay/backend/contrib/example_target_hooks/target.cc
@@ -33,6 +33,7 @@ runtime::Module TIRToRuntime(IRModule mod, Target target);
 }  // namespace relay
 
 TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU)
+    .set_attr<Bool>("use_device_api", Bool(true))
     .set_attr<FTVMRelayToTIR>("RelayToTIR", relay::contrib::example_target_hooks::RelayToTIR())
     .set_attr<FTVMTIRToRuntime>("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime);
 
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index 32c8966..b47bc40 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -228,6 +228,9 @@ class TECompilerImpl : public TECompilerNode {
   }
 
   Map<GlobalVar, String> GetDeviceContexts() { return device_contexts_; }
+  void SetDeviceContexts(const Map<GlobalVar, String>& device_contexts) {
+    device_contexts_ = device_contexts;
+  }
 
   void Clear() final { cache_.clear(); }
 
diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h
index d109412..60dd5fe 100644
--- a/src/relay/backend/te_compiler.h
+++ b/src/relay/backend/te_compiler.h
@@ -126,6 +126,7 @@ class TECompilerNode : public Object {
    * annotated)
    */
   virtual Map<GlobalVar, String> GetDeviceContexts() = 0;
+  virtual void SetDeviceContexts(const Map<GlobalVar, String>& device_contexts) = 0;
 
   virtual Map<String, Integer> GetOpWeights() const = 0;
 
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index 8246d61..5540c35 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -402,7 +402,6 @@ TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU)  // line break
 
 TVM_REGISTER_TARGET_KIND("composite", kDLCPU).add_attr_option<Array<Target>>("devices");
 
-TVM_REGISTER_TARGET_KIND("ethos-u", kDLCPU).set_attr<Bool>("use_device_api", Bool(true));
 
 /**********  Registry  **********/
 
diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc
index 3343e10..a5ecf4b 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -117,6 +117,11 @@ class BuiltinLower : public StmtExprMutator {
     // and less than runtime::kMaxStackAlloca heuristic
     // they are not serviced with TVMBackendWorkspaceAlloc calls
     // to be placed on stack.
+    if (op->annotations.count(transform::kDisableLowerTVMBuiltin)) {
+      if (Downcast<Bool>(op->annotations[transform::kDisableLowerTVMBuiltin])) {
+        return stmt;
+      }
+    }
     if (device_type_.defined()) {
       if (const auto* dev_type = device_type_.as<IntImmNode>()) {
         auto storage_scope = Downcast<PointerType>(op->buffer_var->type_annotation)->storage_scope;
diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc
index 6e8793f..169983a 100644
--- a/src/tir/transforms/make_unpacked_api.cc
+++ b/src/tir/transforms/make_unpacked_api.cc
@@ -64,31 +64,21 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) {
   // Collect variables and buffers to map between
   Array<Var> args;
   std::vector<std::pair<Var, Var>> var_def;
-  std::vector<std::pair<Var, Buffer>> buffer_def;
+  bool buffer_map_found = false;
 
   for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
     Var param = func_ptr->params[i];
-    Var v_arg = Var("arg" + std::to_string(i), param->dtype);
 
     auto it = func_ptr->buffer_map.find(param);
     if (it != func_ptr->buffer_map.end()) {
-      buffer_def.emplace_back(v_arg, (*it).second);
+      args.push_back((*it).second->data);
+      buffer_map_found = true;
     } else {
-      var_def.emplace_back(v_arg, param);
+      args.push_back(param);
     }
-
-    args.push_back(v_arg);
-  }
-
-  // Bind variables then bind buffers to them to ensure correct ordering
-  for (const auto& kv : var_def) {
-    binder.Bind(kv.second, kv.first, kv.first->name_hint, true);
-  }
-  for (const auto& kv : buffer_def) {
-    binder.Bind(kv.second->data, kv.first, kv.first->name_hint, true);
   }
 
-  if (buffer_def.size()) {
+  if (buffer_map_found) {
     device_init.push_back(AttrStmt(node, attr::device_id, device_id, nop));
     device_init.push_back(AttrStmt(node, attr::device_type, device_type, nop));
   }
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py
index e20ab41..f4393d4 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -154,14 +154,14 @@ def test_ethosu_conv2d(accel_type):
         )
 
         # Assumes only two runtime.Modules are created -- i.e. single offload module
-        imported_modules = compiled_models[0].executor_factory.lib.imported_modules
-        assert len(imported_modules) == 2
-        ethosu_module = imported_modules[0]
+        ethosu_module = (
+            compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
+        )
 
         # Verify generated C source
-        get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
-        cmms = get_cs(ethosu_module)
-        cmms = bytes.fromhex(cmms)
+        get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+        compilation_artifacts = get_artifacts(ethosu_module)
+        cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
         infra.print_payload(cmms)
         infra.verify_source(compiled_models, accel_type)
 
@@ -241,15 +241,12 @@ def test_tflite_depthwise_conv2d(
     )
 
     # Assumes only two runtime.Modules are created -- i.e. single offload module
-    imported_modules = compiled_models[0].executor_factory.lib.imported_modules
-    assert len(imported_modules) == 2
-    ethosu_module = imported_modules[0]
+    ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
 
     # Verify generated C source
-    get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
-    cmms = get_cs(ethosu_module)
-    cmms = bytes.fromhex(cmms)
-
+    get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+    compilation_artifacts = get_artifacts(ethosu_module)
+    cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
     infra.print_payload(cmms)
     infra.verify_source(compiled_models, accel_type)
 
@@ -328,15 +325,12 @@ def test_ethosu_pooling(
     )
 
     # Assumes only two runtime.Modules are created -- i.e. single offload module
-    imported_modules = compiled_models[0].executor_factory.lib.imported_modules
-    assert len(imported_modules) == 2
-    ethosu_module = imported_modules[0]
+    ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
 
     # Verify generated C source
-    get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
-    cmms = get_cs(ethosu_module)
-    cmms = bytes.fromhex(cmms)
-
+    get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+    compilation_artifacts = get_artifacts(ethosu_module)
+    cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
     infra.print_payload(cmms)
     infra.verify_source(compiled_models, accel_type)
 
@@ -423,15 +417,12 @@ def test_ethosu_binary_elementwise(
     )
 
     # Assumes only two runtime.Modules are created -- i.e. single offload module
-    imported_modules = compiled_models[0].executor_factory.lib.imported_modules
-    assert len(imported_modules) == 2
-    ethosu_module = imported_modules[0]
+    ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
 
     # Verify generated C source
-    get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
-    cmms = get_cs(ethosu_module)
-    cmms = bytes.fromhex(cmms)
-
+    get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+    compilation_artifacts = get_artifacts(ethosu_module)
+    cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
     infra.print_payload(cmms)
     infra.verify_source(compiled_models, accel_type)
 
@@ -501,15 +492,12 @@ def test_binary_add_with_non_4d_shapes(
     )
 
     # Assumes only two runtime.Modules are created -- i.e. single offload module
-    imported_modules = compiled_models[0].executor_factory.lib.imported_modules
-    assert len(imported_modules) == 2
-    ethosu_module = imported_modules[0]
+    ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
 
     # Verify generated C source
-    get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
-    cmms = get_cs(ethosu_module)
-    cmms = bytes.fromhex(cmms)
-
+    get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+    compilation_artifacts = get_artifacts(ethosu_module)
+    cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
     infra.print_payload(cmms)
     infra.verify_source(compiled_models, accel_type)
 
@@ -551,15 +539,12 @@ def test_binary_add_from_constant_scalar(accel_type):
     )
 
     # Assumes only two runtime.Modules are created -- i.e. single offload module
-    imported_modules = compiled_models[0].executor_factory.lib.imported_modules
-    assert len(imported_modules) == 2
-    ethosu_module = imported_modules[0]
+    ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
 
     # Verify generated C source
-    get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
-    cmms = get_cs(ethosu_module)
-    cmms = bytes.fromhex(cmms)
-
+    get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+    compilation_artifacts = get_artifacts(ethosu_module)
+    cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
     infra.print_payload(cmms)
     infra.verify_source(compiled_models, accel_type)
 
@@ -608,15 +593,12 @@ def test_ethosu_left_shift_binary_elemwise(
     )
 
     # Assumes only two runtime.Modules are created -- i.e. single offload module
-    imported_modules = compiled_models[0].executor_factory.lib.imported_modules
-    assert len(imported_modules) == 2
-    ethosu_module = imported_modules[0]
+    ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
 
     # Verify generated C source
-    get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
-    cmms = get_cs(ethosu_module)
-    cmms = bytes.fromhex(cmms)
-
+    get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+    compilation_artifacts = get_artifacts(ethosu_module)
+    cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
     infra.print_payload(cmms)
     infra.verify_source(compiled_models, accel_type)
 
@@ -705,18 +687,16 @@ def test_ethosu_right_shift_binary_elemwise(
         [rounding_right_shift(x[0], x[1]) for x in zip(lhs.flat, rhs.flat)]
     ).astype(ofm_dtype)
 
-    compiled_model = infra.build_source(mod, input_data, [output_data], accel_type)
-    imported_modules = compiled_model[0].executor_factory.lib.imported_modules
-    assert len(imported_modules) == 2
-    ethosu_module = imported_modules[0]
+    compiled_models = infra.build_source(mod, input_data, [output_data], accel_type)
+    # Assumes only two runtime.Modules are created -- i.e. single offload module
+    ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
 
     # Verify generated C source
-    get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
-    cmms = get_cs(ethosu_module)
-    cmms = bytes.fromhex(cmms)
-
+    get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+    compilation_artifacts = get_artifacts(ethosu_module)
+    cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
     infra.print_payload(cmms)
-    infra.verify_source(compiled_model, accel_type)
+    infra.verify_source(compiled_models, accel_type)
 
 
 @pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -738,15 +718,13 @@ def test_ethosu_identity_codegen(ifm_shape, ifm_scale, ifm_zp, ofm_scale, ofm_zp
         mod, {"ifm": in_data}, [out_data], accel_type, output_tolerance=1
     )
 
-    imported_modules = compiled_model[0].executor_factory.lib.imported_modules
-    assert len(imported_modules) == 2
-    ethosu_module = imported_modules[0]
+    # Assumes only two runtime.Modules are created -- i.e. single offload module
+    ethosu_module = compiled_model[0].executor_factory.lib.imported_modules[0].imported_modules[0]
 
     # Verify generated C source
-    get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
-    cmms = get_cs(ethosu_module)
-    cmms = bytes.fromhex(cmms)
-
+    get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+    compilation_artifacts = get_artifacts(ethosu_module)
+    cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
     infra.print_payload(cmms)
     infra.verify_source(compiled_model, accel_type)
 
@@ -786,15 +764,13 @@ def test_relay_reshape_codegen(ifm_shape, new_shape, accel_type):
         accel_type,
     )
 
-    imported_modules = compiled_model[0].executor_factory.lib.imported_modules
-    assert len(imported_modules) == 2
-    ethosu_module = imported_modules[0]
+    # Assumes only two runtime.Modules are created -- i.e. single offload module
+    ethosu_module = compiled_model[0].executor_factory.lib.imported_modules[0].imported_modules[0]
 
     # Verify generated C source
-    get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
-    cmms = get_cs(ethosu_module)
-    cmms = bytes.fromhex(cmms)
-
+    get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+    compilation_artifacts = get_artifacts(ethosu_module)
+    cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
     infra.print_payload(cmms)
     infra.verify_source(compiled_model, accel_type)
 
@@ -831,15 +807,13 @@ def test_relay_strided_slice_codegen(ifm_shape, begin, end, accel_type):
         accel_type,
     )
 
-    imported_modules = compiled_model[0].executor_factory.lib.imported_modules
-    assert len(imported_modules) == 2
-    ethosu_module = imported_modules[0]
+    # Assumes only two runtime.Modules are created -- i.e. single offload module
+    ethosu_module = compiled_model[0].executor_factory.lib.imported_modules[0].imported_modules[0]
 
     # Verify generated C source
-    get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
-    cmms = get_cs(ethosu_module)
-    cmms = bytes.fromhex(cmms)
-
+    get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+    compilation_artifacts = get_artifacts(ethosu_module)
+    cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
     infra.print_payload(cmms)
     infra.verify_source(compiled_model, accel_type)
 
@@ -907,15 +881,12 @@ def test_ethosu_unary_elementwise(
     )
 
     # Assumes only two runtime.Modules are created -- i.e. single offload module
-    imported_modules = compiled_models[0].executor_factory.lib.imported_modules
-    assert len(imported_modules) == 2
-    ethosu_module = imported_modules[0]
+    ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
 
     # Verify generated C source
-    get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
-    cmms = get_cs(ethosu_module)
-    cmms = bytes.fromhex(cmms)
-
+    get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+    compilation_artifacts = get_artifacts(ethosu_module)
+    cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
     infra.print_payload(cmms)
     infra.verify_source(compiled_models, accel_type)
 
@@ -957,16 +928,18 @@ def test_ethosu_section_name():
     )
 
     # Assumes only two runtime.Modules are created -- i.e. single offload module
-    imported_modules = compiled_models[0].executor_factory.lib.imported_modules
-    assert len(imported_modules) == 2
-    ethosu_module = imported_modules[0]
+    ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
 
     # Verify generated C source
     source = ethosu_module.get_source()
     assert (
-        '__attribute__((section(".rodata.tvm"), aligned(16))) static int8_t cms_data_data' in source
+        '__attribute__((section(".rodata.tvm"), aligned(16))) static int8_t tvmgen_default_ethos_u_main_0_cms_data_data'
+        in source
+    )
+    assert (
+        '__attribute__((section(".rodata.tvm"), aligned(16))) static int8_t tvmgen_default_ethos_u_main_0_weights'
+        in source
     )
-    assert '__attribute__((section(".rodata.tvm"), aligned(16))) static int8_t weights' in source
 
 
 @pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -990,15 +963,13 @@ def test_ethosu_clz(accel_type):
 
     compiled_model = infra.build_source(mod, {"ifm": in_data}, [out_data], accel_type)
 
-    imported_modules = compiled_model[0].executor_factory.lib.imported_modules
-    assert len(imported_modules) == 2
-    ethosu_module = imported_modules[0]
+    # Assumes only two runtime.Modules are created -- i.e. single offload module
+    ethosu_module = compiled_model[0].executor_factory.lib.imported_modules[0].imported_modules[0]
 
     # Verify generated C source
-    get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
-    cmms = get_cs(ethosu_module)
-    cmms = bytes.fromhex(cmms)
-
+    get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+    compilation_artifacts = get_artifacts(ethosu_module)
+    cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
     infra.print_payload(cmms)
     infra.verify_source(compiled_model, accel_type)
 
@@ -1057,15 +1028,12 @@ def test_tflite_tanh(accel_type):
     )
 
     # Assumes only two runtime.Modules are created -- i.e. single offload module
-    imported_modules = compiled_models[0].executor_factory.lib.imported_modules
-    assert len(imported_modules) == 2
-    ethosu_module = imported_modules[0]
+    ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
 
     # Verify generated C source
-    get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
-    cmms = get_cs(ethosu_module)
-    cmms = bytes.fromhex(cmms)
-
+    get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+    compilation_artifacts = get_artifacts(ethosu_module)
+    cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
     infra.print_payload(cmms)
     infra.verify_source(compiled_models, accel_type)
 
diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py
index 91cee81..de8a7f9 100644
--- a/tests/python/contrib/test_ethosu/test_encode_constants.py
+++ b/tests/python/contrib/test_ethosu/test_encode_constants.py
@@ -48,8 +48,8 @@ class WeightStreamOnly:
         ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1)
         buffer_7 = T.match_buffer(placeholder_6, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
         # body
-        placeholder_global = T.allocate([128], "uint8", "global")
-        placeholder_d_global = T.allocate([32], "uint8", "global")
+        placeholder_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin": True})
+        placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin": True})
         T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6.data, 0), 128, T.load("uint8", placeholder_global, 0), dtype="handle"))
         T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 128, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
@@ -122,7 +122,7 @@ class DirectReadOnly:
         buffer_2 = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
         buffer_3 = T.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
         # body
-        ethosu_write_2 = T.allocate([4096], "int8", "global")
+        ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin": True})
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 592, 12, T.load("uint8", buffer_2.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 160, 12, T.load("uint8", buffer_3.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
     __tvm_meta__ = None
@@ -190,9 +190,9 @@ class MixedRead:
         buffer_8 = T.match_buffer(placeholder_8, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
         buffer_9 = T.match_buffer(placeholder_10, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
         # body
-        ethosu_write_2 = T.allocate([4096], "int8", "global")
-        placeholder_global = T.allocate([80], "uint8", "global")
-        placeholder_d_global = T.allocate([32], "uint8", "global")
+        ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin": True})
+        placeholder_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin": True})
+        placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin": True})
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_11.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_6.data, 0), 592, 12, T.load("uint8", buffer_7.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
         T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle"))
         T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle"))
@@ -312,6 +312,10 @@ def test_constant_as_input():
 
     # More generally, check compiles successfully to make sure
     # nothing else was overrwritten.
+    # With Target Hooks the TIR module needs a target attached
+    # and lowered via make unpacked API.
+    tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u"))
+    tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod)
     tir_to_cs_translator.translate(tir_mod, params)
 
 
diff --git a/tests/python/contrib/test_ethosu/test_lookup_table.py b/tests/python/contrib/test_ethosu/test_lookup_table.py
index d32b441..9485b4f 100644
--- a/tests/python/contrib/test_ethosu/test_lookup_table.py
+++ b/tests/python/contrib/test_ethosu/test_lookup_table.py
@@ -103,16 +103,13 @@ def test_tflite_lut_activations(accel_type):
     )
 
     # Assumes only two runtime.Modules are created -- i.e. single offload module
-    imported_modules = compiled_models[0].executor_factory.lib.imported_modules
-    assert len(imported_modules) == 2
-    ethosu_module = imported_modules[0]
+    ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
 
     # Verify generated C source
-    get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
-    cmms = get_cs(ethosu_module)
-    cmms = bytes.fromhex(cmms)
+    get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+    compilation_artifacts = get_artifacts(ethosu_module)
+    cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
     infra.print_payload(cmms)
-
     infra.verify_source(compiled_models, accel_type)
 
 
@@ -162,16 +159,13 @@ def test_random_lut(accel_type):
     )
 
     # Assumes only two runtime.Modules are created -- i.e. single offload module
-    imported_modules = compiled_models[0].executor_factory.lib.imported_modules
-    assert len(imported_modules) == 2
-    ethosu_module = imported_modules[0]
+    ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
 
     # Verify generated C source
-    get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
-    cmms = get_cs(ethosu_module)
-    cmms = bytes.fromhex(cmms)
+    get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+    compilation_artifacts = get_artifacts(ethosu_module)
+    cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
     infra.print_payload(cmms)
-
     infra.verify_source(compiled_models, accel_type)
 
 
diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py
index 1d3afec..2f2cd7a 100644
--- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py
+++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py
@@ -255,7 +255,7 @@ class Conv2dDoubleCascade1:
         ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1)
         buffer_3 = T.match_buffer(placeholder_1, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
         # body
-        ethosu_write_2 = T.allocate([1024], "int8", "global")
+        ethosu_write_2 = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True})
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 160, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 304, 12, T.load("uint8", buffer_1.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 12), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 160, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
@@ -276,7 +276,7 @@ class Conv2dDoubleCascade2:
         placeholder_5 = T.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1)
         ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1)
         # body
-        ethosu_write_2 = T.allocate([1536], "int8", "global")
+        ethosu_write_2 = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True})
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2.data, 0), 1312, 12, T.load("uint8", buffer_1.data, 0), 320, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 2608, 12, T.load("uint8", buffer.data, 0), 80, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 48), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2.data, 0), 1312, 12, T.load("uint8", buffer_1.data, 0), 320, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
@@ -297,7 +297,7 @@ class Conv2dDoubleCascade3:
         buffer_3 = T.match_buffer(placeholder_1, [880], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
         placeholder_5 = T.match_buffer(placeholder, [1, 16, 16, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1)
         # body
-        ethosu_write_2 = T.allocate([2560], "int8", "global")
+        ethosu_write_2 = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True})
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3.data, 0), 880, 12, T.load("uint8", buffer_2.data, 0), 320, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer.data, 0), 1744, 12, T.load("uint8", buffer_1.data, 0), 80, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, T.load("int8", placeholder_5.data, 192), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3.data, 0), 880, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
@@ -320,7 +320,7 @@ class Conv2dDoubleCascade4:
         buffer_2 = T.match_buffer(placeholder_4, [272], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
         buffer_3 = T.match_buffer(placeholder_3, [11040], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
         # body
-        ethosu_write_2 = T.allocate([2304], "int8", "global")
+        ethosu_write_2 = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True})
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 1456, 12, T.load("uint8", buffer_1.data, 0), 352, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 11040, 12, T.load("uint8", buffer_2.data, 0), 272, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 1456, 12, T.load("uint8", buffer_1.data, 0), 352, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py
index b1f923d..cce414c 100644
--- a/tests/python/contrib/test_ethosu/test_replace_copy.py
+++ b/tests/python/contrib/test_ethosu/test_replace_copy.py
@@ -39,8 +39,8 @@ class ReferenceModule:
         buffer_1 = T.match_buffer(placeholder_1, [304], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
         ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1)
         # body
-        placeholder_global = T.allocate([304], "uint8", "global")
-        placeholder_d_global = T.allocate([80], "uint8", "global")
+        placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin": True})
+        placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin": True})
         T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle"))
         T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
@@ -87,8 +87,8 @@ class WeightStream:
         buffer_2 = T.match_buffer(placeholder_3, [272], dtype="uint8")
         buffer_3 = T.match_buffer(placeholder_4, [64], dtype="uint8")
         # body
-        placeholder_global = T.allocate([416], "uint8", "global")
-        placeholder_d_global = T.allocate([112], "uint8", "global")
+        placeholder_global = T.allocate([416], "uint8", "global", annotations={"disable_lower_builtin": True})
+        placeholder_d_global = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin": True})
         T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 416, T.load("uint8", placeholder_global, 0), dtype="handle"))
         T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 112, T.load("uint8", placeholder_d_global, 0), dtype="handle"))
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 416, 12, T.load("uint8", placeholder_d_global, 0), 112, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
index 94c8f0d..59b7b2c 100644
--- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
+++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
@@ -233,9 +233,12 @@ def test_buffer_info_extraction():
         },
     ]
     for test_case in test_cases:
-        buffer_info = tir_to_cs_translator.extract_buffer_info(
-            test_case["tir_module"], test_case["param_dict"]
-        )
+        # With Target Hooks the TIR module needs a target attached
+        # and lowered via make unpacked API.
+        tir_mod = test_case["tir_module"]
+        tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u"))
+        tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod)
+        buffer_info = tir_to_cs_translator.extract_buffer_info(tir_mod, test_case["param_dict"])
         for buffer_var, info in buffer_info.items():
             buffer_name = buffer_var.name
             if buffer_name in test_case["constants"].keys():
@@ -247,8 +250,6 @@ def test_buffer_info_extraction():
                 )
                 info.btype == tir_to_cs_translator.BufferType.constant
             else:
-                assert list(info.shape) == test_case["data_buffers"][buffer_name][0]
-                assert info.dtype == test_case["data_buffers"][buffer_name][1]
                 assert info.btype == test_case["data_buffers"][buffer_name][2]
 
 
@@ -831,10 +832,11 @@ def test_assign_addresses():
                     )
 
     for test_case in test_cases:
-        buffer_info = tir_to_cs_translator.extract_buffer_info(
-            test_case["tir_module"], test_case["param_dict"]
-        )
-        extern_calls = extract_call_extern_list(test_case["tir_module"])
+        tir_mod = test_case["tir_module"]
+        tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u"))
+        tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod)
+        buffer_info = tir_to_cs_translator.extract_buffer_info(tir_mod, test_case["param_dict"])
+        extern_calls = extract_call_extern_list(tir_mod)
         _npu_ops = list()
         for extern_call in extern_calls:
             _npu_ops.append(tir_to_cs_translator.translate_ethosu_tir_call_extern(extern_call))
diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py
index 95defff..73f3a0f 100644
--- a/tests/python/driver/tvmc/test_compiler.py
+++ b/tests/python/driver/tvmc/test_compiler.py
@@ -483,8 +483,8 @@ def test_compile_tflite_module_with_external_codegen_ethosu(
         tvmc.compiler.compile_model(
             tvmc_model,
             target=f"ethos-u -accelerator_config={accel_type}, c -mcpu=cortex-m55",
-            runtime=Runtime("crt", {"system-lib": True}),
-            executor=Executor("aot"),
+            runtime=Runtime("crt"),
+            executor=Executor("aot", {"unpacked-api": True}),
             output_format="mlf",
             package_path=output_file_name,
             pass_context_configs=["tir.disable_vectorize=true"],
diff --git a/tests/python/relay/aot/test_c_device_api.py b/tests/python/relay/aot/test_c_device_api.py
index 3de4fec..473b8d5 100644
--- a/tests/python/relay/aot/test_c_device_api.py
+++ b/tests/python/relay/aot/test_c_device_api.py
@@ -92,7 +92,7 @@ def device_api_main_func():
             workspace_byte_alignment=16,
             pass_config=test_runner.pass_config,
         )
-        main_ir_module = list(compiled_models[0].executor_factory.lowered_ir_mods.values())[0]
+        main_ir_module = compiled_models[0].executor_factory.lowered_ir_mods.items()[1][1]
         main_func = main_ir_module["run_model"]
         return main_func
 
@@ -177,6 +177,9 @@ def test_device_api_hooks_unpacked_api(device_api_main_func):
     )
 
 
+@pytest.mark.skip(
+    "Skipping this test as this is incorrectly using Arm(R) Ethos(TM)-U NPU with packed calling convention which is not supported by the NPU codegen's TIR to Runtime Hook. We need to use a different target to test this feature"
+)
 def test_device_api_hooks_packed_api(device_api_main_func):
     """Check for Device API hooks with packed internal calls"""
     main_func = device_api_main_func(interface_api="packed", use_unpacked_api=False)
diff --git a/tests/python/unittest/test_tir_transform_make_unpacked_api.py b/tests/python/unittest/test_tir_transform_make_unpacked_api.py
index 9d91746..e5f41e7 100644
--- a/tests/python/unittest/test_tir_transform_make_unpacked_api.py
+++ b/tests/python/unittest/test_tir_transform_make_unpacked_api.py
@@ -58,7 +58,7 @@ def test_device_setup(mod, target, dev):
     mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target(target)))(mod)
     f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
     assert len(f.params) == 1
-    assert f.params[0].name == "arg0"
+    assert f.params[0].name == "A"
     assert f.body.node == "default"
     assert f.body.attr_key == "device_id"
     assert f.body.value == 0
@@ -77,16 +77,13 @@ def test_no_buffers_no_device_setup():
 
     f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
     assert len(f.params) == 1
-    assert f.body.var.name == "A"
-    assert f.body.value.name == "arg0"
+    assert f.params[0].name == "A"
 
 
 def test_argument_mapping(mod):
     f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
     assert len(f.params) == 1
-    assert f.params[0].name == "arg0"
-    assert f.body.body.body.var.name == "A"
-    assert f.body.body.body.value.name == "arg0"
+    assert f.params[0].name == "A"
 
 
 def test_argument_mapping_multiple():
@@ -101,12 +98,8 @@ def test_argument_mapping_multiple():
 
     f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
     assert len(f.params) == 2
-    assert f.params[0].name == "arg0"
-    assert f.params[1].name == "arg1"
-    assert f.body.body.body.var.name == "A"
-    assert f.body.body.body.value.name == "arg0"
-    assert f.body.body.body.body.var.name == "B"
-    assert f.body.body.body.body.value.name == "arg1"
+    assert f.params[0].name == "A"
+    assert f.params[1].name == "B"
 
 
 def test_argument_mapping_multiple_matching():
@@ -120,12 +113,8 @@ def test_argument_mapping_multiple_matching():
 
     f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
     assert len(f.params) == 2
-    assert f.params[0].name == "arg0"
-    assert f.params[1].name == "arg1"
-    assert f.body.body.body.var.name == "A"
-    assert f.body.body.body.value.name == "arg0"
-    assert f.body.body.body.body.condition.a.name == "A"
-    assert f.body.body.body.body.condition.b.name == "arg1"
+    assert f.params[0].name == "A"
+    assert f.params[1].name == "A"
 
 
 def test_body():
@@ -140,15 +129,9 @@ def test_body():
     mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod)
     f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
     assert len(f.params) == 3
-    assert f.params[0].name == "arg0"
-    assert f.params[1].name == "arg1"
-    assert f.params[2].name == "arg2"
-    assert f.body.body.body.var.name == "A"
-    assert f.body.body.body.value.name == "arg2"
-    assert f.body.body.body.body.var.name == "B"
-    assert f.body.body.body.body.value.name == "arg1"
-    assert f.body.body.body.body.body.condition.a.name == "A"
-    assert f.body.body.body.body.body.condition.b.name == "arg0"
+    assert f.params[0].name == "A"
+    assert f.params[1].name == "B"
+    assert f.params[2].name == "A"
 
 
 if __name__ == "__main__":