You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2023/03/27 19:26:18 UTC

[tvm] branch main updated: [TVMC][microNPU] tvmc option for printing which operators are offloaded to Ethos-U (#13212)

This is an automated email from the ASF dual-hosted git repository.

lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new da8335378a [TVMC][microNPU] tvmc option for printing which operators are offloaded to Ethos-U (#13212)
da8335378a is described below

commit da8335378af0a8454bb23b1aa5638a520fc5cb94
Author: sergio-grovety <89...@users.noreply.github.com>
AuthorDate: Mon Mar 27 22:26:10 2023 +0300

    [TVMC][microNPU] tvmc option for printing which operators are offloaded to Ethos-U (#13212)
    
    Added an option to tvmc and Ethos-U for printing to console or to the file which operators from the initial graph are offloaded to Ethos-U and which aren't. It forms line-by-line output of initial model IR, indicating which operations ported to Ethos-U.
    
    Compiler option "--target-ethos-u-dump_npu_functions_coverage" has been replaced by more generic "--dump-offloads" with the same meaning.
    
    
    ## Usage
    ```
    # output to console:
    tvmc compile --target=ethos-u,cmsis-nn,c \
        --dump-offloads=- \
        ........
    
    # output to file:
    tvmc compile --target=ethos-u,cmsis-nn,c \
        --dump-offloads=<file path> \
        ........
    ```
    
    ## Example output:
    
    
    ...
    Total number of operators and distribution by targets
    Total: 211
    target1: 198
    target2: 10
    generic: 3
    
    'target1        <-     target2.qnn_conv2d'
    'target1        <-          %0 = qnn.conv2d(%tfl.quantize, %v_param_1, ...'
    'target1        <-          %1 = nn.bias_add(%0, %v_param_2, axis=3);'
    'target1        <-          %2 = qnn.requantize(%1, meta[relay.Constant]...'
    'target2        <-     target2.reshape'
    'target2        <-          %3 = reshape(%2, newshape=[1, 1001]);'
    'generic        <-     %4 = nn.pad(%3, -128f, pad_width=[[0, 0], [1, 1]...'
    ...
---
 python/tvm/driver/tvmc/compiler.py                 | 173 ++++++++++
 .../tvm/relay/analysis/operations_distribution.py  | 102 ++++++
 python/tvm/relay/transform/suffixes.py             | 105 ++++++
 .../backend/contrib/cmsisnn/extract_constants.cc   |   1 +
 src/relay/backend/contrib/cmsisnn/fuse_pads.cc     |   3 +-
 .../backend/contrib/cmsisnn/generate_constants.cc  |  12 +-
 .../contrib/cmsisnn/scalar_to_tensor_constant.cc   |   5 +-
 src/relay/transforms/annotate_target.cc            |   1 +
 tests/python/contrib/test_ethosu/infra.py          |  11 +-
 .../test_pass_operations_distribution.py           | 173 ++++++++++
 tests/python/driver/tvmc/test_compiler.py          | 351 +++++++++++++++++++++
 11 files changed, 927 insertions(+), 10 deletions(-)

diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py
index 6e61e762ee..c42974593a 100644
--- a/python/tvm/driver/tvmc/compiler.py
+++ b/python/tvm/driver/tvmc/compiler.py
@@ -20,8 +20,12 @@ Provides support to compile networks both AOT and JIT.
 """
 import logging
 import os.path
+import re
+import itertools
+from copy import deepcopy
 from typing import Any, Optional, Dict, List, Union, Callable, Sequence
 from pathlib import Path
+from collections import defaultdict
 
 import tvm
 from tvm import autotvm, auto_scheduler
@@ -31,6 +35,8 @@ from tvm.ir.instrument import PassInstrument
 from tvm.ir.memory_pools import WorkspaceMemoryPools
 from tvm.target import Target
 from tvm.relay.backend import Executor, Runtime
+from tvm.relay.analysis.operations_distribution import analyze_operations_distribution
+from tvm.relay.transform.suffixes import tag_suffixes
 
 from . import composite_target, frontends, TVMCException
 from .model import TVMCModel, TVMCPackage
@@ -69,6 +75,16 @@ def add_compile_parser(subparsers, _, json_params):
         default="",
         help="comma separated list of formats to export the input model, e.g. 'asm,ll,tir,relay'.",
     )
+    parser.add_argument(
+        "--dump-offloads",
+        default="",
+        help="output a mapping of which operations of the initial Relay "
+        "will be transferred to which backend, indicating the composite "
+        "that includes those operations, "
+        "e.g. '--dump-offloads -' to dump to the console, "
+        "e.g. '--dump-offloads <path_to_file>' to dump to the file. "
+        "If not presented, no output is done. ",
+    )
     parser.add_argument(
         "--model-format",
         choices=frontends.get_frontend_names(),
@@ -171,6 +187,8 @@ def drive_compile(args):
 
     dump_code = [x.strip() for x in args.dump_code.split(",")] if args.dump_code else None
 
+    dump_offloads = args.dump_offloads if args.dump_offloads else ""
+
     additional_targets = reconstruct_target_args(args)
     workspace_pools_target, extra_targets = target_from_cli(args.target, additional_targets)
     transform_args = parse_graph_transform_args(args)
@@ -187,6 +205,7 @@ def drive_compile(args):
         cross_options=args.cross_compiler_options,
         output_format=args.output_format,
         dump_code=dump_code,
+        dump_offloads=dump_offloads,
         target_host=None,
         disabled_pass=args.disabled_pass,
         pass_context_configs=args.pass_config,
@@ -213,6 +232,7 @@ def compile_model(
     cross_options: Optional[str] = None,
     output_format: str = "so",
     dump_code: Optional[List[str]] = None,
+    dump_offloads: str = "",
     target_host: Optional[str] = None,
     disabled_pass: Optional[str] = None,
     pass_context_configs: Optional[List[str]] = None,
@@ -259,6 +279,10 @@ def compile_model(
     dump_code : list[str], optional
         Dump the generated code for the specified source types, on
         the requested target. Choose from: ["asm", "ll", "tir", "relay"].
+    dump_offloads : str
+        Dump the information about the partition of input model's layers by external codegen.
+        Can be '' to not dump at all, '-' to dump to the console
+        or '<path_to_file>' to dump to the specified file.
     target_host : str, optional
         The target of the host machine if host-side code
         needs to be generated.
@@ -313,6 +337,13 @@ def compile_model(
     if "tir" in dump_code:
         config, dumps = add_tir_to_dumps(config, dumps)
 
+    initial_relay = None
+    if dump_offloads != "":
+        # add suffixes to the span field for calls in Relay
+        mod = tag_suffixes(mod)
+        # remember initial Relay
+        initial_relay = deepcopy(mod)
+
     tvm_target, extra_targets = target_from_cli(target, additional_target_options)
     tvm_target, target_host = Target.canon_target_and_host(tvm_target, target_host)
 
@@ -337,6 +368,10 @@ def compile_model(
         for partition_function, opts in zip(partition_functions, partition_opts):
             mod = partition_function(mod, params, mod_name=mod_name, **opts)
 
+        if initial_relay:
+            # dump which operations are offloaded to which backend
+            dump_operation_offloads(mod, initial_relay, dump_offloads)
+
         if tuning_records and os.path.exists(tuning_records):
             logger.debug("tuning records file provided: %s", tuning_records)
 
@@ -496,3 +531,141 @@ def save_dumps(module_name: str, dumps: Dict[str, str], dump_root: str = "."):
         dump_name = module_name + "." + dump_format
         with open(Path(dump_root, dump_name), "w") as f:
             f.write(dumps[dump_format])
+
+
+def dump_operation_offloads(mod: tvm.ir.IRModule, initial_mod: tvm.ir.IRModule, dump_path: str):
+    """This helper function forms a line-by-line output of the initial Relay lines,
+    indicating which operations are ported to which target,
+    and indicating the composite that includes those operations;
+    the 'generic' target refers to operations uploaded to the host, e.g
+    'target1        <-     target1.qnn_conv2d'
+    'target1        <-          %0 = qnn.conv2d(%tfl.quantize, %v_param_1, ...'
+    'target1        <-          %1 = nn.bias_add(%0, %v_param_2, axis=3);'
+    'target1        <-          %2 = qnn.requantize(%1, meta[relay.Constant]...'
+    'target2        <-     target2.reshape'
+    'target2        <-          %3 = reshape(%2, newshape=[1, 1001]);'
+    'generic        <-     %4 = nn.pad(%3, -128f, pad_width=[[0, 0], [1, 1]...'
+
+    Parameters
+    ----------
+    mod : tvm.ir.IRModule
+        The partitioned IRModule with external global functions.
+    initial_mod : tvm.ir.IRModule
+        The initial IRModule that gets generated from a relay frontend.
+    dump_path: str
+        Value of the "dump_offloads" compiler atribute.
+        Could be dash ("-") or file path or empty string for
+        printing to console, file or doing nothing respectively.
+    """
+    print_to_console = dump_path == "-"
+    save_to_file = all([dump_path != "-", dump_path != ""])
+
+    if print_to_console or save_to_file:
+
+        operations_distribution = analyze_operations_distribution(mod)
+
+        def annotate_f(x):
+            ret = ""
+            if isinstance(x, relay.Call):
+                # if there is no x.span.source_name.name in operations_distribution,
+                # this could mean that the span was not copied during the application of passes
+                # to the Relay, in which case we can not associate the initial Relay string
+                # with the resulting Relay call
+                source_name = x.span.source_name.name
+                if source_name in operations_distribution:
+                    compiler_name, op_name, func_id = operations_distribution[source_name]
+                    ret = (
+                        f", compiler_name: {compiler_name}, op_name: {op_name}, "
+                        f"func_id: {func_id}"
+                    )
+                else:
+                    ret = ", compiler_name: unknown, op_name: unknown, func_id: unknown"
+            return ret
+
+        initial_relay_astext = initial_mod.astext(show_meta_data=False, annotate=annotate_f).split(
+            "\n"
+        )
+
+        # funcs_list is a list of internal composite/function IDs
+        # generated by analyze_operations_distribution().
+        # funcs_list helps keep the order of lines from the initial Relay.
+        funcs_list = []
+
+        # target_statistic is a mapping of the target name to the
+        # number of initial Relay calls offloaded on the target
+        target_statistic = defaultdict(int)
+
+        # funcs_dict is a mapping of the generated analyze_operations_distribution
+        # internal composite/function IDs to a list, where:
+        # 1st element is
+        #   (1a): target name - it could be "generic" or "unknown" or
+        #   (1b): specific target name, like "ethos-u" or "cmsis-nn"
+        # 2nd element is
+        #   (2a): corresponding initial Relay line for the case (1a) or
+        #   (2b): the name of the target composite functon in the other case (1b)
+        # 3rd element or subsequent ones are presented only for the case (2b)
+        # and are the initial Relay lines included in the corresponding
+        # target composite functon
+        funcs_dict = {}
+
+        # Here we group together initial Relay lines from the one composite
+        counter = itertools.count()
+        for s in initial_relay_astext:
+            result = re.search(
+                r"(compiler_name: )(.*)(, op_name: )(.*)(, func_id: )((.*)(?=;)|(.*))", s
+            )
+            if result:
+                target_name = result.group(2)
+                op_name = result.group(4)
+                func_id = result.group(6)
+                s = re.sub(r", compiler_name: (.*)", "", s).lstrip()
+                target_statistic[target_name] += 1
+
+                # create an identifier for each "unknown" case to keep the lines order
+                if func_id == "unknown":
+                    func_id = str(next(counter) * -1)
+
+                if func_id not in funcs_dict:
+                    funcs_list.append(func_id)
+                    funcs_dict[func_id] = [target_name]
+                    if target_name not in ["unknown", "generic"]:
+                        funcs_dict[func_id].append(op_name)
+
+                funcs_dict[func_id].append(s)
+
+        # Here we prepare the output for printing.
+        # The output in most cases keeps the original order of the Relay lines
+        # but some lines are moved to be in the corresponding composite group
+        output = []
+        total = 0
+        output.append("Total number of operators and distribution by targets")
+        output.append("Total:")
+        for target, statistic in target_statistic.items():
+            total += statistic
+            output.append(f"{target}: {statistic}")
+        output[1] += f" {total}"
+        output[len(target_statistic) + 1] += "\n"
+
+        for func_id in funcs_list:
+            _list = funcs_dict[func_id]
+            output.append(f"{_list[0]:10}     <-     {_list[1]}")
+            if _list[0] == "unknown":
+                output.append(
+                    "Warning: The above line means that some pass(es) \
+                              in Relay partitioning"
+                )
+                output.append("do not copy the span when the call is recreated")
+                output.append(
+                    "and a line from initial Relay could not be associated \
+                              with the resulting Relay"
+                )
+            for el in _list[2:]:
+                output.append(f"{_list[0]:10}     <-          {el}")
+
+        if print_to_console:
+            print("\n" + "\n".join(output))
+        if save_to_file:
+            file_path = os.path.abspath(dump_path)
+            os.makedirs(os.path.dirname(file_path), exist_ok=True)
+            with open(file_path, "w") as f:
+                f.write("\n".join(output))
diff --git a/python/tvm/relay/analysis/operations_distribution.py b/python/tvm/relay/analysis/operations_distribution.py
new file mode 100644
index 0000000000..fc983c8e7e
--- /dev/null
+++ b/python/tvm/relay/analysis/operations_distribution.py
@@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities that enable analyze Relay and get mappings for
+the unique identifier of the Relay line to the tuple of
+compiler name, composite name and composite/function identifier."""
+import tvm
+from tvm import relay
+from tvm.relay.expr_functor import ExprVisitor
+
+
+class AnalyzeOperationsDistribution(ExprVisitor):
+    """A visitor pass that maintains the dictionary unique_op_ids where
+    the tuple (compiler name, composite name, composite/function identifier)
+    corresponds to the unique identifier of the Relay line.
+    TVMC compiler adds a unique Relay line identifier as a suffix
+    to the call span field using the tag_suffixes pass
+    if the --dump-offloads option is specified.
+
+    Attributes
+    ----------
+    unique_op_ids : Dict[str, str, int]
+        Mapping the unique identifier of the Relay line obtained from
+        the "span" field of the Call and the tuple of compiler name,
+        composite name and internal composite/function identifier.
+    func_name : str
+        The name of the composite name in the partitioned Relay or
+        'generic' in case the Call has not been included in any composite.
+    func_id : int
+        Internal(inside unique_op_ids) composite/function identifier.
+    compiler_name : str
+        A name of the compiler (e.g. 'ethos-u' or 'cmsis-nn') or 'generic'
+        in case the Call has not been included in any composite.
+    """
+
+    def __init__(self):
+        self.unique_op_ids = {}
+        self.func_name = ""
+        self.func_id = 1
+        self.compiler_name = ""
+        super().__init__()
+
+    def extract(self, call: relay.Call):
+        self.compiler_name = "generic"
+        self.func_name = "generic"
+        if "Compiler" in call.attrs:
+            self.compiler_name = call.attrs["Compiler"]
+        self.visit(call)
+
+    def visit_call(self, call: relay.Call):
+        if isinstance(call.op, tvm.ir.Op):
+            if call.span:
+                src = call.span.source_name.name
+                self.unique_op_ids[src] = [self.compiler_name, self.func_name, self.func_id]
+                if self.func_name == "generic":
+                    self.func_id += 1
+        if isinstance(call.op, relay.Function):
+            self.func_name = call.op.attrs["Composite"]
+            self.func_id += 1
+        super().visit_call(call)
+
+
+def analyze_operations_distribution(mod):
+    """Traverses the partitioned graph to get the unique identifier
+    of the Relay line from the Call's span field.
+    The result is maintained in the dictionary unique_op_ids where
+    the unique indicator obtained from the op's span corresponds to
+    the tuple (compiler name, composite name, composite/function identifier).
+    With this information we can annotate the textual representation
+    of the initial Relay by indicating into which target composite
+    and function the operators are converted
+
+    Parameters
+    ----------
+    mod : tvm.ir.IRModule
+        The partitioned Relay graph usually obtained with
+        partition_for_<target> function
+
+    Returns
+    -------
+    unique_op_ids : Dict[str, str, int]
+        Mapping from the unique identifier of the Relay line to the tuple of
+        compiler name, composite name, internal composite/function
+        identifier.
+    """
+    analyze = AnalyzeOperationsDistribution()
+    for _, func in mod.functions.items():
+        analyze.extract(func)
+    return analyze.unique_op_ids
diff --git a/python/tvm/relay/transform/suffixes.py b/python/tvm/relay/transform/suffixes.py
new file mode 100644
index 0000000000..e2f7a3c224
--- /dev/null
+++ b/python/tvm/relay/transform/suffixes.py
@@ -0,0 +1,105 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"Add suffix to the relay.Call's span fields"
+from collections import defaultdict
+
+import tvm
+
+from ..expr_functor import ExprMutator
+from .. import expr as _expr
+
+
+class _SuffixTagger(ExprMutator):
+    """A pass to traverse the Relay graph to add suffix to the call's span fields.
+    This making span an unique indicator of a Relay line and we can use it to
+    obtain the mapping between the Relay that gets generated from a relay frontend
+    and the Relay after partitioning.
+    """
+
+    def __init__(self):
+        ExprMutator.__init__(self)
+        # key: span or source name, value: counter, indexed from 0
+        self.lookup = defaultdict(int)
+        self.suffix = "_PART_"
+        # a set to record hashes of an expressions which spans have been already rewritten
+        self.hashes = set()
+
+    def _tag_suffix(self, span):
+        # To avoid error once we introduce the SequentialSpan in the future
+        """https://discuss.tvm.apache.org/
+        t/pre-rfc-tvm-explorer-infrastructure/13457#pass-source-information-builder-6
+        """
+        # Don't need this if currently
+        if isinstance(span, tvm.relay.Span):
+            ori_name = span.source_name.name
+            new_name = ori_name + self.suffix + str(self.lookup[ori_name])
+            self.lookup[ori_name] += 1
+            return tvm.relay.Span(
+                tvm.relay.SourceName(new_name),
+                span.line,
+                span.end_line,
+                span.column,
+                span.end_column,
+            )
+        return span
+
+    def visit(self, expr):
+        if hasattr(expr, "span"):
+            return super().visit(expr)
+        return expr
+
+    def visit_call(self, call):
+        new_args = [self.visit(arg) for arg in call.args]
+        new_op = self.visit(call.op)
+        if tvm.ir.structural_hash(call) not in self.hashes:
+            self.hashes.add(tvm.ir.structural_hash(call))
+            expr__ = _expr.CallWithFields(
+                call,
+                new_op,
+                new_args,
+                call.attrs,
+                call.type_args,
+                None,
+                self._tag_suffix(call.span),
+            )
+        else:
+            expr__ = _expr.CallWithFields(
+                call, new_op, new_args, call.attrs, call.type_args, None, call.span
+            )
+        return expr__
+
+
+def tag_suffixes(mod):
+    """Traverses the Relay graph to add suffix to the call's span fields.
+    That making span as an unique indicator of a Relay call and we can use it to
+    obtain the mapping between the offloaded result and the frontend operators.
+
+    Parameters
+    ----------
+    tvm.ir.IRModule
+        The IRModule that gets generated from a relay frontend.
+
+    Returns
+    -------
+    tvm.ir.IRModule
+        The IRModule with call's span fields tagged with suffixes.
+    """
+    tagger = _SuffixTagger()
+    for global_var, func in mod.functions.items():
+        func = tagger.visit(func)
+        mod.update_func(global_var, func)
+    return mod
diff --git a/src/relay/backend/contrib/cmsisnn/extract_constants.cc b/src/relay/backend/contrib/cmsisnn/extract_constants.cc
index c6ed7af9ff..f82014d5d1 100644
--- a/src/relay/backend/contrib/cmsisnn/extract_constants.cc
+++ b/src/relay/backend/contrib/cmsisnn/extract_constants.cc
@@ -206,6 +206,7 @@ class ExtractConstantsMutator : public MixedModeMutator {
       final_call = Call(new_func, new_args);
     }
 
+    final_call->span = call->span;
     return final_call;
   }
 
diff --git a/src/relay/backend/contrib/cmsisnn/fuse_pads.cc b/src/relay/backend/contrib/cmsisnn/fuse_pads.cc
index 71c31c3035..0ef7091fc2 100644
--- a/src/relay/backend/contrib/cmsisnn/fuse_pads.cc
+++ b/src/relay/backend/contrib/cmsisnn/fuse_pads.cc
@@ -138,7 +138,7 @@ class FusePadsMutator : public MixedModeMutator {
     auto new_conv2d_args = conv2d_call->args;
     new_conv2d_args.erase(new_conv2d_args.begin());
     new_conv2d_args.insert(new_conv2d_args.begin(), new_conv2d_input);
-    Call ret_call = Call(conv2d_call->op, new_conv2d_args, new_conv2d_attrs, {});
+    Call ret_call = Call(conv2d_call->op, new_conv2d_args, new_conv2d_attrs, {}, conv2d_call->span);
     return std::move(ret_call);
   }
 
@@ -162,6 +162,7 @@ class FusePadsMutator : public MixedModeMutator {
         Function new_func = Function(FreeVars(new_body), new_body, func->ret_type,
                                      FreeTypeVars(new_body, mod_), func->attrs);
         ret_call = Call(new_func, post_call->args);
+        ret_call->span = call->span;
       }
     }
 
diff --git a/src/relay/backend/contrib/cmsisnn/generate_constants.cc b/src/relay/backend/contrib/cmsisnn/generate_constants.cc
index e08b61c457..3bdbb5d057 100644
--- a/src/relay/backend/contrib/cmsisnn/generate_constants.cc
+++ b/src/relay/backend/contrib/cmsisnn/generate_constants.cc
@@ -153,16 +153,17 @@ class GenerateConstantsMutator : public MixedModeMutator {
     // Conv2D arguments: data, weight, input_zp, weight_zp, input_sc, weight_sc
     Array<Expr> conv2d_args = {conv2d_call->args[0], conv2d_kernel,        conv2d_call->args[2],
                                multiplier_const,     conv2d_call->args[4], weight_scale};
-    Call ret_call = Call(conv2d_call->op, conv2d_args, new_conv2d_attrs, {});
+    Call ret_call = Call(conv2d_call->op, conv2d_args, new_conv2d_attrs, {}, conv2d_call->span);
     if (bias_add_call) {
-      ret_call =
-          Call(bias_add_call->op, {ret_call, bias_add_call->args[1]}, bias_add_call->attrs, {});
+      ret_call = Call(bias_add_call->op, {ret_call, bias_add_call->args[1]}, bias_add_call->attrs,
+                      {}, bias_add_call->span);
     }
     Array<Expr> requantize_args = {ret_call, req_inp_scale, shift_const, requantize_call->args[3],
                                    requantize_call->args[4]};
-    ret_call = Call(requantize_call->op, requantize_args, requantize_call->attrs, {});
+    ret_call = Call(requantize_call->op, requantize_args, requantize_call->attrs, {},
+                    requantize_call->span);
     if (clip_call) {
-      ret_call = Call(clip_call->op, {ret_call}, clip_call->attrs, {});
+      ret_call = Call(clip_call->op, {ret_call}, clip_call->attrs, {}, clip_call->span);
     }
     return std::move(ret_call);
   }
@@ -198,6 +199,7 @@ class GenerateConstantsMutator : public MixedModeMutator {
       }
     }
 
