You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/04/25 09:08:47 UTC

[tvm] branch main updated: [microNPU] Integrate the cascader (#10862)

This is an automated email from the ASF dual-hosted git repository.

manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d2db9cb0d8 [microNPU] Integrate the cascader (#10862)
d2db9cb0d8 is described below

commit d2db9cb0d839e32778f461b77e59f6418282a511
Author: Elen Kalda <el...@arm.com>
AuthorDate: Mon Apr 25 10:08:38 2022 +0100

    [microNPU] Integrate the cascader (#10862)
    
    * [microNPU] Integrate the cascader
    
    Integrate the cascader into the codegen and optionally enable it
    with the enable_cascader flag. Includes placeholder MemoryRegions until
    integration with the PoolInfos provided by a user.
    
    Co-authored-by: Matthew Barrett <ma...@arm.com>
    
    * Fix linting and a docstring
    
    * Plumbing and testing improvements
    
    Plumb the workspace memory pools into into the cascader and make
    the tests to check for the memory reduction.
    
    * enable_cascader() -> is_cascader_enabled()
    
    * Check for the exact value of workspace size
    
    * Remove unused ACCEL_TYPES
    
    * Linting...
    
    Change-Id: If2d92846f05a7e8b21be767163841084538805a9
    
    * Rebasing...
    
    Co-authored-by: Matthew Barrett <ma...@arm.com>
---
 python/tvm/contrib/ethosu/cascader/__init__.py     |   2 +-
 python/tvm/contrib/ethosu/cascader/scheduler.py    |  36 +++-
 python/tvm/relay/backend/contrib/ethosu/codegen.py |  85 +++++++-
 .../relay/backend/contrib/ethosu/tir/compiler.py   |   2 +-
 .../relay/backend/contrib/ethosu/tir/scheduler.py  |   7 +-
 python/tvm/relay/backend/contrib/ethosu/util.py    |   6 +
 src/relay/backend/contrib/ethosu/compiler_attrs.cc |   4 +
 .../test_ethosu/cascader/test_memory_reduction.py  | 223 +++++++++++++++++++++
 tests/python/contrib/test_ethosu/infra.py          |  36 +++-
 9 files changed, 376 insertions(+), 25 deletions(-)

diff --git a/python/tvm/contrib/ethosu/cascader/__init__.py b/python/tvm/contrib/ethosu/cascader/__init__.py
index 3ee350d008..51f5e58a47 100644
--- a/python/tvm/contrib/ethosu/cascader/__init__.py
+++ b/python/tvm/contrib/ethosu/cascader/__init__.py
@@ -36,5 +36,5 @@ from .parts import InlinePart, EthosuPart
 from .device_config import EthosuDeviceConfig
 from .tensor_config import TensorConfigState, MemoryRegion, TensorConfig
 from .plan import Plan
-from .scheduler import apply_proposal, cascade
+from .scheduler import apply_proposal, cascade, extract_memory_info
 from .cascader_options import CascaderOptions
diff --git a/python/tvm/contrib/ethosu/cascader/scheduler.py b/python/tvm/contrib/ethosu/cascader/scheduler.py
index 4198193c11..63d48a19af 100644
--- a/python/tvm/contrib/ethosu/cascader/scheduler.py
+++ b/python/tvm/contrib/ethosu/cascader/scheduler.py
@@ -22,6 +22,7 @@ import numpy as np
 
 from tvm import te
 from tvm import tir
+from tvm import PoolInfo
 from .cascader_options import CascaderOptions
 from .graph import CascaderGraph, Part, Tensor, TESubgraph
 from .parts import EthosuPart
@@ -44,7 +45,7 @@ def tile_nd(
     tensor : te.Tensor
         The tensor to apply the tiling to.
     tile : Tuple[int, ...]
-        The N-dimensional tile size.
+        The N-dimensional tile size
 
     Returns
     -------
@@ -78,8 +79,8 @@ def stripe_part(
         include_inputs=False,
     )
     g.compute_at(sch[te_output_tensor], outer_indices[-1])
-    for ax in outer_indices:
-        sch[te_output_tensor].unroll(ax)
+    for axis in outer_indices:
+        sch[te_output_tensor].unroll(axis)
 
     return sch[te_output_tensor], outer_indices[-1]
 
@@ -198,6 +199,35 @@ def choose_proposal(proposals: List[Proposal], cascade_region: MemoryRegion):
     return proposal_choice
 
 
+def extract_memory_info(memory_pool: PoolInfo) -> MemoryRegion:
+    "Create a MemoryRegion based on the info in the memory pool"
+    size = int(memory_pool.size_hint_bytes)
+    read_bandwidth = int(memory_pool.read_bandwidth_bytes_per_cycle)
+    write_bandwidth = int(memory_pool.write_bandwidth_bytes_per_cycle)
+
+    for param in (size, read_bandwidth, write_bandwidth):
+        assert param != -1, f"{param} needs to be specified for the cascader."
+
+    name_to_burst_lenght = {
+        target.kind.name: burst for target, burst in memory_pool.target_burst_bytes.items()
+    }
+
+    try:
+        burst_length = int(name_to_burst_lenght["ethos-u"])
+    except KeyError:
+        burst_length = 1
+
+    return MemoryRegion(
+        name=memory_pool.pool_name,
+        size=size,
+        read_bandwidth=read_bandwidth,
+        write_bandwidth=write_bandwidth,
+        read_latency=int(memory_pool.read_latency_cycles),
+        write_latency=int(memory_pool.write_latency_cycles),
+        burst_length=burst_length,
+    )
+
+
 def cascade(
     sch: te.Schedule,
     te_graph: TESubgraph,
diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py
index cbc9bc2a8d..19272ed6f7 100644
--- a/python/tvm/relay/backend/contrib/ethosu/codegen.py
+++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py
@@ -17,13 +17,20 @@
 """Codegen for Arm(R) Ethos(TM)-U NPU"""
 from collections import defaultdict
 
+from typing import List, Callable
 import tvm
 from tvm import relay
 from tvm.relay.backend.contrib.ethosu.tir.compiler import LowerToTIR
 from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants
+from tvm.contrib.ethosu.cascader import (
+    cascade,
+    EthosuDeviceConfig,
+    CascaderOptions,
+    MemoryRegion,
+    extract_memory_info,
+)
 from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU
-from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator
-from tvm.relay.backend.contrib.ethosu import util
+from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator, util
 from tvm.relay.expr_functor import ExprMutator, ExprVisitor
 
 # pylint: disable=unused-import
@@ -32,12 +39,6 @@ from tvm.relay.backend.contrib.ethosu import op
 
 from . import _ffi_api
 
-# We are currently using copy_constants scheduler In the long run,
-# this should be a single intelligent and a composite scheduler
-# that can perform scheduling based on user inputs such as
-# scratch memory size.
-SCHEDULER = copy_constants
-
 
 class OptimizeLUTs(ExprMutator):
     """A pass to merge an identity operator with a LUT based activation function with
@@ -334,6 +335,49 @@ def constant_updater(expr, symbol):  # pylint: disable=unused-argument
     return dict()
 
 
+def _create_cascader(
+    options: CascaderOptions,
+    io_region: MemoryRegion,
+    constant_region: MemoryRegion,
+    working_regions: List[MemoryRegion],
+    device_config: EthosuDeviceConfig,
+) -> Callable:
+    def _cascader(te_graph, const_dict, sch):
+        cascade(
+            sch,
+            te_graph,
+            const_dict,
+            options,
+            io_region,
+            constant_region,
+            working_regions,
+            device_config,
+        )
+
+    return _cascader
+
+
+def _ethos_u55_cascader(sram) -> Callable:
+    # TODO(ekalda): Extract the flash info from ConstantPools once it is implemented
+    flash = MemoryRegion(name="FLASH", size=10**7, read_bandwidth=4, write_bandwidth=4)
+
+    device_config = EthosuDeviceConfig(util.get_accelerator_config())
+    cascader_options = CascaderOptions(
+        cascade_region=sram,
+        max_proposals=64,
+        stripe_factors=5,
+        max_plan_size=10,
+        always_copy_size=1024,
+    )
+    return _create_cascader(
+        options=cascader_options,
+        io_region=sram,
+        constant_region=flash,
+        working_regions=[sram],
+        device_config=device_config,
+    )
+
+
 @tvm._ffi.register_func("relay.ext.ethos-u.relay_to_tir")
 def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
     """
@@ -362,9 +406,30 @@ def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
         gv: "ethos-u" for gv, _ in filter(lambda x: util.is_npu_func(x[1]), mod.functions.items())
     }
     mod = mod.with_attr("device_contexts", device_contexts)
-    mod = LowerToTIR(SCHEDULER)(mod)
 
-    return mod
+    # Use the cascader if it is enabled for the U55 accelerator, otherwise use copy_constants
+    # scheduler
+    if util.is_cascader_enabled():
+        assert (
+            util.get_accelerator_config() != "ethos-u65-256"
+        ), "Cascading is not supported for the U65 accelerator"
+
+        workspace_memory_pools = mod.attrs["workspace_memory_pools"]
+
+        assert (
+            workspace_memory_pools
+        ), "Workspace memory pool needs to be provided for the U55 cascader"
+
+        assert (
+            len(workspace_memory_pools.pools) == 1
+        ), "Exactly one workspace pool needs to be provided for the U55 cascader"
+
+        sram = extract_memory_info(workspace_memory_pools.pools[0])
+        tir_mod = LowerToTIR(_ethos_u55_cascader(sram))(mod)
+    else:
+        tir_mod = LowerToTIR(copy_constants())(mod)
+
+    return tir_mod
 
 
 @tvm._ffi.register_func("relay.ext.ethos-u.primfunc_to_artifact")
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
index 545e0a41d8..f2c294cfed 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
@@ -194,7 +194,7 @@ class LowerToTIR:
     def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
         """Lower NPU functions to TIR."""
 
-        tir_mod, const_dict = _lower_to_tir(func, self.scheduler())
+        tir_mod, const_dict = _lower_to_tir(func, self.scheduler)
 
         for param in const_dict.keys():
             const_dict[param] = tvm.nd.array(const_dict[param])
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py
index 5e66a07c31..827a58055d 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py
@@ -260,9 +260,10 @@ def schedule_cache_reads(sch):
         return False
 
     for stage in sch.stages:
-        if _detect_cache_read(stage):
-            fax = stage.fuse(*stage.op.axis)
-            stage.pragma(fax, "op", "ethosu_copy")
+        if stage.attach_type != 2:  # Not inlined
+            if _detect_cache_read(stage):
+                fax = stage.fuse(*stage.op.axis)
+                stage.pragma(fax, "op", "ethosu_copy")
 
 
 def inline_no_ops(cached_func, sch):
diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py
index 64c561ec7f..cc9cc15410 100644
--- a/python/tvm/relay/backend/contrib/ethosu/util.py
+++ b/python/tvm/relay/backend/contrib/ethosu/util.py
@@ -241,6 +241,12 @@ def get_accelerator_config():
     return compiler_attrs.accelerator_config
 
 
+def is_cascader_enabled():
+    """Determine whether the cascader is enabled"""
+    compiler_attrs = tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")()
+    return compiler_attrs.enable_cascader
+
+
 def get_arg_count(func):
     """Helper function to get the number of
     arguments in a python function"""
diff --git a/src/relay/backend/contrib/ethosu/compiler_attrs.cc b/src/relay/backend/contrib/ethosu/compiler_attrs.cc
index 5795db29b4..8cada6c3a3 100644
--- a/src/relay/backend/contrib/ethosu/compiler_attrs.cc
+++ b/src/relay/backend/contrib/ethosu/compiler_attrs.cc
@@ -39,6 +39,7 @@ namespace ethosu {
 /*! \brief Attributes to store the compiler options for Arm(R) Ethos(TM)-U NPU. */
 struct EthosUCompilerConfigNode : public tvm::AttrsNode<EthosUCompilerConfigNode> {
   String accelerator_config;
+  bool enable_cascader;
 
   TVM_DECLARE_ATTRS(EthosUCompilerConfigNode, "ext.attrs.EthosUCompilerConfigNode") {
     TVM_ATTR_FIELD(accelerator_config)
@@ -46,6 +47,9 @@ struct EthosUCompilerConfigNode : public tvm::AttrsNode<EthosUCompilerConfigNode
             "The class of Arm(R) Ethos(TM)-U NPU; possible values = {ethos-u55-32, ethos-u55-64, "
             "ethos-u55-128, ethos-u55-256}")
         .set_default("ethos-u55-256");
+    TVM_ATTR_FIELD(enable_cascader)
+        .describe("Whether the cascader should be enabled")
+        .set_default(false);
   }
 };
 
diff --git a/tests/python/contrib/test_ethosu/cascader/test_memory_reduction.py b/tests/python/contrib/test_ethosu/cascader/test_memory_reduction.py
new file mode 100644
index 0000000000..26a69033c5
--- /dev/null
+++ b/tests/python/contrib/test_ethosu/cascader/test_memory_reduction.py
@@ -0,0 +1,223 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+import pytest
+
+pytest.importorskip("ethosu.vela")
+
+import numpy as np
+import tensorflow as tf
+import tflite.Model
+from tvm import relay
+from tvm.relay.backend import Executor, Runtime
+from tvm.micro import model_library_format as mlf
+from tvm.relay.op.contrib.ethosu import partition_for_ethosu
+import tvm
+from tvm import WorkspaceMemoryPools, PoolInfo
+
+from .. import infra
+
+
+def _get_ethosu_workspace_size(mod, params, accel_type, pool_size, enable_cascader):
+    enable_usmp = True
+
+    target = tvm.target.Target("c")
+    ethosu_target = tvm.target.Target("ethos-u")
+    runtime = Runtime("crt")
+
+    executor = Executor(
+        "aot",
+        {
+            "workspace-byte-alignment": 16,
+            "interface-api": "c",
+            "unpacked-api": True,
+        },
+    )
+    pass_config = {
+        "tir.disable_vectorize": True,
+        "relay.ext.ethos-u.options": {
+            "accelerator_config": accel_type,
+            "enable_cascader": enable_cascader,
+        },
+        "tir.usmp.enable": enable_usmp,
+        "tir.usmp.algorithm": "hill_climb",
+        "tir.disable_storage_rewrite": enable_usmp,
+    }
+
+    workspace_memory_pools = WorkspaceMemoryPools(
+        [
+            PoolInfo(
+                "SRAM",
+                {target: PoolInfo.READ_WRITE_ACCESS, ethosu_target: PoolInfo.READ_WRITE_ACCESS},
+                size_hint_bytes=pool_size,
+                read_bandwidth_bytes_per_cycle=16,
+                write_bandwidth_bytes_per_cycle=16,
+                target_burst_bytes={ethosu_target: 1},
+            ),
+        ]
+    )
+
+    with tvm.transform.PassContext(opt_level=3, config=pass_config):
+        lib = tvm.relay.build(
+            mod,
+            target,
+            executor=executor,
+            runtime=runtime,
+            workspace_memory_pools=workspace_memory_pools,
+            params=params,
+        )
+
+    mlf_memory_map = mlf._build_function_memory_map(lib.function_metadata)
+    return mlf_memory_map["main"][0]["workspace_size_bytes"]
+
+
+@pytest.mark.parametrize(
+    "accel_type, expected_ws_size_without_cascader, expected_ws_size_with_cascader",
+    [
+        ("ethos-u55-256", 1067408, 14096),
+        ("ethos-u55-128", 1067408, 3968),
+        ("ethos-u55-64", 1067408, 2272),
+        ("ethos-u55-32", 1067392, 2256),
+    ],
+)
+def test_double_conv2d(
+    accel_type, expected_ws_size_without_cascader, expected_ws_size_with_cascader
+):
+    np.random.seed(1)
+    ifm_shape = (1, 321, 212, 6)
+
+    @tf.function
+    def tf_graph(x):
+        ofm_channels = 10
+        conv2d = tf.nn.conv2d(
+            x,
+            filters=tf.constant(
+                np.random.uniform(size=[3, 2, ifm_shape[3], ofm_channels]),  # HWIO
+                dtype=tf.float32,
+            ),
+            strides=(1, 1),
+            padding="VALID",
+            dilations=(2, 1),
+        )
+        conv2d = tf.nn.conv2d(
+            conv2d,
+            filters=tf.constant(
+                np.random.uniform(size=(1, 1, ofm_channels, 3)),  # HWIO
+                dtype=tf.float32,
+            ),
+            strides=(3, 2),
+            padding="SAME",
+            dilations=(1, 1),
+        )
+
+        return conv2d
+
+    _, tflite_graph = infra.get_tflite_graph(tf_graph, [ifm_shape])
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    relay_module, params = relay.frontend.from_tflite(tflite_model)
+    mod = partition_for_ethosu(relay_module, params)
+
+    # Run the graph without the cascader, with lots of memory
+    pool_size = 2000000
+    workspace_size_cascader_disabled = _get_ethosu_workspace_size(
+        mod, params, accel_type, pool_size, enable_cascader=False
+    )
+
+    # Run the same graph with the cascader, giving it less memory to persuade cascder to cascade
+    pool_size = 600000
+    workspace_size_cascader_enabled = _get_ethosu_workspace_size(
+        mod, params, accel_type, pool_size, enable_cascader=True
+    )
+
+    assert workspace_size_cascader_disabled == expected_ws_size_without_cascader
+    assert workspace_size_cascader_enabled == expected_ws_size_with_cascader
+
+
+# TODO(ekalda): Fix a bug in the block config selection that selects block config that is too large
+# for the smaller accelerators
+@pytest.mark.parametrize(
+    "accel_type, expected_ws_size_without_cascader, expected_ws_size_with_cascader",
+    [
+        ("ethos-u55-256", 180096, 5024),
+        ("ethos-u55-128", 180096, 4832),
+        pytest.param("ethos-u55-64", 180096, 4832, marks=pytest.mark.xfail),
+        pytest.param("ethos-u55-32", 180096, 4832, marks=pytest.mark.xfail),
+    ],
+)
+def test_depthwise2d_conv2d_pooling(
+    accel_type, expected_ws_size_without_cascader, expected_ws_size_with_cascader
+):
+    np.random.seed(2)
+    ifm_shape = (1, 80, 75, 3)
+
+    @tf.function
+    def tf_graph(x):
+        # This graph will execute as one cascade
+        ofm_channels = 7
+        conv2d = tf.nn.conv2d(
+            x,
+            filters=tf.constant(
+                np.random.uniform(size=[3, 2, ifm_shape[3], ofm_channels]),  # HWIO
+                dtype=tf.float32,
+            ),
+            strides=(1, 1),
+            padding="VALID",
+            dilations=(1, 1),
+        )
+        depthwise2d = tf.nn.depthwise_conv2d(
+            conv2d,
+            tf.constant(np.random.uniform(size=(3, 3, ofm_channels, 1)), dtype=tf.float32),  # HWC1
+            strides=(1, 1, 1, 1),
+            padding="VALID",
+            dilations=(1, 1),
+        )
+        relu = tf.nn.relu(depthwise2d)
+        conv2d = tf.nn.conv2d(
+            relu,
+            filters=tf.constant(
+                np.random.uniform(size=[3, 2, ofm_channels, 2]),  # HWIO
+                dtype=tf.float32,
+            ),
+            strides=(1, 1),
+            padding="SAME",
+            dilations=(1, 1),
+        )
+        max_pool = tf.nn.max_pool(conv2d, (3, 3), (1, 1), "SAME")
+
+        return max_pool
+
+    _, tflite_graph = infra.get_tflite_graph(tf_graph, [ifm_shape])
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    relay_module, params = relay.frontend.from_tflite(tflite_model)
+    mod = partition_for_ethosu(relay_module, params)
+
+    # Run the graph without the cascader, with lots of memory
+    pool_size = 10**6
+    workspace_size_cascader_disabled = _get_ethosu_workspace_size(
+        mod, params, accel_type, pool_size, enable_cascader=False
+    )
+
+    # Run the same graph with the cascader, giving it less memory to persuade cascder to cascade
+    pool_size = 40000
+    workspace_size_cascader_enabled = _get_ethosu_workspace_size(
+        mod, params, accel_type, pool_size, enable_cascader=True
+    )
+
+    assert workspace_size_cascader_disabled == expected_ws_size_without_cascader
+    assert workspace_size_cascader_enabled == expected_ws_size_with_cascader
diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py
index 4d22414e24..0c42b024f2 100644
--- a/tests/python/contrib/test_ethosu/infra.py
+++ b/tests/python/contrib/test_ethosu/infra.py
@@ -109,7 +109,7 @@ def deserialize_command_stream(blob):
     return cmms
 
 
-def create_test_runner(accel="ethos-u55-256", enable_usmp=True):
+def create_test_runner(accel="ethos-u55-256", enable_usmp=True, enable_cascader=False):
     file_dir = os.path.dirname(os.path.abspath(__file__))
     test_root = os.path.join(file_dir, "reference_system")
     _, ethosu_variant, ethosu_macs = accel.split("-")
@@ -134,6 +134,7 @@ def create_test_runner(accel="ethos-u55-256", enable_usmp=True):
         pass_config={
             "relay.ext.ethos-u.options": {
                 "accelerator_config": accel,
+                "enable_cascader": enable_cascader,
             },
             "tir.usmp.enable": enable_usmp,
             "tir.usmp.algorithm": "hill_climb",
@@ -143,9 +144,15 @@ def create_test_runner(accel="ethos-u55-256", enable_usmp=True):
 
 
 def build_source(
-    module, inputs, outputs, accel="ethos-u55-256", output_tolerance=0, enable_usmp=True
+    module,
+    inputs,
+    outputs,
+    accel="ethos-u55-256",
+    output_tolerance=0,
+    enable_usmp=True,
+    enable_cascader=False,
 ):
-    test_runner = create_test_runner(accel, enable_usmp)
+    test_runner = create_test_runner(accel, enable_usmp, enable_cascader)
     return compile_models(
         models=AOTTestModel(
             module=module,
@@ -165,12 +172,13 @@ def verify_source(
     models: List[AOTCompiledTestModel],
     accel="ethos-u55-256",
     enable_usmp=True,
+    enable_cascader=False,
 ):
     """
     This method verifies the generated source from an NPU module by building it and running on an FVP.
     """
     interface_api = "c"
-    test_runner = create_test_runner(accel, enable_usmp)
+    test_runner = create_test_runner(accel, enable_usmp, enable_cascader)
     run_and_check(
         models,
         test_runner,
@@ -284,7 +292,13 @@ def get_tflite_graph(tf_func, shapes, ranges=None):
 
 
 def compare_ethosu_with_reference(
-    mod, input_data, output_data, accel_type, output_tolerance=0, print_cmm=False
+    mod,
+    input_data,
+    output_data,
+    accel_type,
+    output_tolerance=0,
+    print_cmm=False,
+    enable_cascader=False,
 ):
     compiled_models = build_source(
         mod,
@@ -292,6 +306,7 @@ def compare_ethosu_with_reference(
         output_data,
         accel_type,
         output_tolerance=output_tolerance,
+        enable_cascader=enable_cascader,
     )
 
     # Assumes only two runtime.Modules are created -- i.e. single offload module
@@ -304,11 +319,17 @@ def compare_ethosu_with_reference(
         cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
         print_payload(cmms)
 
-    verify_source(compiled_models, accel_type)
+    verify_source(compiled_models, accel_type, enable_cascader=enable_cascader)
 
 
 def compare_tvm_with_tflite(
-    tf_func, shapes, accel_type, ranges=None, output_tolerance=0, print_cmm=False
+    tf_func,
+    shapes,
+    accel_type,
+    ranges=None,
+    output_tolerance=0,
+    print_cmm=False,
+    enable_cascader=False,
 ):
     mod, tflite_graph = get_tflite_graph(tf_func, shapes, ranges)
 
@@ -322,6 +343,7 @@ def compare_tvm_with_tflite(
         accel_type,
         output_tolerance=output_tolerance,
         print_cmm=print_cmm,
+        enable_cascader=enable_cascader,
     )