You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/10/11 15:38:21 UTC

[tvm] branch main updated: Arm(R) Ethos(TM)-U NPU Depthwise2d operator support (#9209)

This is an automated email from the ASF dual-hosted git repository.

mbaret pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8ba0451  Arm(R) Ethos(TM)-U NPU Depthwise2d operator support (#9209)
8ba0451 is described below

commit 8ba0451d3bff4d6486908d0f9dcf064fe916dd36
Author: Elen Kalda <el...@arm.com>
AuthorDate: Mon Oct 11 16:37:54 2021 +0100

    Arm(R) Ethos(TM)-U NPU Depthwise2d operator support (#9209)
    
    * Arm(R) Ethos(TM)-U NPU Depthwise2d operator support
    
    This commit adds support for Depthwise2d primitive operator throughout
    the TVM stack including Relay legalization pass, operator definition,
    TE, TIR passes and translation into the command stream.
    
    Change-Id: If82b85f5d3b23cd214fe38babd724451bf95ef5b
    
    * Change depthwise2d to depthwise_conv2d
    
    And respond to other review comments.
    
    Change-Id: I58a9f28723750970d386b4d0ba62fa399c5c6181
    
    * Make a line shorter and add a comment
    
    Change-Id: Idf4c078bf65e7ed31fe82a92bf334295a82b6ead
    
    * Change the order of imports
    
    Change-Id: Ic6c77af30a5b9cb68dcc0c173b95490965359481
    
    * Whitespace change
    
    Change-Id: I7318bd8cfa5985b33fc7d020cc19057cc9498197
---
 .../tvm/relay/backend/contrib/ethosu/legalize.py   |  94 +++++++++
 .../relay/backend/contrib/ethosu/op/__init__.py    |   1 +
 .../relay/backend/contrib/ethosu/op/depthwise.py   | 205 ++++++++++++++++++++
 .../relay/backend/contrib/ethosu/te/__init__.py    |   1 +
 .../relay/backend/contrib/ethosu/te/depthwise.py   | 148 ++++++++++++++
 .../relay/backend/contrib/ethosu/tir/depthwise.py  | 116 +++++++++++
 .../tvm/relay/backend/contrib/ethosu/tir/passes.py |   2 +
 .../tvm/relay/backend/contrib/ethosu/tir/spec.py   |   2 +-
 .../tvm/relay/backend/contrib/ethosu/tir/utils.py  |   6 +-
 .../backend/contrib/ethosu/tir_to_cs_translator.py |  49 +++++
 .../tvm/relay/backend/contrib/ethosu/vela_api.py   |  15 +-
 python/tvm/relay/op/contrib/ethosu.py              |  82 +++++++-
 src/relay/op/contrib/ethosu/depthwise.cc           | 212 +++++++++++++++++++++
 tests/python/contrib/test_ethosu/infra.py          | 159 ++++++++++++++--
 tests/python/contrib/test_ethosu/test_codegen.py   |  94 ++++++++-
 tests/python/contrib/test_ethosu/test_legalize.py  | 139 +++++++++++++-
 .../test_ethosu/test_replace_depthwise_conv2d.py   | 178 +++++++++++++++++
 .../test_ethosu/test_tir_to_cs_translator.py       |  75 ++++++++
 .../contrib/test_ethosu/test_type_inference.py     |  96 ++++++++++
 tests/python/driver/tvmc/test_compiler.py          |   6 +-
 20 files changed, 1631 insertions(+), 49 deletions(-)

diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py
index fd58da8..b970aec 100644
--- a/python/tvm/relay/backend/contrib/ethosu/legalize.py
+++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py
@@ -208,6 +208,99 @@ class LegalizeEthosUConv2D:
         pass
 
 
+class EthosuDepthwiseConv2DRewriter(DFPatternCallback):
+    """Convert ethosu.qnn_depthwise_conv2d composite functions to ethosu_depthwise_conv2d
+    operators"""
+
+    def __init__(self):
+        super().__init__(require_type=True)
+        self.pattern = (
+            wildcard().has_attr(
+                {"Composite": ethosu_patterns.QnnDepthwiseConv2DParams.composite_name}
+            )
+        )(wildcard())
+
+    def callback(
+        self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
+    ) -> tvm.relay.Expr:
+        params = ethosu_patterns.QnnDepthwiseConv2DParams(post.op.body)
+        params.ifm.tensor = post.args[0]
+        channels_map = {
+            "NHWC": 3,
+        }
+        if str(params.ofm.layout) not in channels_map.keys():
+            raise UnsupportedLayout(str(params.ofm.layout))
+        kernel_shape_map = {
+            "HWOI": params.weights.shape[0:2],
+        }
+        if str(params.weights.layout) not in kernel_shape_map.keys():
+            raise UnsupportedLayout(str(params.weights.layout))
+
+        weights_values = params.weights.values
+        weights_values_ohwi = np.moveaxis(weights_values, [0, 1, 2, 3], [1, 2, 0, 3])
+
+        activation = "NONE"
+        # Activations requiring LUT is currently not supported, so setting it to an empty list
+        lut = relay.const([], "int8")
+        clip_min = 0
+        clip_max = 0
+        if params.activation:
+            activation = ethosu_patterns.QnnDepthwiseConv2DParams.activation_map[
+                params.activation.op.name
+            ]
+            if activation == "CLIP":
+                clip_min = int(params.activation.attrs.a_min)
+                clip_max = int(params.activation.attrs.a_max)
+        scale_bias = vela_api.pack_biases(
+            biases=params.biases.tensor.data.asnumpy(),
+            ifm_scale=params.ifm.q_params.scale_f32,
+            ifm_dtype=np.dtype(params.ifm.dtype),
+            weight_scales=params.weights.q_params.scale_f32,
+            ofm_scale=params.ofm.q_params.scale_f32,
+            is_activation_tanh_or_sigmoid=activation in ["TANH", "SIGMOID"],
+        )
+
+        ethosu_depthwise_conv2d = ethosu_ops.ethosu_depthwise_conv2d(
+            post.args[0],  # IFM
+            relay.const(weights_values_ohwi, params.weights.values.dtype),
+            relay.const(scale_bias, "uint8"),
+            lut,
+            float(params.ifm.q_params.scale_f32),
+            int(params.ifm.q_params.zero_point),
+            int(params.weights.q_params.zero_point),
+            float(params.ofm.q_params.scale_f32),
+            int(params.ofm.q_params.zero_point),
+            kernel_shape_map[str(params.weights.layout)],
+            params.ofm.shape[channels_map[str(params.ofm.layout)]],
+            strides=params.strides,
+            padding=params.padding,
+            dilation=params.dilation,
+            activation=activation,
+            clip_min=clip_min,
+            clip_max=clip_max,
+            upscale="NONE",
+            ifm_layout=str(params.ifm.layout),
+            ofm_layout=str(params.ofm.layout),
+        )
+        return ethosu_depthwise_conv2d
+
+
+@ir.transform.module_pass(opt_level=1)
+class LegalizeEthosUDepthwiseConv2D:
+    """This is the pass that wraps the EthosUDepthwiseConv2DRewriter"""
+
+    def transform_module(
+        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+    ) -> tvm.ir.IRModule:
+        for global_var, func in mod.functions.items():
+            func = rewrite(EthosuDepthwiseConv2DRewriter(), func)
+            mod.update_func(global_var, func)
+        return mod
+
+    def __call__(self, *args, **kwargs):
+        pass
+
+
 @ir.transform.module_pass(opt_level=1)
 class LegalizeEthosU:
     """This is the pass to call graph-rewrites to perform graph transformation
@@ -220,6 +313,7 @@ class LegalizeEthosU:
     ) -> tvm.ir.IRModule:
         mod = LegalizeSplit()(mod)
         mod = LegalizeEthosUConv2D()(mod)
+        mod = LegalizeEthosUDepthwiseConv2D()(mod)
         return mod
 
     def __call__(self, *args, **kwargs):
diff --git a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py
index 0406298..1063db6 100644
--- a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py
+++ b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py
@@ -17,3 +17,4 @@
 "Relay operators for the Arm(R) Ethos(TM)-U NPU"
 
 from .convolution import ethosu_conv2d
+from .depthwise import ethosu_depthwise_conv2d
diff --git a/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py
new file mode 100644
index 0000000..abcddf9
--- /dev/null
+++ b/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py
@@ -0,0 +1,205 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""Relay operator for depthwise convolution"""
+from typing import Tuple
+
+import tvm
+from tvm.relay.op import _make
+from tvm.topi.generic import schedule_injective
+from tvm.relay.op.op import OpStrategy
+from tvm.relay.op import strategy as _strategy
+
+from ..te import depthwise_conv2d_compute
+
+
+def _extract_ethosu_depthwise_conv2d_params(attrs, args):
+    """Get the parameters necessary to construct a ethosu_depthwise_conv2d compute TE
+    from a ethosu_depthwise_conv2d Relay call."""
+    ifm = args[0]
+    weight = args[1]
+    scale_bias = args[2]
+    lut = args[3]
+    ifm_scale = attrs.ifm_scale
+    ifm_zero_point = attrs.ifm_zero_point
+    weight_zero_point = attrs.weight_zero_point
+    ofm_scale = attrs.ofm_scale
+    ofm_zero_point = attrs.ofm_zero_point
+    strides = attrs.strides
+    padding = attrs.padding
+    dilation = attrs.dilation
+    activation = attrs.activation
+    clip_min = attrs.clip_min
+    clip_max = attrs.clip_max
+    upscale = attrs.upscale
+    ifm_layout = attrs.ifm_layout
+    ofm_layout = attrs.ofm_layout
+
+    return (
+        ifm,
+        weight,
+        scale_bias,
+        lut,
+        ifm_scale,
+        ifm_zero_point,
+        weight_zero_point,
+        ofm_scale,
+        ofm_zero_point,
+        strides,
+        padding,
+        dilation,
+        activation,
+        clip_min,
+        clip_max,
+        upscale,
+        ifm_layout,
+        ofm_layout,
+    )
+
+
+@tvm.ir.register_op_attr("contrib.ethosu.depthwise_conv2d", "FTVMCompute")
+def create_ethosu_depthwise_conv2d_compute(attrs, args, out_type):
+    """Create an ethosu_depthwise_conv2d compute op."""
+    params = _extract_ethosu_depthwise_conv2d_params(attrs, args)
+    op = depthwise_conv2d_compute(*params)
+    return [op]
+
+
+@tvm.ir.register_op_attr("contrib.ethosu.depthwise_conv2d", "FTVMStrategy")
+def depthwise_conv2d_strategy_ethosu(attrs, inputs, out_type, target):
+    strategy = OpStrategy()
+    strategy.add_implementation(
+        create_ethosu_depthwise_conv2d_compute,
+        _strategy.wrap_topi_schedule(schedule_injective),
+        name="ethosu_depthwise_conv2d",
+    )
+    return strategy
+
+
+def ethosu_depthwise_conv2d(
+    ifm: tvm.relay.Expr,
+    weight: tvm.relay.Expr,
+    scale_bias: tvm.relay.Expr,
+    lut: tvm.relay.Expr,
+    ifm_scale: float,
+    ifm_zero_point: int,
+    weight_zero_point: int,
+    ofm_scale: float,
+    ofm_zero_point: int,
+    kernel_shape: Tuple[int, int],
+    ofm_channels: int,
+    strides: Tuple[int, int] = (1, 1),
+    padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
+    dilation: Tuple[int, int] = (1, 1),
+    activation: str = "NONE",
+    clip_min: int = 0,
+    clip_max: int = 0,
+    upscale: str = "NONE",
+    ifm_layout: str = "NHWC",
+    ofm_layout: str = "NHWC",
+) -> tvm.relay.Call:
+    """This is a quantized 2D depthwise convolution operation as supported
+    by the NPU. It accepts either NHWC or NHCWB16 format
+    for the input data and OHWI format for the kernel weights.
+
+    Reference: https://developer.arm.com/documentation/102420/0200/
+
+    Note that the per-channel weight scale and bias tensor must be
+    packed together into a combined tensor of uint80s. This is represented
+    in TVM by a (channels, 10) tensor of type uint8. For more detail,
+    refer to the Technical Reference Manual linked above.
+
+    Parameters
+    ----------
+    ifm : tvm.relay.Expr
+        The Input Feature Map tensor (IFM).
+    weight : tvm.relay.Expr
+        The weight tensor.
+    scale_bias : tvm.relay.Expr
+        The packed per-channel weight scale and bias tensor.
+    lut : tvm.relay.Expr
+        The look-up table values to use if activation = "LUT"
+    ifm_scale : float
+        The quantization scale for the Input Feature Map tensor.
+    ifm_zero_point : int
+        The quantization zero point for the Input Feature Map tensor.
+    weight_zero_point : int
+        The quantization zero point for the weight tensor.
+    ofm_scale : float
+        The quantization scale for the Output Feature Map tensor.
+    ofm_zero_point : int
+        The quantization zero point for the Output Feature Map tensor.
+    kernel_shape : tuple of int
+        The 2 dimensional kernel shape as (kernel_height, kernel_width).
+    ofm_channels : int
+        The number of OFM channels.
+    strides : tuple of int, optional
+        The 2 dimensional strides as (stride_height, stride_width).
+    padding : tuple of int, optional
+        The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right).
+    dilation : tuple of int, optional
+        The 2 dimensional dilation as (dilation_height, dilation_width).
+    activation : str, optional
+        The activation function to use.
+            "NONE" - no activation function.
+            "CLIP" - clip the output between clip_min and clip_max.
+            "TANH" - tanh activation function.
+            "SIGMOID" - sigmoid activation function.
+            "LUT" - use a look-up table to perform
+        the activation function.
+    clip_min : int, optional
+        The minimum clipping value if activation = "CLIP"
+    clip_max : int, optional,
+        The maximum clipping value if activation = "CLIP"
+    upscale : str, optional
+        The 2x2 upscaling mode to apply to the Input Feature Map tensor.
+            "NONE" - no upscaling.
+            "NEAREST" - upscale using nearest neighbour.
+            "ZEROS" - upscale using zeros.
+    ifm_layout : str, optional
+        The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16".
+    ofm_layout : str, optional
+        The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16".
+
+    Returns
+    -------
+    out : tvm.relay.Call
+        A call to the ethosu_depthwise_conv2d op.
+
+    """
+    return _make.ethosu_depthwise_conv2d(
+        ifm,
+        weight,
+        scale_bias,
+        lut,
+        ifm_scale,
+        ifm_zero_point,
+        weight_zero_point,
+        ofm_scale,
+        ofm_zero_point,
+        kernel_shape,
+        ofm_channels,
+        strides,
+        padding,
+        dilation,
+        activation,
+        clip_min,
+        clip_max,
+        upscale,
+        ifm_layout,
+        ofm_layout,
+    )
diff --git a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py
index 7ca5de3..5dcdd4d 100644
--- a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py
+++ b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py
@@ -17,3 +17,4 @@
 """Tensor Expressions for the NPU"""
 
 from .convolution import *
+from .depthwise import *
diff --git a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py
new file mode 100644
index 0000000..35ae7f9
--- /dev/null
+++ b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py
@@ -0,0 +1,148 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-argument
+"""Tensor Expressions for depthwise convolutions"""
+from typing import Tuple, Union, List
+
+from tvm import te
+from .dma import dma_ofm_compute, dma_ifm_compute
+
+
+def depthwise_conv2d_compute(
+    ifm: te.Tensor,
+    weight: te.Tensor,
+    scale_bias: te.Tensor,
+    lut: te.Tensor,
+    ifm_scale: float,
+    ifm_zero_point: int,
+    weight_zero_point: int,
+    ofm_scale: float,
+    ofm_zero_point: int,
+    strides: Tuple[int, int],
+    padding: Tuple[int, int, int, int],
+    dilation: Union[Tuple[int, int], List[int]],
+    activation: str,
+    clip_min: int,
+    clip_max: int,
+    upscale: str,
+    ifm_layout: str,
+    ofm_layout: str,
+) -> te.Tensor:
+    """A compute operator representing the capabilities of 2D convolution for the NPU.
+
+    Parameters
+    ----------
+    ifm : te.Tensor
+        The Input Feature Map tensor (IFM).
+    weight : te.Tensor
+        The weight tensor.
+    scale_bias : te.Tensor
+        The packed per-channel weight scale and bias tensor.
+    lut : te.Tensor
+        The look-up table values to use if activation = "LUT".
+    ifm_scale : float
+        The quantization scale for the Input Feature Map tensor.
+    ifm_zero_point : int
+        The quantization zero point for the Input Feature Map tensor.
+    weight_zero_point : int
+        The quantization zero point for the weight tensor.
+    ofm_scale : float
+        The quantization scale for the Output Feature Map tensor.
+    ofm_zero_point : int
+        The quantization zero point for the Output Feature Map tensor.
+    strides : tuple
+        The 2 dimensional strides as (stride_height, stride_width).
+    padding : tuple
+        The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right).
+    dilation : Union[int, tuple, list]
+        The 2 dimensional dilation as (dilation_height, dilation_width).
+    activation : str
+        The activation function to use.
+            "NONE" - no activation function.
+            "CLIP" - clip the output between clip_min and clip_max.
+            "TANH" - tanh activation function.
+            "SIGMOID" - sigmoid activation function.
+            "LUT" - use a look-up table to perform the activation function.
+    clip_min : int
+        The minimum clipping value if activation = "CLIP".
+    clip_max : int
+        The maximum clipping value if activation = "CLIP".
+    upscale : str
+        The 2x2 upscaling mode to apply to the Input Feature Map tensor.
+            "NONE" - no upscaling.
+            "NEAREST" - upscale using nearest neighbour.
+            "ZEROS" - upscale using zeros.
+    ifm_layout : str
+        The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16".
+    ofm_layout : str
+        The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16".
+
+    Returns
+    -------
+    te.Tensor
+        The OFM tensor.
+
+    """
+    assert ifm.shape[0] == 1, f"Only batch size 1 is supported"
+    assert ifm_layout in {"NHWC", "NHCWB16"}
+    assert ofm_layout in {"NHWC", "NHCWB16"}
+
+    stride_h, stride_w = strides
+    dilation_h, dilation_w = dilation
+    channels, kernel_h, kernel_w, _ = weight.shape
+
+    # Compute operation for the IFM DMA pipeline
+    dmaed_ifm = dma_ifm_compute(ifm, ifm_layout, ifm_zero_point, ifm_scale, channels, padding)
+
+    # 2D Depthwise Convolution compute operation
+    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
+    ofm_height = (dmaed_ifm.shape[1] - dilated_kernel_h) // stride_h + 1
+    ofm_width = (dmaed_ifm.shape[2] - dilated_kernel_w) // stride_w + 1
+    rh = te.reduce_axis((0, kernel_h), name="ry")
+    rw = te.reduce_axis((0, kernel_w), name="rx")
+
+    depthwise_conv2d_attrs = {
+        "op": "ethosu_depthwise_conv2d",
+        "weight_zero_point": weight_zero_point,
+        "activation": activation,
+        "upscale": upscale,
+        "clip_min": clip_min,
+        "clip_max": clip_max,
+        "stride_h": stride_h,
+        "stride_w": stride_w,
+        "dilation_h": dilation_h,
+        "dilation_w": dilation_w,
+    }
+
+    depthwise = te.compute(
+        (1, ofm_height, ofm_width, channels),
+        lambda nn, hh, ww, cc: te.sum(
+            dmaed_ifm(
+                nn, hh * stride_h + rh * dilation_h, ww * stride_w + rw * dilation_w, cc
+            ).astype(ifm.dtype)
+            * weight[cc, rh, rw, 0].astype(ifm.dtype)
+            # This is a trick to load 10 elements of the scale_bias at once, not accurate maths
+            + (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype),
+            axis=[rh, rw],
+        ),
+        name="ethosu_depthwise_conv2d",
+        attrs=depthwise_conv2d_attrs,
+    )
+
+    # Compute operation for the OFM DMA pipeline
+    return dma_ofm_compute(depthwise, ofm_layout, ofm_zero_point, ofm_scale, channels)
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py
new file mode 100644
index 0000000..27111a9
--- /dev/null
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py
@@ -0,0 +1,116 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""Extract information from the depthwise convolution operators in TIR."""
+from typing import Dict, Tuple
+import tvm
+from ..vela_api import SCALE_BIAS_LENGTH
+from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores
+from .dma import get_ifm_params, get_ofm_params
+from .spec import (
+    SerialKernel,
+    SerialAddressRange,
+    SerialActivation,
+    Serial2DDepthwise,
+)
+
+
+def get_depthwise_conv2d_params(
+    stmt: tvm.tir.AttrStmt,
+    producers: Dict[tvm.tir.Var, tvm.tir.AttrStmt],
+    consumers: Dict[tvm.tir.Var, tvm.tir.AttrStmt],
+) -> Tuple[Serial2DDepthwise, tvm.tir.Var, tvm.tir.Var]:
+    """Get the parameters necessary to construct a call_extern for a depthwise_conv2d.
+
+    Parameters
+    ----------
+    stmt : tvm.tir.AttrStmt
+        The outermost attribute statement of a depthwise loop nest.
+    producers : Dict[tvm.tir.Var, tvm.tir.AttrStmt]
+        A dictionary to associate pointers with the loop nest
+        that produces their values.
+    consumers : Dict[tvm.tir.Var, tvm.tir.AttrStmt]
+        A dictionary to associate pointers with the loop nest
+        that consumes their values.
+
+    Returns
+    -------
+    Serial2DDepthwise
+        The parameters needed to construct a 2D depthwise.
+    output_pointer : tvm.tir.Var
+        The output pointer of the convolution operation.
+    replace_pointer : tvm.tir.Var
+        The output pointer of the DMA write operation, which is to replace
+        the convolution output pointer.
+
+    """
+    attrs, body = get_op_attrs(stmt)
+    _, _, _, _, _, inner = get_outer_loops(body, "NHWC")
+    rh = inner
+    rw = rh.body
+    # loads = [output, input, weights, scale_bias, scale_bias]
+    loads = get_loads(rw.body)
+    # stores = [output]
+    stores = get_stores(rw.body)
+    input_pointer = loads[1].buffer_var
+    output_pointer = stores[0].buffer_var
+    # Get feature map info
+    serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
+    serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers)
+    # Get kernel info
+    serial_kernel = SerialKernel(
+        width=int(rw.extent),
+        height=int(rh.extent),
+        stride_w=int(attrs["stride_w"]),
+        stride_h=int(attrs["stride_h"]),
+        dilation_w=int(attrs["dilation_w"]),
+        dilation_h=int(attrs["dilation_h"]),
+    )
+    # Get scale_bias info
+    scale_bias_load = loads[3]
+    scale_bias_base = get_base_address(scale_bias_load.index)
+    serial_scale_bias = SerialAddressRange(
+        address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base),
+        length=SCALE_BIAS_LENGTH * serial_ofm[3],
+    )
+    # Get weight info
+    weight_load = loads[2]
+    weight_base = get_base_address(weight_load.index)
+    serial_weight = SerialAddressRange(
+        address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base),
+        length=serial_ofm[3] * serial_kernel[0] * serial_kernel[1],
+    )
+    # Get activation info
+    serial_activation = SerialActivation(
+        op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"]
+    )
+
+    return (
+        Serial2DDepthwise(
+            ifm=serial_ifm,
+            ofm=serial_ofm,
+            kernel=serial_kernel,
+            weight=serial_weight,
+            weight_zero_point=attrs["weight_zero_point"],
+            scale_bias=serial_scale_bias,
+            padding=serial_padding,
+            activation=serial_activation,
+            upscale="NONE",
+        ),
+        output_pointer,
+        replace_pointer,
+    )
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
index 1af4496..8bb410e 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
@@ -21,6 +21,7 @@ import numpy as np  # type: ignore
 import tvm
 from tvm.relay.backend.contrib.ethosu import vela_api
 from .convolution import get_conv2d_params
+from .depthwise import get_depthwise_conv2d_params
 from .transform import get_copy_params
 from .utils import get_weights_pointer, get_scale_bias_pointer
 
@@ -52,6 +53,7 @@ def ReplaceOperators():
     op_map = {
         "ethosu_conv2d": get_conv2d_params,
         "ethosu_copy": get_copy_params,
+        "ethosu_depthwise_conv2d": get_depthwise_conv2d_params,
     }
     pointer_to_producer = {}
     pointer_to_consumer = {}
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py
index 3ecbcd5..ff019c7 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py
@@ -203,7 +203,7 @@ class Serial2DConvolution(SerializableFormat):
 
 class Serial2DDepthwise(SerializableFormat):
     """Specialization class to retrieve arguments of
-    a ethosu.depthwise2d tir extern call on a predefined ordering"""
+    a ethosu.depthwise_conv2d TIR extern call on a predefined ordering"""
 
     def __init__(
         self,
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py
index 7d6fd3b..ccfc2df 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py
@@ -23,7 +23,8 @@ from tvm import arith
 # TODO(@mbaret): Formalise this with a specification
 def get_weights_pointer(tir_extern_call):
     """Get the weights pointer from a NPU extern call if it exists"""
-    if tir_extern_call.args[0] == "ethosu_conv2d":
+    supported_ops = ["ethosu_conv2d", "ethosu_depthwise_conv2d"]
+    if tir_extern_call.args[0] in supported_ops:
         return tir_extern_call.args[41].buffer_var
     return None
 
@@ -31,7 +32,8 @@ def get_weights_pointer(tir_extern_call):
 # TODO(@mbaret): Formalise this with a specification
 def get_scale_bias_pointer(tir_extern_call):
     """Get the scale_bias pointer from a NPU extern call if it exists"""
-    if tir_extern_call.args[0] == "ethosu_conv2d":
+    supported_ops = ["ethosu_conv2d", "ethosu_depthwise_conv2d"]
+    if tir_extern_call.args[0] in supported_ops:
         return tir_extern_call.args[44].buffer_var
     return None
 
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
index 4b28dc5..408eab6 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
@@ -299,6 +299,7 @@ def translate_ethosu_tir_extern_call(tir_extern_call):
     supported_extern_calls = {
         "ethosu_conv2d": translate_ethosu_conv2d,
         "ethosu_copy": translate_ethosu_copy,
+        "ethosu_depthwise_conv2d": translate_ethosu_depthwise_conv2d,
     }
     ext_call_type = tir_extern_call.args[0].value
     assert ext_call_type in supported_extern_calls.keys(), f"{ext_call_type} is not yet supported"
@@ -408,6 +409,54 @@ def _create_npu_op_conv2d(serial_2d_convolution):
     return npu_conv2d_op, weights_zero_point
 
 
+def translate_ethosu_depthwise_conv2d(tir_extern_call):
+    """This function will translate a tir extern_call
+    as produced by Relay to TIR compilation.
+
+    Parameters
+    ----------
+    tir_extern_call : tvm.tir.Call
+        This should be a tir external call that has an agreed upon ordering
+        for NPU TIR Compiler. See Serial2DDepthwise in
+        tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering.
+
+    Returns
+    -------
+    ethosu.vela.api.NpuDepthWiseOperation
+        The vela object containing the params of ethosu_depthwise_conv2d
+    weights_zero_point : int
+        The zero point of the weights
+    """
+    serial_object = spec.create_serial_object(spec.Serial2DDepthwise, tir_extern_call.args[1:])
+    return _create_npu_op_depthwise_conv2d(serial_object)
+
+
+def _create_npu_op_depthwise_conv2d(serial_2d_depthwise):
+    npu_depthwise_conv2d_op = vapi.NpuConvDepthWiseOperation()
+
+    npu_depthwise_conv2d_op.ifm = _create_npu_feature_map(serial_2d_depthwise.ifm)
+    npu_depthwise_conv2d_op.ofm = _create_npu_feature_map(serial_2d_depthwise.ofm)
+    npu_depthwise_conv2d_op.kernel = _create_npu_kernel(serial_2d_depthwise.kernel)
+    npu_depthwise_conv2d_op.weights = [_create_npu_address_range(serial_2d_depthwise.weight)]
+    weights_zero_point = np.int64(serial_2d_depthwise.weight_zero_point.value)
+    npu_depthwise_conv2d_op.biases = [_create_npu_address_range(serial_2d_depthwise.scale_bias)]
+    npu_depthwise_conv2d_op.padding = _create_npu_padding(serial_2d_depthwise.padding)
+
+    npu_depthwise_conv2d_op.activation = _create_npu_activation(serial_2d_depthwise.activation)
+    if (
+        npu_depthwise_conv2d_op.activation
+        and npu_depthwise_conv2d_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU
+    ):
+        _convert_clip_bounds(npu_depthwise_conv2d_op)
+
+    npu_depthwise_conv2d_op.upscale = _create_npu_resampling_mode(serial_2d_depthwise.upscale)
+    target_accel_type = vela_api.get_target_accel_type()
+    block_config = vela_api.get_optimal_block_config(npu_depthwise_conv2d_op, target_accel_type)
+    npu_depthwise_conv2d_op.block_config = block_config
+
+    return npu_depthwise_conv2d_op, weights_zero_point
+
+
 def _create_npu_feature_map(serial_feature_map):
     """This is a helper function to capture a list
     of arguments to create Vela NpuFeatureMap object
diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py b/python/tvm/relay/backend/contrib/ethosu/vela_api.py
index 5009c31..6523352 100644
--- a/python/tvm/relay/backend/contrib/ethosu/vela_api.py
+++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py
@@ -130,18 +130,22 @@ def encode_weights(tir_extern_call, values, accel_type):
     bytearray
         Compressed weights
     """
-    supported_ops = ["ethosu_conv2d"]
+    supported_ops = {
+        "ethosu_conv2d": tirtocs.translate_ethosu_conv2d,
+        "ethosu_depthwise_conv2d": tirtocs.translate_ethosu_depthwise_conv2d,
+    }
     op = str(tir_extern_call.args[0].value)
-    assert op in supported_ops
-    npu_op, weights_zero_point = tirtocs.translate_ethosu_conv2d(tir_extern_call)
+    assert op in supported_ops.keys()
+    npu_op, weights_zero_point = supported_ops[op](tir_extern_call)
     block_config = get_optimal_block_config(npu_op, accel_type)
     # The weight layout is assumed to be flat OHWI, always.
     assert len(values.shape) == 1
+    is_depthwise = op == "ethosu_depthwise_conv2d"
     shape_ohwi = (
         npu_op.ofm.shape.depth,
         npu_op.kernel.height,
         npu_op.kernel.width,
-        npu_op.ifm.shape.depth,
+        1 if is_depthwise else npu_op.ifm.shape.depth,
     )
     assert values.size == np.prod(shape_ohwi)
     values = np.reshape(values, shape_ohwi)
@@ -154,8 +158,7 @@ def encode_weights(tir_extern_call, values, accel_type):
         block_depth=block_config.depth,
         dilation=(npu_op.kernel.dilation_x, npu_op.kernel.dilation_y),
         accel_type=accel_type,
-        # TODO(@manupa-arm): change this when we support depthwise
-        is_depthwise=False,
+        is_depthwise=is_depthwise,
     )
 
 
diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py
index 85ddfd9..4369376 100644
--- a/python/tvm/relay/op/contrib/ethosu.py
+++ b/python/tvm/relay/op/contrib/ethosu.py
@@ -192,11 +192,11 @@ class QnnConv2DParams:
         bias_add = requantize_op.args[0]
         qnn_conv2d = bias_add.args[0]
         data_layout = qnn_conv2d.attrs.data_layout
-        kernel_layout = qnn_conv2d.attrs.kernel_layout
+        self.kernel_layout = qnn_conv2d.attrs.kernel_layout
         # We consider the weights & biases as params as it should be a Constant
         self.weights = TensorParams(
             qnn_conv2d.args[QConv2DArgs.WEIGHTS.value],
-            kernel_layout,
+            self.kernel_layout,
             qnn_conv2d.args[QConv2DArgs.WEIGHTS_SCALE.value],
             qnn_conv2d.args[QConv2DArgs.WEIGHTS_ZERO_POINT.value],
         )
@@ -219,16 +219,18 @@ class QnnConv2DParams:
             requantize_op.args[RequantArgs.OFM_SCALE.value],
             requantize_op.args[RequantArgs.OFM_ZERO_POINT.value],
         )
-        self.padding = qnn_conv2d.attrs.padding
-        self.strides = qnn_conv2d.attrs.strides
-        self.dilation = qnn_conv2d.attrs.dilation
+        attrs = qnn_conv2d.attrs
+        self.padding = attrs.padding
+        self.strides = attrs.strides
+        self.dilation = attrs.dilation
         self.activation = activation
+        self.channels = attrs.channels
 
         # If groups are equal to channel, its a depthwise_conv2d
-        self.groups = qnn_conv2d.attrs.groups
+        self.groups = attrs.groups
         self.is_depthwise = False
         channels_axis = {"HWIO": 3, "HWOI": 2}
-        if qnn_conv2d.attrs.groups == self.weights.shape[channels_axis[kernel_layout]]:
+        if self.groups == self.weights.shape[channels_axis[self.kernel_layout]]:
             self.is_depthwise = True
 
     def is_valid(self) -> bool:
@@ -253,10 +255,52 @@ class QnnConv2DParams:
         legal_groups = [1, self.ofm.shape[3]]
         if self.groups not in legal_groups:
             return False
-        # This should be a valid QnnDepthwise2DParams, not QnnConv2DParams
+        # This should be a valid QnnDepthwiseConv2DParams, not QnnConv2DParams
         return not self.is_depthwise
 
 
+class QnnDepthwiseConv2DParams(QnnConv2DParams):
+    """
+    This class will parse a call to a ethosu.depthwise_conv2d composite function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.depthwise_conv2d"
+    # The hardware only supports padding upto the numbers as follows
+    padding_bounds = [31, 31, 32, 32]
+
+    def __init__(self, func_body: tvm.relay.expr.Call):
+        QnnConv2DParams.__init__(self, func_body)
+
+    def is_valid(self):
+        """
+        Checks whether QnnDepthwiseConv2D + activation function has compatible attributes with HW
+        """
+        tensor_params = [self.weights, self.ifm, self.ofm]
+        if not check_valid_dtypes(tensor_params):
+            return False
+        if not check_weights(self.weights, self.dilation):
+            return False
+        if not check_bias(self.biases):
+            return False
+        if not check_strides(self.strides):
+            return False
+        if not check_batch_size(self.ifm):
+            return False
+        if not check_dilation(self.dilation):
+            return False
+        if not check_padding(self.padding, self.padding_bounds):
+            return False
+        if self.weights.layout != "HWOI":
+            return False
+        # only depth multiplier of size 1 is supported
+        if self.weights.shape[3] != 1:
+            return False
+        if not self.is_depthwise:
+            return False
+        return True
+
+
 def qnn_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
     """
     This function creates the pattern for qnn.conv2D with optional fused RELU activation.
@@ -272,6 +316,21 @@ def qnn_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
     return clip_or_req
 
 
+def qnn_depthwise_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for depthwise qnn.conv2D with optional fused RELU activation.
+    """
+    qnn_conv2d = is_op("qnn.conv2d")(
+        wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
+    ).has_attr({"kernel_layout": "HWOI"})
+    bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant())
+    req = is_op("qnn.requantize")(
+        bias_add, is_constant(), is_constant(), is_constant(), is_constant()
+    )
+    clip_or_req = req.optional(is_op("clip"))
+    return clip_or_req
+
+
 @register_pattern_table("ethosu")
 def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]:
     return [
@@ -279,7 +338,12 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
             QnnConv2DParams.composite_name,
             qnn_conv2d_pattern(),
             lambda pat: QnnConv2DParams(pat).is_valid(),
-        )
+        ),
+        (
+            QnnDepthwiseConv2DParams.composite_name,
+            qnn_depthwise_conv2d_pattern(),
+            lambda pat: QnnDepthwiseConv2DParams(pat).is_valid(),
+        ),
     ]
 
 
diff --git a/src/relay/op/contrib/ethosu/depthwise.cc b/src/relay/op/contrib/ethosu/depthwise.cc
new file mode 100644
index 0000000..fa73645
--- /dev/null
+++ b/src/relay/op/contrib/ethosu/depthwise.cc
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/contrib/ethosu/depthwise.cc
+ * \brief Depthwise convolution 2D operator definition for the Arm(R) Ethos(TM)-U NPU
+ */
+#include <tvm/relay/base.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/qnn/attrs.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/data_layout.h>
+
+#include "../../../qnn/utils.h"
+#include "../../nn/convolution.h"
+#include "common.h"
+
+namespace tvm {
+namespace relay {
+namespace op {
+namespace contrib {
+namespace ethosu {
+
+/*! \brief Attributes used by the Ethos(TM)-U NPU depthwise operator */
+struct EthosuDepthwiseConv2DAttrs : public tvm::AttrsNode<EthosuDepthwiseConv2DAttrs> {
+  double ifm_scale;
+  int ifm_zero_point;
+  int weight_zero_point;
+  double ofm_scale;
+  int ofm_zero_point;
+  Array<IndexExpr> kernel_shape;
+  IndexExpr ofm_channels;
+  Array<IndexExpr> strides;
+  Array<IndexExpr> padding;
+  Array<IndexExpr> dilation;
+  String activation;
+  int clip_min;
+  int clip_max;
+  String upscale;
+  String ifm_layout;
+  String ofm_layout;
+
+  TVM_DECLARE_ATTRS(EthosuDepthwiseConv2DAttrs, "relay.attrs.EthosuDepthwiseConv2DAttrs") {
+    TVM_ATTR_FIELD(ifm_scale).describe("The quantization scale for the Input Feature Map tensor.");
+    TVM_ATTR_FIELD(ifm_zero_point)
+        .describe("The quantization zero point for the Output Feature Map tensor.");
+    TVM_ATTR_FIELD(weight_zero_point)
+        .describe("The quantization zero point for the weight tensor.");
+    TVM_ATTR_FIELD(ofm_scale).describe("The quantization scale for the Output Feature Map tensor.");
+    TVM_ATTR_FIELD(ofm_zero_point)
+        .describe("The quantization zero point for the Output Feature Map tensor.");
+    TVM_ATTR_FIELD(kernel_shape)
+        .describe("The 2 dimensional kernel shape as (kernel_height, kernel_width).")
+        .set_default(NullValue<Array<IndexExpr> >());
+    TVM_ATTR_FIELD(ofm_channels)
+        .describe("The number of OFM channels.")
+        .set_default(NullValue<IndexExpr>());
+    TVM_ATTR_FIELD(strides)
+        .describe("The 2 dimensional strides as (stride_height, stride_width).")
+        .set_default(Array<IndexExpr>({1, 1}));
+    TVM_ATTR_FIELD(padding)
+        .describe("The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right)")
+        .set_default(Array<IndexExpr>({0, 0, 0, 0}));
+    TVM_ATTR_FIELD(dilation)
+        .describe("The 2 dimensional dilation as (dilation_height, dilation_width).")
+        .set_default(Array<IndexExpr>({1, 1}));
+    TVM_ATTR_FIELD(activation)
+        .describe(
+            "Description: The activation function to use."
+            "'NONE' - no activation function."
+            "'CLIP' - clip the output between clip_min and clip_max."
+            "'TANH - tanh activation function."
+            "'SIGMOID' - sigmoid activation function."
+            "'LUT' - use a look-up table to perform the activation function.")
+        .set_default("NONE");
+    TVM_ATTR_FIELD(clip_min)
+        .describe("The minimum clipping value if activation = CLIP.")
+        .set_default(0);
+    TVM_ATTR_FIELD(clip_max)
+        .describe("The maximum clipping value if activation = CLIP.")
+        .set_default(0);
+    TVM_ATTR_FIELD(upscale)
+        .describe(
+            "The 2x2 upscaling mode to apply to the Input Feature Map tensor. "
+            "'NONE' - no upscaling. "
+            "'NEAREST' - upscale using nearest neighbour. "
+            "'ZEROS' - upscale using zeros.")
+        .set_default("NONE");
+    TVM_ATTR_FIELD(ifm_layout)
+        .set_default("NHWC")
+        .describe("The layout of the Input Feature Map tensor. Can be 'NHWC' or 'NHCWB16'.");
+    TVM_ATTR_FIELD(ofm_layout)
+        .set_default("NHWC")
+        .describe("The layout of the Output Feature Map tensor. Can be 'NHWC' or 'NHCWB16'.");
+  }
+};
+
+TVM_REGISTER_NODE_TYPE(EthosuDepthwiseConv2DAttrs);
+
+bool EthosuDepthwiseConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                              const TypeReporter& reporter) {
+  ICHECK_EQ(types.size(), 5);
+  const auto* ifm = types[0].as<TensorTypeNode>();
+  const auto* weight = types[1].as<TensorTypeNode>();
+  const auto* scale_bias = types[2].as<TensorTypeNode>();
+  if (ifm == nullptr || weight == nullptr) return false;
+
+  const auto* param = attrs.as<EthosuDepthwiseConv2DAttrs>();
+  ICHECK(param != nullptr) << "EthosuDepthwiseConv2DAttrs cannot be nullptr.";
+  ICHECK(ifm->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8))
+      << "Expected ethosu_depthwise_conv2d type(uint8) or type(int8) for ifm but was "
+      << ifm->dtype;
+  ICHECK(weight->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8))
+      << "Expected ethosu_depthwise_conv2d type(uint8) or type(int8) for weight but was "
+      << weight->dtype;
+  ICHECK(scale_bias->dtype == DataType::UInt(8))
+      << "Expected ethosu_depthwise_conv2d type(uint8) for scale_bias but was "
+      << scale_bias->dtype;
+
+  // Collect the ifm, weight and ofm tensors for using in the inference function
+  Array<Type> tensor_types = {types[0], types[1], types[4]};
+
+  // Assign weight type {ofm_channels, kernel_height, kernel_width, 1}
+  reporter->Assign(types[1], TensorType({param->ofm_channels, param->kernel_shape[0],
+                                         param->kernel_shape[1], weight->shape[3]},
+                                        weight->dtype));
+
+  // Assign ofm type
+  auto ofm_shape =
+      EthosuInferKernelOutput(ifm->shape, param->ifm_layout, param->ofm_layout, param->kernel_shape,
+                              param->ofm_channels, param->dilation, param->strides, param->padding);
+
+  reporter->Assign(types[4], TensorType(ofm_shape, ifm->dtype));
+
+  return true;
+}
+
+Expr MakeEthosuDepthwiseConv2D(Expr ifm, Expr weight, Expr scale_bias, Expr lut, double ifm_scale,
+                               int ifm_zero_point, int weight_zero_point, double ofm_scale,
+                               int ofm_zero_point, Array<IndexExpr> kernel_shape,
+                               IndexExpr ofm_channels, Array<IndexExpr> strides,
+                               Array<IndexExpr> padding, Array<IndexExpr> dilation,
+                               String activation, int clip_min, int clip_max, String upscale,
+                               String ifm_layout, String ofm_layout) {
+  auto attrs = make_object<EthosuDepthwiseConv2DAttrs>();
+  attrs->ifm_scale = ifm_scale;
+  attrs->ifm_zero_point = ifm_zero_point;
+  attrs->weight_zero_point = weight_zero_point;
+  attrs->ofm_scale = ofm_scale;
+  attrs->ofm_zero_point = ofm_zero_point;
+  attrs->kernel_shape = std::move(kernel_shape);
+  attrs->ofm_channels = std::move(ofm_channels);
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->dilation = std::move(dilation);
+  attrs->activation = std::move(activation);
+  attrs->clip_min = clip_min;
+  attrs->clip_max = clip_max;
+  attrs->upscale = std::move(upscale);
+  attrs->ifm_layout = std::move(ifm_layout);
+  attrs->ofm_layout = std::move(ofm_layout);
+  static const Op& op = Op::Get("contrib.ethosu.depthwise_conv2d");
+  return Call(op, {ifm, weight, scale_bias, lut}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.ethosu_depthwise_conv2d")
+    .set_body_typed(MakeEthosuDepthwiseConv2D);
+
+RELAY_REGISTER_OP("contrib.ethosu.depthwise_conv2d")
+    .describe(R"code(Arm(R) Ethos(TM)-U NPU 2D quantized depthwise operator.
+
+This Relay operator corresponds to the hardware-implemented quantized
+depthwise operation found on Ethos(TM)-U NPUs. It accepts either NHWC or NHCWB16 format
+for the input data (input feature map, or IFM) and OHWI format for the kernel weights.
+
+- **ifm**: NHWC - (1, ifm_height, ifm_width, ifm_channels)
+           NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16)
+- **weight**: (ofm_channels, kernel_shape[0], kernel_shape[1], 1 (depth multiplier))
+- **scale_bias**: (ofm_channels, 10)
+- **ofm**: (1, ofm_height, ofm_width, ofm_channels)
+
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<EthosuDepthwiseConv2DAttrs>()
+    .set_num_inputs(4)
+    .add_argument("ifm", "Tensor", "The Input Feature Map tensor (IFM).")
+    .add_argument("weight", "Tensor", "The weight tensor.")
+    .add_argument("scale_bias", "Tensor", "The packed per-channel weight scale and bias tensor.")
+    .add_argument("lut", "Tensor", "The look-up table values to use if activation = 'LUT'")
+    .set_support_level(11)
+    .add_type_rel("EthosuDepthwiseConv2D", EthosuDepthwiseConv2DRel);
+
+}  // namespace ethosu
+}  // namespace contrib
+}  // namespace op
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py
index 8b0d306..01a7ceb 100644
--- a/tests/python/contrib/test_ethosu/infra.py
+++ b/tests/python/contrib/test_ethosu/infra.py
@@ -29,7 +29,9 @@ from typing import List
 import os
 import struct
 import numpy
+import math
 from enum import IntEnum
+import tensorflow as tf
 
 from ethosu.vela.register_command_stream_generator import CmdMode
 from ethosu.vela.register_command_stream_generator import cmd0
@@ -66,26 +68,6 @@ class VelaArtifacts:
         self.npu_ops = set()
 
 
-def parse_relay_tflite_model(tflite_model, input_tensor, input_shape, input_dtype):
-    mod_, params_ = relay.frontend.from_tflite(
-        tflite_model,
-        shape_dict={input_tensor: input_shape},
-        dtype_dict={input_tensor: input_dtype},
-    )
-    return mod_, params_
-
-
-def parse_tflite_model(model_file):
-    try:
-        import tflite
-
-        return tflite.Model.GetRootAsModel(model_file, 0)
-    except AttributeError:
-        import tflite.Model
-
-        return tflite.Model.Model.GetRootAsModel(model_file, 0)
-
-
 def print_payload(payload):
     cmds = deserialize_command_stream(payload)
     for cmd_val in cmds:
@@ -270,6 +252,58 @@ def flatten_numpy_data(data):
     return reshaped_data
 
 
+class InputGenerator:
+    def __init__(self, random_state):
+        self._random_state = random_state
+
+    def generate(self, size, dtype):
+        if dtype == numpy.float32:
+            print("random float32")
+            return self._random_state.uniform(-1, 1, size).astype(dtype)
+        else:
+            print("random (u)int min=%d max=%d", numpy.iinfo(dtype).min, numpy.iinfo(dtype).max)
+            low = numpy.iinfo(dtype).min
+            high = numpy.iinfo(dtype).max + 1
+            return self._random_state.randint(low, high, size, dtype)
+
+
+def generate_ref_data_tflite(model):
+    """
+    This method generates reference data by running the specified model on tflite with random input data.
+    The random input data and generated output data are returned.
+    """
+    expected_output_data = {}
+    interpreter = tf.lite.Interpreter(model_content=model)
+    interpreter.allocate_tensors()
+
+    input_details = interpreter.get_input_details()
+    output_details = interpreter.get_output_details()
+
+    # Initialize random generators with a fixed seed to get deterministic results
+    seed = 0
+    random_state = numpy.random.RandomState(seed)
+
+    inputgen = InputGenerator(random_state)
+
+    # Generate input data
+    input_data = {
+        input_detail["name"]: inputgen.generate(
+            input_detail["shape"],
+            input_detail["dtype"],
+        )
+        for input_detail in input_details
+    }
+    for index, value in enumerate(input_data.values()):
+        interpreter.set_tensor(index, value)
+    interpreter.invoke()
+
+    expected_output_data = [
+        interpreter.get_tensor(output_detail["index"]) for output_detail in output_details
+    ]
+
+    return input_data, expected_output_data
+
+
 def generate_weights_data(shape, dtype):
     size = 1
     for dim in shape:
@@ -278,7 +312,7 @@ def generate_weights_data(shape, dtype):
 
 
 def get_convolutional_args(call, include_buffers=False, remove_constants=False):
-    """A method to extract the arguments from conv2d or depthwise2d extern call."""
+    """A method to extract the arguments from conv2d or depthwise_conv2d extern call."""
     args = call.args
     conv_args = []
     remove_indices = [0]
@@ -299,6 +333,44 @@ def get_convolutional_args(call, include_buffers=False, remove_constants=False):
     return conv_args
 
 
+def compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation=[1, 1]):
+    assert len(strides) == 2
+    assert len(dilation) == 2
+    assert len(kernel_shape) == 2
+    if padding.lower() == "valid":
+        h = math.ceil((ifm_shape[1] - (kernel_shape[0] - 1) * dilation[0]) / strides[0])
+        w = math.ceil((ifm_shape[2] - (kernel_shape[1] - 1) * dilation[1]) / strides[1])
+    if padding.lower() == "same":
+        h = math.ceil(ifm_shape[1] / strides[0])
+        w = math.ceil(ifm_shape[2] / strides[1])
+    ofm_shape = [ifm_shape[0], h, w, ifm_shape[3]]
+    return ofm_shape
+
+
+def compute_padding_shape(ifm_shape, ofm_shape, padding, kernel_shape, strides, dilation=[1, 1]):
+    assert len(strides) == 2
+    assert len(dilation) == 2
+    assert len(kernel_shape) == 2
+    if padding.lower() == "valid":
+        return [0, 0, 0, 0]
+    if padding.lower() == "same":
+        effective_kernel_shape = [
+            dilation[0] * (kernel_shape[0] - 1) + 1,
+            dilation[1] * (kernel_shape[1] - 1) + 1,
+        ]
+        pad_along_height = max(
+            (ofm_shape[1] - 1) * strides[0] + effective_kernel_shape[0] - ifm_shape[1], 0
+        )
+        pad_along_width = max(
+            (ofm_shape[2] - 1) * strides[1] + effective_kernel_shape[1] - ifm_shape[2], 0
+        )
+        pad_top = pad_along_height // 2
+        pad_bottom = pad_along_height - pad_top
+        pad_left = pad_along_width // 2
+        pad_right = pad_along_width - pad_left
+        return [pad_top, pad_left, pad_bottom, pad_right]
+
+
 def make_ethosu_conv2d(
     ifm,
     ifm_channels,
@@ -343,3 +415,48 @@ def make_ethosu_conv2d(
         ofm_layout=ofm_layout,
     )
     return conv
+
+
+def make_ethosu_depthwise_conv2d(
+    ifm,
+    channels,
+    kernel_shape,
+    padding,
+    strides,
+    dilation,
+    activation="NONE",
+    ifm_layout="NHWC",
+    ofm_layout="NHWC",
+    weight_dtype="int8",
+):
+    # params
+    weight_shape = (channels, kernel_shape[0], kernel_shape[1], 1)
+    padding = get_pad_tuple(padding, kernel_shape)
+
+    scale_bias_data = generate_weights_data((weight_shape[0], 10), "uint8")
+    scale_bias = relay.const(scale_bias_data, dtype="uint8")
+    weight_data = generate_weights_data(weight_shape, weight_dtype)
+    weight = relay.const(weight_data, dtype=weight_dtype)
+    depthwise = ethosu_ops.ethosu_depthwise_conv2d(
+        ifm,
+        weight,
+        scale_bias,
+        lut=relay.const([], dtype="int8"),
+        ifm_scale=0.6,
+        ifm_zero_point=11,
+        weight_zero_point=13,
+        ofm_scale=0.26,
+        ofm_zero_point=15,
+        kernel_shape=kernel_shape,
+        ofm_channels=channels,
+        strides=strides,
+        padding=padding,
+        dilation=dilation,
+        activation=activation,
+        clip_min=15 if activation == "CLIP" else 0,
+        clip_max=105 if activation == "CLIP" else 0,
+        upscale="NONE",
+        ifm_layout=ifm_layout,
+        ofm_layout=ofm_layout,
+    )
+    return depthwise
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py
index 1944de5..4949d68 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -18,14 +18,12 @@
 import pytest
 
 pytest.importorskip("ethosu.vela")
-import os
 import numpy as np
-import pathlib
+import tflite.Model
 
 import tvm
-import tvm.micro as micro
+import tensorflow as tf
 from tvm import relay
-from tvm.relay.backend.contrib import ethosu
 from tvm.relay.backend.contrib.ethosu import util
 from tvm.relay.op.contrib.ethosu import partition_for_ethosu
 from tests.python.relay.aot.aot_test_utils import generate_ref_data
@@ -168,5 +166,93 @@ def test_ethosu_conv2d(accel_type):
         infra.verify_source(compiled_models, accel_type)
 
 
+@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
+@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)])
+@pytest.mark.parametrize(
+    "kernel_shape, activation",
+    [((3, 3), "relu"), ((1, 2), None)],
+)
+@pytest.mark.parametrize("padding", ["SAME", "VALID"])
+@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 2)), ((3, 2), (1, 1))])
+def test_tflite_depthwise_conv2d(
+    accel_type,
+    ifm_shape,
+    kernel_shape,
+    padding,
+    strides,
+    dilation,
+    activation,
+):
+    dtype = "int8"
+
+    def create_tflite_graph():
+        class Model(tf.Module):
+            @tf.function
+            def depthwise_conv2d(self, x):
+                weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
+                weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+                # The input strides to the TensorFlow API needs to be of shape 1x4
+                tf_strides = [1, strides[0], strides[1], 1]
+                op = tf.nn.depthwise_conv2d(
+                    x, weight, strides=tf_strides, padding=padding, dilations=dilation
+                )
+                if activation:
+                    op = tf.nn.relu(op)
+                return op
+
+        model = Model()
+        concrete_func = model.depthwise_conv2d.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
+        )
+
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    tflite_graph = create_tflite_graph()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    relay_module, params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"input": ifm_shape},
+        dtype_dict={"input": dtype},
+    )
+    mod = partition_for_ethosu(relay_module, params)
+
+    # Generate reference data
+    input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
+
+    compiled_models = infra.build_source(
+        mod,
+        input_data,
+        output_data,
+        accel_type,
+    )
+
+    # Assumes only two runtime.Modules are created -- i.e. single offload module
+    imported_modules = compiled_models[0].executor_factory.lib.imported_modules
+    assert len(imported_modules) == 2
+    ethosu_module = imported_modules[0]
+
+    # Verify generated C source
+    get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs")
+    cmms = get_cs(ethosu_module)
+    cmms = bytes.fromhex(cmms)
+
+    infra.print_payload(cmms)
+    infra.verify_source(compiled_models, accel_type)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py
index 911a0e6..b9a588d 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -20,15 +20,33 @@ import pytest
 
 pytest.importorskip("ethosu.vela")
 import numpy as np
+import tensorflow as tf
+import tflite.Model
 
 import tvm
 from tvm import relay
-from tvm.relay.backend.contrib import ethosu
 from tvm.relay.backend.contrib.ethosu import legalize, preprocess
-from tvm.relay.dataflow_pattern import *
-from tvm.relay.op.contrib.ethosu import *
+from tvm.relay import dataflow_pattern
+from tvm.relay.op.contrib import ethosu
+from tvm.relay.build_module import bind_params_by_name
 
 from . import relay_ir_builder
+from . import infra
+
+
+def partition_ethosu_by_table(mod, pattern_table):
+    """In case only the legalization part is supported for an operator, we don't
+    want to add the operator's pattern to the pattern table so that the compiler
+    wouldn't attempt to offload an operator without full stack support."""
+    mod = relay.transform.InferType()(mod)
+    mod = relay.transform.MergeComposite(pattern_table)(mod)
+    mod = relay.transform.AnnotateTarget("ethosu")(mod)
+    mod = relay.transform.MergeCompilerRegions()(mod)
+    mod = relay.transform.InferType()(mod)
+    mod = relay.transform.PartitionGraph()(mod)
+    mod = relay.transform.InferType()(mod)
+    mod = preprocess.preprocess_ext_io()(mod)
+    return mod
 
 
 def test_split_indices_legalize():
@@ -294,7 +312,7 @@ def test_ethosu_conv2d_legalize():
     ]
     for test_case in test_cases:
         mod, conv_params = test_case[0](*test_case[1])
-        mod = partition_for_ethosu(mod)
+        mod = ethosu.partition_for_ethosu(mod)
         mod = legalize.LegalizeEthosUConv2D()(mod)
         verify_linear(mod["tvmgen_default_ethosu_main_0"], conv_params)
 
@@ -327,12 +345,123 @@ def test_ethosu_conv2d_legalize_errors():
 
     for test_case in test_cases:
         mod, conv_params = test_case[0](*test_case[1])
-        mod = partition_for_ethosu(mod)
+        mod = ethosu.partition_for_ethosu(mod)
         with pytest.raises(
             tvm._ffi.base.TVMError, match="EthosUCodegenError: Unsupported Layout NCHW"
         ):
             mod = legalize.LegalizeEthosUConv2D()(mod)
 
 
+@pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 123, 17, 7)])
+@pytest.mark.parametrize("kernel_shape", [(7, 3), (22, 5)])
+@pytest.mark.parametrize("padding", ["SAME", "VALID"])
+@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))])
+@pytest.mark.parametrize("activation", ["RELU", None])
+def test_tflite_depthwise_conv_2d_legalize(
+    ifm_shape, kernel_shape, padding, strides, dilation, activation
+):
+    dtype = "int8"
+
+    def create_tflite_graph():
+        class Model(tf.Module):
+            @tf.function
+            def depthwise_conv2d(self, x):
+                weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
+                weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+                # The input strides to the TensorFlow API needs to be of shape 1x4
+                tf_strides = [1, strides[0], strides[1], 1]
+                op = tf.nn.depthwise_conv2d(
+                    x, weight, strides=tf_strides, padding=padding, dilations=dilation
+                )
+                if activation:
+                    op = tf.nn.relu(op)
+                return op
+
+        model = Model()
+        concrete_func = model.depthwise_conv2d.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
+        )
+
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def verify(ext_func):
+        op = ext_func.body
+        ofm_channels = op.attrs.ofm_channels
+
+        # check IFM
+        ifm = op.args[0].checked_type
+        assert list(ifm.shape) == list(ifm_shape)
+        assert str(ifm.dtype) == dtype
+        assert ifm.shape[3] == ofm_channels
+
+        # check OFM
+        ofm = op.checked_type
+        expected_ofm_shape = infra.compute_ofm_shape(
+            ifm_shape, padding, kernel_shape, strides, dilation
+        )
+        assert list(ofm.shape) == list(expected_ofm_shape)
+        assert str(ofm.dtype) == dtype
+        assert ofm.shape[3] == ofm_channels
+
+        # check weights
+        weights_ohwi = op.args[1].data.asnumpy()
+        assert str(weights_ohwi.dtype) == dtype
+        assert weights_ohwi.shape[0] == ofm_channels
+        assert weights_ohwi.shape[1] == kernel_shape[0]
+        assert weights_ohwi.shape[2] == kernel_shape[1]
+        assert weights_ohwi.shape[3] == 1  # only depth multiplier 1 is supported
+
+        # Check that scale_bias matches weight tensor
+        assert list(op.args[2].checked_type.shape)[0] == ofm_channels
+
+        expected_padding = infra.compute_padding_shape(
+            ifm_shape, expected_ofm_shape, padding, kernel_shape, strides, dilation
+        )
+        assert list(op.attrs.padding) == list(expected_padding)
+        assert op.attrs.ofm_channels == ofm_channels
+        assert list(op.attrs.strides) == list(strides)
+        assert list(op.attrs.dilation) == list(dilation)
+        if activation == "RELU":
+            assert str(op.attrs.activation) == "CLIP"
+
+    depthwise_pattern_table = [
+        (
+            ethosu.QnnDepthwiseConv2DParams.composite_name,
+            ethosu.qnn_depthwise_conv2d_pattern(),
+            lambda pat: ethosu.QnnDepthwiseConv2DParams(pat).is_valid(),
+        )
+    ]
+
+    tflite_graph = create_tflite_graph()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"input": ifm_shape},
+        dtype_dict={"input": dtype},
+    )
+
+    mod["main"] = bind_params_by_name(mod["main"], params)
+    mod = partition_ethosu_by_table(mod, depthwise_pattern_table)
+
+    mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite(
+        legalize.EthosuDepthwiseConv2DRewriter(), mod["tvmgen_default_ethosu_main_0"]
+    )
+    verify(mod["tvmgen_default_ethosu_main_0"])
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git a/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py
new file mode 100644
index 0000000..b3ce74c
--- /dev/null
+++ b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py
@@ -0,0 +1,178 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+pytest.importorskip("ethosu.vela")
+
+import tvm
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
+from .infra import make_ethosu_depthwise_conv2d, get_convolutional_args
+
+
+@pytest.mark.parametrize(
+    "trial",
+    [
+        [(1, 8, 8, 3), 3, (3, 2), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC"],
+        [(1, 8, 8, 3), 3, (1, 1), (2, 1), (1, 1), (1, 1), "TANH", "NHWC", "NHWC"],
+        [(1, 8, 8, 3), 3, (1, 1), (0, 0), (1, 1), (1, 1), "NONE", "NHWC", "NHWC"],
+        [(1, 1, 1, 1), 1, (1, 1), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC"],
+        [(1, 7, 9, 4), 4, (3, 2), (1, 2), (2, 1), (1, 2), "SIGMOID", "NHWC", "NHWC"],
+        [(1, 8, 2, 8, 16), 18, (1, 1), (2, 1), (1, 1), (1, 1), "CLIP", "NHCWB16", "NHWC"],
+        [(1, 7, 9, 40), 40, (3, 2), (1, 2), (2, 1), (1, 2), "CLIP", "NHWC", "NHCWB16"],
+        [(1, 4, 12, 9, 16), 182, (2, 3), (6, 3), (2, 2), (1, 1), "CLIP", "NHCWB16", "NHCWB16"],
+        [(1, 7, 9, 4), 4, (3, 2), (1, 2), (2, 1), (2, 2), "CLIP", "NHWC", "NHWC"],
+        [(1, 7, 9, 41), 41, (3, 2), (1, 2), (2, 1), (2, 2), "CLIP", "NHWC", "NHCWB16"],
+        [
+            (1, 13, 12, 19, 16),
+            182,
+            (1, 3),
+            (5, 3),
+            (2, 1),
+            (2, 1),
+            "CLIP",
+            "NHCWB16",
+            "NHCWB16",
+        ],
+    ],
+)
+def test_depthwise_conv2d_single(trial):
+    def _get_func(
+        ifm_shape,
+        channels,
+        kernel_shape,
+        padding,
+        strides,
+        dilation,
+        activation,
+        ifm_layout,
+        ofm_layout,
+    ):
+        ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
+        depthwise = make_ethosu_depthwise_conv2d(
+            ifm,
+            channels,
+            kernel_shape,
+            padding,
+            strides,
+            dilation,
+            activation,
+            ifm_layout,
+            ofm_layout,
+        )
+        func = relay.Function(relay.analysis.free_vars(depthwise), depthwise)
+        func = run_opt_pass(func, relay.transform.InferType())
+        return func
+
+    func = _get_func(*trial)
+    mod, _ = lower_to_tir(func)
+    data = []
+
+    def _visit(stmt):
+        if isinstance(stmt, tvm.tir.Call):
+            data.append(get_convolutional_args(stmt, remove_constants=True))
+
+    tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit)
+    (
+        ifm_shape,
+        channels,
+        kernel_shape,
+        padding,
+        strides,
+        dilation,
+        activation,
+        ifm_layout,
+        ofm_layout,
+    ) = trial
+    dilated_kernel_h = (kernel_shape[0] - 1) * dilation[0] + 1
+    dilated_kernel_w = (kernel_shape[1] - 1) * dilation[1] + 1
+    if ifm_layout == "NHWC":
+        ifm_stride_c = 1
+        ifm_stride_w = ifm_shape[3]
+        ifm_stride_h = ifm_shape[2] * ifm_shape[3]
+        ofm_height = (ifm_shape[1] - dilated_kernel_h + padding[0] + padding[0]) // strides[0] + 1
+        ofm_width = (ifm_shape[2] - dilated_kernel_w + padding[1] + padding[1]) // strides[1] + 1
+    else:
+        ifm_stride_w = 16
+        ifm_stride_c = 16 * ifm_shape[3]
+        ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3]
+        ofm_height = (ifm_shape[1] - dilated_kernel_h + padding[0] + padding[0]) // strides[0] + 1
+        ofm_width = (ifm_shape[3] - dilated_kernel_w + padding[1] + padding[1]) // strides[1] + 1
+
+    if ofm_layout == "NHWC":
+        ofm_stride_c = 1
+        ofm_stride_w = channels if ofm_width > 1 else 1
+        ofm_stride_h = channels * ofm_width if ofm_height > 1 else 1
+    else:
+        ofm_stride_w = 16
+        ofm_stride_c = 16 * ofm_width
+        ofm_stride_h = 16 * ofm_width * ((channels - 1) // 16 + 1)
+
+    answer = [
+        "int8",
+        ifm_shape[1],
+        ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3],
+        channels,
+        ifm_shape[1],
+        0,
+        ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3],
+        0,
+        0,
+        0,
+        0,
+        0.6,
+        11,
+        ifm_layout,
+        ifm_stride_h,
+        ifm_stride_w,
+        ifm_stride_c,
+        "int8",
+        ofm_height,
+        ofm_width,
+        channels,
+        ofm_height,
+        0,
+        ofm_width,
+        0,
+        0,
+        0,
+        0,
+        0.26,
+        15,
+        ofm_layout,
+        ofm_stride_h,
+        ofm_stride_w,
+        ofm_stride_c,
+        kernel_shape[1],
+        kernel_shape[0],
+        strides[1],
+        strides[0],
+        dilation[1],
+        dilation[0],
+        13,
+        padding[0],
+        padding[1],
+        padding[0],
+        padding[1],
+        activation,
+        15 if activation == "CLIP" else 0,
+        105 if activation == "CLIP" else 0,
+        "NONE",
+    ]
+    assert data[0] == answer, data[0]
diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
index b07f3a5..8240b39 100644
--- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
+++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
@@ -497,6 +497,81 @@ def test_translate_ethosu_conv2d():
             assert w_zero_point == ref["w_zero_point"]
 
 