+    final_call->span = call->span;
     return final_call;
   }
 
diff --git a/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc b/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
index 0e2036505b..f64f485bfd 100644
--- a/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
+++ b/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
@@ -83,6 +83,7 @@ class ScalarToTensorConstantMutator : public MixedModeMutator {
                                      FreeTypeVars(new_body, mod_), func->attrs);
       mod_->Update(global_var, new_func);
       final_call = Call(global_var, call->args);
+      final_call->span = call->span;
     }
 
     // Substitute scalar constant with tensor constant in the call to composite function.
@@ -140,7 +141,7 @@ class ScalarToTensorConstantMutator : public MixedModeMutator {
       String arg_name = scalar_arg.as<VarNode>()->name_hint();
       new_args.Set(i, Var(arg_name, tensor_arg->checked_type_));
     }
-    return Call(call->op, new_args, call->attrs, {});
+    return Call(call->op, new_args, call->attrs, {}, call->span);
   }
 
   // Replaces scalar constant with a tensor constant with same shape as that of the neighbouring
@@ -187,7 +188,7 @@ class ScalarToTensorConstantMutator : public MixedModeMutator {
     if (new_args[0].same_as(new_args[1])) {
       new_args.erase(new_args.begin());
     }
-    return Call(new_func, new_args);
+    return Call(new_func, new_args, Attrs(), {}, call->span);
   }
 
  private:
diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc
index 3f1985b7dd..eb6f9ec004 100644
--- a/src/relay/transforms/annotate_target.cc
+++ b/src/relay/transforms/annotate_target.cc
@@ -258,6 +258,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
     Array<Expr> compiler_begins = std::get<1>(target_n_args);
     Call new_call = Call(post_call->op, compiler_begins, post_call->attrs);
     new_call->checked_type_ = pre->checked_type_;
+    new_call->span = pre->span;
 
     // Update the target map.
     op_expr_to_target_[new_call] = target;
diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py
index 844d08c66e..e6ebec6ac4 100644
--- a/tests/python/contrib/test_ethosu/infra.py
+++ b/tests/python/contrib/test_ethosu/infra.py
@@ -28,7 +28,6 @@ from typing import List
 import os
 import struct
 import numpy as np
-import tflite.Model
 import math
 from enum import IntEnum
 import tensorflow as tf
@@ -311,7 +310,15 @@ def get_tflite_graph(tf_func, shapes, ranges=None):
     converter.inference_output_type = tf.int8
     tflite_graph = converter.convert()
 
-    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+    # Get TFLite model from buffer
+    try:
+        import tflite
+
+        tflite_model = tflite.Model.GetRootAsModel(tflite_graph, 0)
+    except AttributeError:
+        import tflite.Model
+
+        tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
 
     relay_module, params = relay.frontend.from_tflite(tflite_model)
     mod = partition_for_ethosu(relay_module, params)
diff --git a/tests/python/contrib/test_ethosu/test_pass_operations_distribution.py b/tests/python/contrib/test_ethosu/test_pass_operations_distribution.py
new file mode 100644
index 0000000000..2a9d88e412
--- /dev/null
+++ b/tests/python/contrib/test_ethosu/test_pass_operations_distribution.py
@@ -0,0 +1,173 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import numpy as np
+
+from tvm import relay
+from tests.python.contrib.test_ethosu.infra import get_tflite_graph
+from tvm.relay.op.contrib.ethosu import partition_for_ethosu
+from tvm.relay.analysis.operations_distribution import analyze_operations_distribution
+from tvm.relay.transform.suffixes import tag_suffixes
+
+
+def test_operations_distribution_ethos():
+
+    tflite = pytest.importorskip("tflite")
+    tensorflow = pytest.importorskip("tensorflow")
+    pytest.importorskip("ethosu.vela")
+
+    import tensorflow as tf
+
+    inp = (224, 224, 9)
+    input_shape = (1, *inp)
+    kernel_shape = (3, 3)
+    padding = (1, 1, 1, 1)
+    padding_out = (1, 33, 33, 1)
+
+    @tf.function
+    def simple_net(x):
+        weight_shape = [kernel_shape[0], kernel_shape[1], input_shape[3], 3]
+        weights = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+        op = tf.nn.conv2d(
+            x,
+            filters=weights,
+            strides=1,
+            padding="SAME",
+            data_format="NHWC",
+            dilations=1,
+        )
+        op = tf.pad(
+            op,
+            [[0, 0], [padding[0], padding_out[2]], [padding_out[1], padding[3]], [0, 0]],
+            "CONSTANT",
+        )
+        op = tf.pad(
+            op,
+            [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]],
+            "CONSTANT",
+        )
+        return tf.pad(
+            op,
+            [[0, 0], [padding_out[0], padding[2]], [padding[1], padding_out[3]], [0, 0]],
+            "CONSTANT",
+        )
+
+    _, tflite_graph = get_tflite_graph(simple_net, [input_shape])
+
+    # Get TFLite model from buffer
+    try:
+        import tflite
+
+        tflite_model = tflite.Model.GetRootAsModel(tflite_graph, 0)
+    except AttributeError:
+        import tflite.Model
+
+        tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, params = relay.frontend.from_tflite(tflite_model)
+
+    mod = tag_suffixes(mod)
+    mod = partition_for_ethosu(mod, params)
+    operations_distribution = analyze_operations_distribution(mod)
+
+    expected = {
+        "Pad_PART_0": ["generic", "generic", 1],
+        "Conv2D2_PART_2": ["ethos-u", "ethos-u.qnn_conv2d", 3],
+        "Conv2D2_PART_1": ["ethos-u", "ethos-u.qnn_conv2d", 3],
+        "Conv2D2_PART_0": ["ethos-u", "ethos-u.qnn_conv2d", 3],
+        "Identity_PART_0": ["ethos-u", "ethos-u.pad2d", 4],
+        "Pad_1_PART_0": ["ethos-u", "ethos-u.pad2d", 5],
+    }
+
+    assert operations_distribution == expected
+
+
+def test_operations_distribution_generic():
+
+    tflite = pytest.importorskip("tflite")
+    tensorflow = pytest.importorskip("tensorflow")
+    pytest.importorskip("ethosu.vela")
+
+    import tensorflow as tf
+
+    inp = (224, 224, 9)
+    input_shape = (1, *inp)
+    kernel_shape = (3, 3)
+    padding = (1, 1, 1, 1)
+    padding_out = (1, 33, 33, 1)
+    dilations_out = 32
+
+    @tf.function
+    def simple_net(x):
+        weight_shape = [kernel_shape[0], kernel_shape[1], input_shape[3], 3]
+        weights = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+        op = tf.nn.conv2d(
+            x,
+            filters=weights,
+            strides=1,
+            padding="SAME",
+            data_format="NHWC",
+            dilations=dilations_out,
+        )
+        op = tf.pad(
+            op,
+            [[0, 0], [padding[0], padding_out[2]], [padding[1], padding[3]], [0, 0]],
+            "CONSTANT",
+        )
+        op = tf.pad(
+            op,
+            [[0, 0], [padding[0], padding_out[2]], [padding[1], padding[3]], [0, 0]],
+            "CONSTANT",
+        )
+        return tf.pad(
+            op,
+            [[0, 0], [padding[0], padding_out[2]], [padding[1], padding[3]], [0, 0]],
+            "CONSTANT",
+        )
+
+    _, tflite_graph = get_tflite_graph(simple_net, [input_shape])
+
+    # Get TFLite model from buffer
+    try:
+        import tflite
+
+        tflite_model = tflite.Model.GetRootAsModel(tflite_graph, 0)
+    except AttributeError:
+        import tflite.Model
+
+        tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, params = relay.frontend.from_tflite(tflite_model)
+
+    mod = tag_suffixes(mod)
+    mod = partition_for_ethosu(mod, params)
+    operations_distribution = analyze_operations_distribution(mod)
+
+    expected = {
+        "Identity_PART_0": ["generic", "generic", 1],
+        "Pad_1_PART_0": ["generic", "generic", 2],
+        "Pad_PART_0": ["generic", "generic", 3],
+        "Conv2D2_PART_2": ["generic", "generic", 4],
+        "Conv2D2_PART_1": ["generic", "generic", 5],
+        "Conv2D2_PART_0": ["generic", "generic", 6],
+    }
+
+    assert operations_distribution == expected
+
+
+if __name__ == "__main__":
+    test_operations_distribution()
diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py
index 61b1828aad..f624984481 100644
--- a/tests/python/driver/tvmc/test_compiler.py
+++ b/tests/python/driver/tvmc/test_compiler.py
@@ -16,6 +16,7 @@
 # under the License.
 import os
 import re