+# fmt: off
+"""A ethosu_depthwise_conv2d tir testcase for the translator"""
+@tvm.script.ir_module
+class SingleEthosuDepthwiseConv2D:
+    @T.prim_func
+    def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_depthwise_conv2d: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        placeholder_4 = T.match_buffer(placeholder_1, [3, 3, 2, 1], dtype="int8", elem_offset=0, align=128, offset_factor=1)
+        placeholder_5 = T.match_buffer(placeholder_2, [3, 10], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+        placeholder_3 = T.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1)
+        ethosu_depthwise_conv2d_1 = T.match_buffer(ethosu_depthwise_conv2d, [1, 6, 7, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1)
+        # body
+        T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 8, 3, 8, 0, 8, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.6), 11, "NHWC", 24, 3, 1, "int8", 6, 7, 3, 6, 0, 7, T.load("int8", ethosu_depthwise_conv2d_1.data, 0), 0, 0, 0, T.float32(0.26), 15, "NHWC", 21, 3, 1, 2, 3, 1, 1, 1, 1, T.load("int8", placeholder_4.data, 0), 18, 13, T.load("uint8", placeholder_5.data, 0), 30, 0, 0, 0, 0, "CLIP", 15, 105, "NONE", dtype="int8"))
+    __tvm_meta__ = None
+# fmt: on
+
+
+def test_translate_ethosu_depthwise_conv2d():
+    def extract_ethosu_depthwise_conv2d_extern_call(mod):
+        # There should only be a single function
+        assert len(mod.functions.items()) == 1
+        primfunc = mod.functions.items()[0][1]
+
+        ethosu_depthwise_conv2d_calls = list()
+
+        def populate_ethosu_depthwise_conv2d_calls(stmt):
+            if (
+                isinstance(stmt, tvm.tir.Call)
+                and stmt.op.name == "tir.call_extern"
+                and stmt.args[0] == "ethosu_depthwise_conv2d"
+            ):
+                ethosu_depthwise_conv2d_calls.append(stmt)
+
+        stmt_functor.post_order_visit(primfunc.body, populate_ethosu_depthwise_conv2d_calls)
+        return ethosu_depthwise_conv2d_calls[0]
+
+    depthwise_conv2d_call = extract_ethosu_depthwise_conv2d_extern_call(SingleEthosuDepthwiseConv2D)
+    npu_op, w_zero_point = tir_to_cs_translator.translate_ethosu_depthwise_conv2d(
+        depthwise_conv2d_call
+    )
+
+    assert npu_op.ifm.data_type == vapi.NpuDataType.INT8
+    assert npu_op.ifm.shape == vapi.NpuShape3D(8, 8, 3)
+    assert npu_op.ifm.tiles.height_0 == vapi.NpuTileBox(8, 0, 8, [0, 0, 0, 0]).height_0
+    assert npu_op.ifm.tiles.height_1 == vapi.NpuTileBox(8, 0, 8, [0, 0, 0, 0]).height_1
+    assert npu_op.ifm.tiles.width_0 == vapi.NpuTileBox(8, 0, 8, [0, 0, 0, 0]).width_0
+    assert npu_op.ifm.quantization == pytest.approx(vapi.NpuQuantization(0.6, 11))
+    assert npu_op.ifm.layout == vapi.NpuLayout.NHWC
+    assert npu_op.ifm.strides == vapi.NpuShape3D(24, 3, 1)
+    # Compare OFM
+    assert npu_op.ofm.data_type == vapi.NpuDataType.INT8
+    assert npu_op.ofm.shape == vapi.NpuShape3D(6, 7, 3)
+    assert npu_op.ofm.tiles.height_0 == vapi.NpuTileBox(6, 0, 8, [0, 0, 0, 0]).height_0
+    assert npu_op.ofm.tiles.height_1 == vapi.NpuTileBox(6, 0, 7, [0, 0, 0, 0]).height_1
+    assert npu_op.ofm.tiles.width_0 == vapi.NpuTileBox(6, 0, 7, [0, 0, 0, 0]).width_0
+    assert npu_op.ofm.quantization == pytest.approx(vapi.NpuQuantization(0.26, 15))
+    assert npu_op.ofm.layout == vapi.NpuLayout.NHWC
+    assert npu_op.ofm.strides == vapi.NpuShape3D(21, 3, 1)
+    # Compare kernel and padding
+    assert (
+        npu_op.kernel.__dict__
+        == vapi.NpuKernel(w=2, h=3, stride_x=1, stride_y=1, dilation_x=1, dilation_y=1).__dict__
+    )
+    assert npu_op.padding == vapi.NpuPadding(top=0, left=0, bottom=0, right=0)
+    # Compare activation
+    assert npu_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU
+    assert npu_op.activation.min == 0
+    assert npu_op.activation.max == pytest.approx(23.4)
+    # Compare ifm upscaling
+    assert npu_op.ifm_upscale == vapi.NpuResamplingMode.NONE
+    # Compare weight quantization parameters
+    assert w_zero_point == 13
+
+
 def test_translate_ethosu_copy():
     def extract_ethosu_copy_extern_calls(mod):
         """This function will obtain all ethosu_conv2d
diff --git a/tests/python/contrib/test_ethosu/test_type_inference.py b/tests/python/contrib/test_ethosu/test_type_inference.py
new file mode 100644
index 0000000..47fddad
--- /dev/null
+++ b/tests/python/contrib/test_ethosu/test_type_inference.py
@@ -0,0 +1,96 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+pytest.importorskip("ethosu.vela")
+
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+from .infra import make_ethosu_conv2d
+from .infra import make_ethosu_depthwise_conv2d
+
+
+@pytest.mark.parametrize(
+    ["ifm_shape", "ifm_layout"], [((1, 56, 72, 55), "NHWC"), ((1, 56, 4, 72, 16), "NHCWB16")]
+)
+@pytest.mark.parametrize(
+    "ofm_shape,ofm_layout", [((1, 54, 38, 122), "NHWC"), ((1, 54, 8, 38, 16), "NHCWB16")]
+)
+def test_ethosu_conv2d_type_inference(
+    ifm_shape,
+    ifm_layout,
+    ofm_shape,
+    ofm_layout,
+):
+    ifm_channels = 55
+    ofm_channels = 122
+    kernel_shape = (3, 2)
+    padding = (0, 1, 2, 3)
+    strides = (1, 2)
+    dilation = (2, 1)
+    ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
+    conv2d = make_ethosu_conv2d(
+        ifm,
+        ifm_channels,
+        ofm_channels,
+        kernel_shape,
+        padding,
+        strides,
+        dilation,
+        ifm_layout=ifm_layout,
+        ofm_layout=ofm_layout,
+    )
+    f = relay.Function([ifm], conv2d)
+    f = run_opt_pass(f, relay.transform.InferType())
+    assert tuple(f.body.checked_type.shape) == ofm_shape
+
+
+@pytest.mark.parametrize(
+    "ifm_shape, ifm_layout", [((1, 46, 71, 55), "NHWC"), ((1, 46, 4, 71, 16), "NHCWB16")]
+)
+@pytest.mark.parametrize(
+    "ofm_shape, ofm_layout", [((1, 44, 37, 55), "NHWC"), ((1, 44, 4, 37, 16), "NHCWB16")]
+)
+def test_ethosu_depthwise_conv2d_type_inference(
+    ifm_shape,
+    ifm_layout,
+    ofm_shape,
+    ofm_layout,
+):
+    channels = 55
+    kernel_shape = (3, 2)
+    padding = (0, 1, 2, 3)
+    strides = (1, 2)
+    dilation = (2, 1)
+    ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
+    depthwise_conv2d = make_ethosu_depthwise_conv2d(
+        ifm,
+        channels,
+        kernel_shape,
+        padding,
+        strides,
+        dilation,
+        ifm_layout=ifm_layout,
+        ofm_layout=ofm_layout,
+    )
+    f = relay.Function([ifm], depthwise_conv2d)
+    f = run_opt_pass(f, relay.transform.InferType())
+    assert tuple(f.body.checked_type.shape) == ofm_shape
+
+
+if __name__ == "__main__":
+    pytest.main([__file__])
diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py
index 2e4687f..6e57796 100644
--- a/tests/python/driver/tvmc/test_compiler.py
+++ b/tests/python/driver/tvmc/test_compiler.py
@@ -471,7 +471,11 @@ def test_compile_tflite_module_with_external_codegen_ethosu(
                 for name in mlf_package.getnames()
                 if re.match(r"\./codegen/host/src/\D+\d+\.c", name)
             ]
-            assert len(c_source_files) == 17
+            # The number of c_source_files depends on the number of fused subgraphs that
+            # get offloaded to the NPU, e.g. conv2d->depthwise_conv2d->conv2d gets offloaded
+            # as a single subgraph if both of these operators are supported by the NPU.
+            # Currently there are two source files for CPU execution and two offload graphs
+            assert len(c_source_files) == 4
 
 
 @mock.patch("tvm.relay.build")