+import numpy as np
 import shutil
 import tarfile
 from os import path
@@ -29,6 +30,7 @@ from tvm.target import Target
 import tvm.testing
 from tvm.relay.op.contrib.ethosn import ethosn_available
 from tvm.relay.backend import Runtime, Executor
+from tvm import relay
 
 from tvm.contrib.target.vitis_ai import vitis_ai_available
 
@@ -49,6 +51,355 @@ def test_save_dumps(tmpdir_factory):
     assert path.exists("{}/{}".format(tmpdir, "fake_module.relay"))
 
 
+def test_save_dump_offloads_ethosu(tmp_path_factory):
+
+    tflite = pytest.importorskip("tflite")
+    tensorflow = pytest.importorskip("tensorflow")
+    pytest.importorskip("ethosu.vela")
+
+    import tensorflow as tf
+    import tflite.Model
+    from tvm.driver.tvmc.model import TVMCModel
+
+    inp = (224, 224, 9)
+    input_shape = (1, *inp)
+    kernel_shape = (3, 3)
+    padding = (1, 1, 1, 1)
+    padding_out = (1, 33, 33, 1)
+
+    @tf.function
+    def simple_net(x):
+        weight_shape = [kernel_shape[0], kernel_shape[1], input_shape[3], 3]
+        weights = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+        weight_shape[2] = 3
+        weights1 = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+        weights2 = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+        op = tf.nn.conv2d(
+            x,
+            filters=weights,
+            strides=1,
+            padding="SAME",
+            data_format="NHWC",
+            dilations=1,
+        )
+        op1 = tf.nn.conv2d(
+            op,
+            filters=weights1,
+            strides=1,
+            padding="SAME",
+            data_format="NHWC",
+            dilations=1,
+        )
+        op2 = tf.nn.conv2d(
+            op,
+            filters=weights2,
+            strides=1,
+            padding="SAME",
+            data_format="NHWC",
+            dilations=1,
+        )
+        op = tf.math.add(op1, op2)
+        op = tf.pad(
+            op,
+            [[0, 0], [padding[0], padding_out[1]], [padding_out[2], padding[3]], [0, 0]],
+            "CONSTANT",
+        )
+        return op
+
+    from tests.python.contrib.test_ethosu.infra import get_tflite_graph
+
+    _, tflite_graph = get_tflite_graph(simple_net, [input_shape])
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+    mod, params = relay.frontend.from_tflite(tflite_model)
+
+    tvmc_model = TVMCModel(mod, params)
+
+    output_dir = tmp_path_factory.mktemp("tmp")
+    output_file_name = os.path.join(str(output_dir), "list.txt")
+
+    tvmc.compiler.compile_model(
+        tvmc_model,
+        target="ethos-u,cmsis-nn,c",
+        runtime=Runtime("crt"),
+        tuning_records="",
+        package_path="module.tar",
+        executor=Executor("aot", {"unpacked-api": 1, "interface-api": "c", "link-params": True}),
+        cross="",
+        cross_options="",
+        output_format="mlf",
+        dump_offloads=output_file_name,
+        disabled_pass=[""],
+        pass_context_configs=[
+            "tir.disable_vectorize=1",
+            "tir.usmp.enable=1",
+            "tir.usmp.algorithm=hill_climb",
+            "tir.disable_storage_rewrite=1",
+            "relay.frontend.fill_span=1",
+        ],
+        additional_target_options={
+            "c": {"mcpu": "cortex-m55"},
+            "cmsis-nn": {"mcpu": "cortex-m55"},
+            "ethos-u": {
+                "accelerator_config": "ethos-u55-256",
+            },
+        },
+    )
+
+    expected = [
+        r"Total number of operators and distribution by targets",
+        r"Total: 11",
+        r"ethos-u: 10",
+        r"generic: 1",
+        r"",
+        r"ethos-u        <-     ethos-u.qnn_conv2d",
+        r'ethos-u        <-          %0 = qnn.conv2d(%x, %v_param_1, -128, 0, 0.00392157f, meta[relay.Constant][0], padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32")',
+        r"ethos-u        <-          %1 = nn.bias_add(%0, %v_param_2, axis=3)",
+        r'ethos-u        <-          %2 = qnn.requantize(%1, meta[relay.Constant][1], 0, 0.11364f, -128, axis=3, out_dtype="int8")',
+        r"ethos-u        <-     ethos-u.qnn_conv2d",
+        r'ethos-u        <-          %3 = qnn.conv2d(%2, %v_param_3, -128, 0, 0.11364f, meta[relay.Constant][2], padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32")',
+        r"ethos-u        <-          %4 = nn.bias_add(%3, %v_param_4, axis=3)",
+        r'ethos-u        <-          %7 = qnn.requantize(%4, meta[relay.Constant][3], 0, 1.56803f, -128, axis=3, out_dtype="int8")',
+        r"ethos-u        <-     ethos-u.qnn_conv2d",
+        r'ethos-u        <-          %5 = qnn.conv2d(%2, %v_param_5, -128, 0, 0.11364f, meta[relay.Constant][4], padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32")',
+        r"ethos-u        <-          %6 = nn.bias_add(%5, %v_param_6, axis=3)",
+        r'ethos-u        <-          %8 = qnn.requantize(%6, meta[relay.Constant][5], 0, 1.20538f, -128, axis=3, out_dtype="int8")',
+        r"ethos-u        <-     ethos-u.add",
+        r"ethos-u        <-          %9 = qnn.add(%7, %8, 1.56803f, -128, 1.20538f, -128, 2.77341f, -128)",
+        r"generic        <-     nn.pad(%9, -128f, pad_width=[[0, 0], [1, 33], [33, 1], [0, 0]])",
+    ]
+
+    file_path = os.path.abspath(output_file_name)
+    # check that file file_path was created
+    assert os.path.exists(file_path)
+    with open(file_path, "r") as f:
+        for i, file_string in enumerate(f):
+            r_output = re.search(r"(.*)\(", file_string.strip(), re.DOTALL)
+            r_expected = re.search(r"(.*)\(", expected[i], re.DOTALL)
+            # check that there is the same sequence of operations and composites,
+            # combined with target names
+            if r_output and r_expected:
+                assert r_output.group(0) == r_expected.group(0)
+            else:
+                assert r_output == r_expected
+
+
+def test_save_dump_offloads_cmsis(tmp_path_factory):
+
+    tflite = pytest.importorskip("tflite")
+    tensorflow = pytest.importorskip("tensorflow")
+    pytest.importorskip("ethosu.vela")
+
+    import tensorflow as tf
+    from tvm.driver.tvmc.model import TVMCModel
+
+    inp = (224, 224, 9)
+    input_shape = (1, *inp)
+    kernel_shape = (3, 3)
+    padding = (1, 1, 1, 1)
+    padding_out = (1, 33, 33, 1)
+
+    @tf.function
+    def simple_net(x):
+        weight_shape = [kernel_shape[0], kernel_shape[1], input_shape[3], 3]
+        weights = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+        op = tf.nn.conv2d(
+            x,
+            filters=weights,
+            strides=1,
+            padding="SAME",
+            data_format="NHWC",
+            dilations=1,
+        )
+        op = tf.nn.relu(op)
+        op = tf.pad(
+            op,
+            [[0, 0], [padding[0], padding_out[2]], [padding_out[1], padding[3]], [0, 0]],
+            "CONSTANT",
+        )
+        op = tf.pad(
+            op,
+            [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]],
+            "CONSTANT",
+        )
+        return tf.pad(
+            op,
+            [[0, 0], [padding_out[0], padding[2]], [padding[1], padding_out[3]], [0, 0]],
+            "CONSTANT",
+        )
+
+    from tests.python.contrib.test_ethosu.infra import get_tflite_graph
+
+    _, tflite_graph = get_tflite_graph(simple_net, [input_shape])
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+    mod, params = relay.frontend.from_tflite(tflite_model)
+
+    tvmc_model = TVMCModel(mod, params)
+
+    output_dir = tmp_path_factory.mktemp("tmp")
+    output_file_name = os.path.join(str(output_dir), "list.txt")
+
+    tvmc.compiler.compile_model(
+        tvmc_model,
+        target="cmsis-nn,c",
+        runtime=Runtime("crt"),
+        tuning_records="",
+        package_path="module.tar",
+        executor=Executor("aot", {"unpacked-api": 1, "interface-api": "c", "link-params": True}),
+        cross="",
+        cross_options="",
+        output_format="mlf",
+        dump_offloads=output_file_name,
+        disabled_pass=[""],
+        pass_context_configs=[
+            "tir.disable_vectorize=1",
+            "tir.usmp.enable=1",
+            "tir.usmp.algorithm=hill_climb",
+            "tir.disable_storage_rewrite=1",
+            "relay.frontend.fill_span=1",
+        ],
+        additional_target_options={
+            "c": {"mcpu": "cortex-m55"},
+            "cmsis-nn": {"mcpu": "cortex-m55"},
+        },
+    )
+
+    expected = [
+        r"Total number of operators and distribution by targets",
+        r"Total: 7",
+        r"cmsis-nn: 4",
+        r"generic: 3",
+        r"",
+        r"cmsis-nn       <-     cmsis-nn.qnn_conv2d",
+        r'cmsis-nn       <-          %0 = qnn.conv2d(%x, %v_param_1, -128, 0, 0.00392157f, meta[relay.Constant][0], padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32")',
+        r"cmsis-nn       <-          %1 = nn.bias_add(%0, %v_param_2, axis=3)",
+        r'cmsis-nn       <-          %2 = qnn.requantize(%1, meta[relay.Constant][1], 0, 0.113405f, -128, axis=3, out_dtype="int8")',
+        r"cmsis-nn       <-          %3 = clip(%2, a_min=-128f, a_max=127f)",
+        r"generic        <-     %4 = nn.pad(%3, -128f, pad_width=[[0, 0], [1, 33], [33, 1], [0, 0]])",
+        r"generic        <-     %5 = nn.pad(%4, -128f, pad_width=[[0, 0], [1, 1], [1, 1], [0, 0]])",
+        r"generic        <-     nn.pad(%5, -128f, pad_width=[[0, 0], [1, 1], [1, 1], [0, 0]])",
+    ]
+
+    file_path = os.path.abspath(output_file_name)
+    # check that file file_path was created
+    assert os.path.exists(file_path)
+    with open(file_path, "r") as f:
+        for i, file_string in enumerate(f):
+            r_output = re.search(r"(.*)\(", file_string.strip(), re.DOTALL)
+            r_expected = re.search(r"(.*)\(", expected[i], re.DOTALL)
+            # check that there is the same sequence of operations and composites,
+            # combined with target names
+            if r_output and r_expected:
+                assert r_output.group(0) == r_expected.group(0)
+            else:
+                assert r_output == r_expected
+
+
+def test_save_dump_offloads_generic(tmp_path_factory):
+
+    tflite = pytest.importorskip("tflite")
+    tensorflow = pytest.importorskip("tensorflow")
+    pytest.importorskip("ethosu.vela")
+
+    import tensorflow as tf
+    from tvm.driver.tvmc.model import TVMCModel
+
+    inp = (224, 224, 9)
+    input_shape = (1, *inp)
+    kernel_shape = (3, 3)
+    padding = (1, 1, 1, 1)
+    padding_out = (1, 33, 33, 1)
+
+    @tf.function
+    def simple_net(x):
+        weight_shape = [kernel_shape[0], kernel_shape[1], input_shape[3], 3]
+        weights = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+        op = tf.nn.conv2d(
+            x,
+            filters=weights,
+            strides=1,
+            padding="SAME",
+            data_format="NHWC",
+            dilations=1,
+        )
+        op = tf.pad(
+            op,
+            [[0, 0], [padding[0], padding_out[2]], [padding_out[1], padding[3]], [0, 0]],
+            "CONSTANT",
+        )
+        op = tf.pad(
+            op,
+            [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]],
+            "CONSTANT",
+        )
+        return tf.pad(
+            op,
+            [[0, 0], [padding_out[0], padding[2]], [padding[1], padding_out[3]], [0, 0]],
+            "CONSTANT",
+        )
+
+    from tests.python.contrib.test_ethosu.infra import get_tflite_graph
+
+    _, tflite_graph = get_tflite_graph(simple_net, [input_shape])
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+    mod, params = relay.frontend.from_tflite(tflite_model)
+
+    tvmc_model = TVMCModel(mod, params)
+
+    output_dir = tmp_path_factory.mktemp("tmp")
+    output_file_name = os.path.join(str(output_dir), "list.txt")
+
+    tvmc.compiler.compile_model(
+        tvmc_model,
+        target="c",
+        runtime=Runtime("crt"),
+        tuning_records="",
+        package_path="module.tar",
+        executor=Executor("aot", {"unpacked-api": 1, "interface-api": "c", "link-params": True}),
+        cross="",
+        cross_options="",
+        output_format="mlf",
+        dump_offloads=output_file_name,
+        disabled_pass=[""],
+        pass_context_configs=[
+            "tir.disable_vectorize=1",
+            "tir.usmp.enable=1",
+            "tir.usmp.algorithm=hill_climb",
+            "tir.disable_storage_rewrite=1",
+            "relay.frontend.fill_span=1",
+        ],
+        additional_target_options={
+            "c": {"mcpu": "cortex-m55"},
+        },
+    )
+
+    expected = [
+        r"Total number of operators and distribution by targets",
+        r"Total: 6",
+        r"generic: 6",
+        r"",
+        r'generic        <-     %0 = qnn.conv2d(%x, %v_param_1, -128, 0, 0.00392156f, meta[relay.Constant][0], padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32")',
+        r"generic        <-     %1 = nn.bias_add(%0, %v_param_2, axis=3)",
+        r'generic        <-     %2 = qnn.requantize(%1, meta[relay.Constant][1], 0, 0.103975f, -128, axis=3, out_dtype="int8")',
+        r"generic        <-     %3 = nn.pad(%2, -128f, pad_width=[[0, 0], [1, 33], [33, 1], [0, 0]])",
+        r"generic        <-     %4 = nn.pad(%3, -128f, pad_width=[[0, 0], [1, 1], [1, 1], [0, 0]])",
+        r"generic        <-     nn.pad(%4, -128f, pad_width=[[0, 0], [1, 1], [1, 1], [0, 0]])",
+    ]
+
+    file_path = os.path.abspath(output_file_name)
+    # check that file file_path was created
+    assert os.path.exists(file_path)
+    with open(file_path, "r") as f:
+        for i, file_string in enumerate(f):
+            r_output = re.search(r"(.*)\(", file_string.strip(), re.DOTALL)
+            r_expected = re.search(r"(.*)\(", expected[i], re.DOTALL)
+            # check that there is the same sequence of operations and composites,
+            # combined with target names
+            if r_output and r_expected:
+                assert r_output.group(0) == r_expected.group(0)
+            else:
+                assert r_output == r_expected
+
+
 # End to end tests for compilation