You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/11/26 09:31:25 UTC

[tvm] branch main updated: [5/10] Code generation for Depthwise Convolution via CMSIS-NN (#9409)

This is an automated email from the ASF dual-hosted git repository.

manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 52edc9a  [5/10] Code generation for Depthwise Convolution via CMSIS-NN (#9409)
52edc9a is described below

commit 52edc9a4fe230cd533e6ee3511924af21adecb07
Author: Ashutosh Parkhi <86...@users.noreply.github.com>
AuthorDate: Fri Nov 26 09:31:02 2021 +0000

    [5/10] Code generation for Depthwise Convolution via CMSIS-NN (#9409)
    
    This PR adds support for depthwise convolution via CMSIS-NN.
---
 python/tvm/relay/op/contrib/cmsisnn.py             |  11 +
 .../backend/contrib/cmsisnn/generate_constants.cc  |  21 +-
 src/relay/backend/contrib/cmsisnn/relay_to_tir.cc  |  43 ++--
 .../backend/contrib/cmsisnn/tir_to_runtime.cc      | 222 +++++++++++++--------
 tests/python/contrib/test_cmsisnn/test_conv2d.py   | 147 ++++++++++++--
 5 files changed, 328 insertions(+), 116 deletions(-)

diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py
index 34efb1d..5e0ad27 100644
--- a/python/tvm/relay/op/contrib/cmsisnn.py
+++ b/python/tvm/relay/op/contrib/cmsisnn.py
@@ -123,6 +123,16 @@ def pattern_table():
         kernel_zp = conv2d.args[3].data.numpy()
         kernel_zp = [kernel_zp] if kernel_zp.ndim == 0 else kernel_zp
 
+        # check if depthwise Conv2D
+        kernel_layout = conv2d.attrs.kernel_layout
+        pos_o = kernel_layout.index("O")
+        groups = conv2d.attrs.groups
+        is_depthwise = False
+        if groups == int(conv2d_input.checked_type.shape[3]) and groups == int(
+            conv2d_weight.checked_type.shape[pos_o]
+        ):
+            is_depthwise = True
+
         return (
             conv2d.attrs.out_dtype == "int32"
             and conv2d.attrs.padding[2] == 0
@@ -132,6 +142,7 @@ def pattern_table():
             and pattern.checked_type.dtype == "int8"
             and bias_dtype == "int32"
             and all([zp == 0 for zp in kernel_zp])
+            and (not is_depthwise or bias_add is not None)
         )
 
     def binary_op_pattern(op):
diff --git a/src/relay/backend/contrib/cmsisnn/generate_constants.cc b/src/relay/backend/contrib/cmsisnn/generate_constants.cc
index 0231e8b..2e12697 100644
--- a/src/relay/backend/contrib/cmsisnn/generate_constants.cc
+++ b/src/relay/backend/contrib/cmsisnn/generate_constants.cc
@@ -105,11 +105,20 @@ class GenerateConstantsMutator : public MixedModeMutator {
       conv2d_call = requantize_input;
     }
 
-    // Transpose weights: HWIO -> OHWI
     auto* conv2d_attrs = conv2d_call->attrs.as<Conv2DAttrs>();
-    tvm::Attrs new_conv2d_attrs;
-    Expr transposed_kernel =
-        ConvertKernelLayout(conv2d_call->args[1], conv2d_attrs, &new_conv2d_attrs);
+    tvm::Attrs new_conv2d_attrs = conv2d_call->attrs;
+    Expr conv2d_kernel = conv2d_call->args[1];
+
+    Array<PrimExpr> input_shape = conv2d_call->args[0]->type_as<TensorTypeNode>()->shape;
+    Array<PrimExpr> kernel_shape = conv2d_call->args[1]->type_as<TensorTypeNode>()->shape;
+    std::string kernel_layout = conv2d_attrs->kernel_layout.c_str();
+    int kernel_pos_o = kernel_layout.find("O");
+    int groups = conv2d_attrs->groups;
+    if (groups != qnn::get_const_int(input_shape[3]) ||
+        groups != qnn::get_const_int(kernel_shape[kernel_pos_o])) {
+      // Transpose weights: HWIO -> OHWI for Conv2D
+      conv2d_kernel = ConvertKernelLayout(conv2d_call->args[1], conv2d_attrs, &new_conv2d_attrs);
+    }
 
     // Obtain input and output scales from Relay's Requantization
     int64_t out_channels = conv2d_attrs->channels.as<IntImmNode>()->value;
@@ -153,11 +162,11 @@ class GenerateConstantsMutator : public MixedModeMutator {
       req_inp_scale = Constant(req_inp_scale_nda);
     }
 
-    // Replace existing weights (HWIO) with the transposed ones (OHWI)
+    // Replace existing weights (HWIO) with the transposed ones (OHWI) for Conv2D
     // Substitute Conv2D weight_zero_point with the CMSIS-NN multiplier
     // Substitute Requantize input_zero_point with CMSIS-NN shift
     // Conv2D arguments: data, weight, input_zp, weight_zp, input_sc, weight_sc
-    Array<Expr> conv2d_args = {conv2d_call->args[0], transposed_kernel,    conv2d_call->args[2],
+    Array<Expr> conv2d_args = {conv2d_call->args[0], conv2d_kernel,        conv2d_call->args[2],
                                multiplier_const,     conv2d_call->args[4], weight_scale};
     Call ret_call = Call(conv2d_call->op, conv2d_args, new_conv2d_attrs, {});
     if (bias_add_call) {
diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
index 1b639dd..6683527 100644
--- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
+++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
@@ -146,6 +146,9 @@ class RelayToTIRVisitor : public MixedModeMutator {
     int32_t padding_h = qnn::get_const_int(conv2d_attrs->padding[0]);
     int32_t dilation_w = qnn::get_const_int(conv2d_attrs->dilation[1]);
     int32_t dilation_h = qnn::get_const_int(conv2d_attrs->dilation[0]);
+    int32_t out_channels = qnn::get_const_int(conv2d_attrs->channels);
+    int32_t groups = conv2d_attrs->groups;
+    std::string kernel_layout = conv2d_attrs->kernel_layout.c_str();
     int32_t clip_min, clip_max;
     if (clip_call) {
       const ClipAttrs* clip_attrs = clip_call->attrs.as<ClipAttrs>();
@@ -156,14 +159,6 @@ class RelayToTIRVisitor : public MixedModeMutator {
       clip_max = 127;
     }
 
-    tvm::Array<PrimExpr> call_ext_args = {tir::StringImm("arm_convolve_wrapper_s8"), input, filter,
-                                          multiplier};
-    if (bias_add_call) {
-      call_ext_args.push_back(bias);
-    }
-    call_ext_args.push_back(shift);
-    call_ext_args.push_back(output);
-
     tvm::Array<PrimExpr> scalar_args = {ToArg(input_offset), ToArg(output_offset), ToArg(stride_w),
                                         ToArg(stride_h),     ToArg(padding_w),     ToArg(padding_h),
                                         ToArg(dilation_w),   ToArg(dilation_h),    ToArg(clip_min),
@@ -173,18 +168,42 @@ class RelayToTIRVisitor : public MixedModeMutator {
     Array<PrimExpr> input_shape = conv2d_call->args[0]->type_as<TensorTypeNode>()->shape;
     Array<PrimExpr> input_dims = CMSISNNDimensions(input_shape);
 
-    // cmsis_nn_dims *filter_dims (OHWI)
+    // cmsis_nn_dims *filter_dims (OHWI for Conv2D and IHWO for depthwise)
     Array<PrimExpr> filter_shape = conv2d_call->args[1]->type_as<TensorTypeNode>()->shape;
     Array<PrimExpr> filter_dims = CMSISNNDimensions(filter_shape);
 
-    // cmsis_nn_dims *bias_dims (1,1,1,output_channels)
-    Array<PrimExpr> bias_shape{1, 1, 1, filter_shape[0]};
+    // cmsis_nn_dims *bias_dims
+    Array<PrimExpr> bias_shape{1, 1, 1, out_channels};
     Array<PrimExpr> bias_dims = CMSISNNDimensions(bias_shape);
 
-    // cmsis_nn_dims *output_dims (NHWC)
+    // cmsis_nn_dims *output_dims (same order as input_dims)
     Array<PrimExpr> output_shape = conv2d_call->type_as<TensorTypeNode>()->shape;
     Array<PrimExpr> output_dims = CMSISNNDimensions(output_shape);
 
+    int32_t depth_multiplier = -1;
+    int kernel_pos_o = kernel_layout.find("O");
+    if (groups == qnn::get_const_int(input_shape[3]) &&
+        groups == qnn::get_const_int(filter_shape[kernel_pos_o])) {
+      int kernel_pos_i = kernel_layout.find("I");
+      depth_multiplier = qnn::get_const_int(filter_shape[kernel_pos_i]);
+    }
+    scalar_args.push_back(ToArg(depth_multiplier));
+
+    // original filter_layout for depthwise is HWOI
+    std::string cmsisnn_api = "arm_convolve_wrapper_s8";
+    if (depth_multiplier != -1) {
+      cmsisnn_api = "arm_depthwise_conv_wrapper_s8";
+      Array<PrimExpr> depthwise_filter_shape{1, filter_shape[0], filter_shape[1], out_channels};
+      filter_dims = CMSISNNDimensions(depthwise_filter_shape);
+    }
+
+    tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, filter, multiplier};
+    if (bias_add_call) {
+      call_ext_args.push_back(bias);
+    }
+    call_ext_args.push_back(shift);
+    call_ext_args.push_back(output);
+
     // https://github.com/ARM-software/CMSIS_5/blob/d788fd583984388553391de18afd8b4d2a146868/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_s8.c#L367
     std::string context_buffer_name = "NULL";
     size_t context_buffer_size =
diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
index b243af6..85923b3 100644
--- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
@@ -39,7 +39,6 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
     decl_stream << "#include <stdio.h>\n";
     decl_stream << "#include <stdlib.h>\n";
     decl_stream << "#include <dlpack/dlpack.h>\n";
-    decl_stream << "#include <tvm/runtime/crt/module.h>\n";
     decl_stream << "#include <arm_nnfunctions.h>\n";
     decl_stream << "#include <arm_nn_types.h>\n";
     CodeGenCHost::Init(output_ssa, emit_asserts, target_str);
@@ -53,6 +52,35 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
   void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); }
 
  private:
+  /*!  * \brief CMSIS-NN context buffer info */
+  struct CMSISNNContextBuffer {
+    std::string name;
+    int size;
+  };
+
+  /*!  * \brief CMSIS-NN buffer dimensions */
+  struct CMSISNNDims {
+    int n;
+    int h;
+    int w;
+    int c;
+  };
+
+  /*!  * \brief CMSIS-NN Conv2D and Depthwise parameters */
+  struct Conv2DParams {
+    int input_offset;
+    int output_offset;
+    int stride_w;
+    int stride_h;
+    int padding_w;
+    int padding_h;
+    int dilation_w;
+    int dilation_h;
+    int clip_min;
+    int clip_max;
+    int depth_multiplier;
+  };
+
   /*!  * \brief Emit the CMSIS-NN context buffer */
   void VisitStmt_(const AllocateNode* op) {
     context_buffer_name_ = op->buffer_var->name_hint;
@@ -70,38 +98,46 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
     if (cmsis_func_name == "arm_softmax_s8" || cmsis_func_name == "arm_elementwise_mul_s8" ||
         cmsis_func_name == "arm_elementwise_add_s8") {
       CodeGenC::VisitExpr_(op, os);
-    } else if (cmsis_func_name == "arm_convolve_wrapper_s8") {
+    } else if (cmsis_func_name == "arm_convolve_wrapper_s8" ||
+               cmsis_func_name == "arm_depthwise_conv_wrapper_s8") {
       EmitConv2D(op);
     }
     return;
   }
 
   /*!  * \brief Emits cmsis_nn_context struct */
-  std::string EmitCMSISNNContext(std::ostream& os, std::string buf_name, int buf_size) {
+  std::string EmitCMSISNNContext(std::ostream& os, CMSISNNContextBuffer context_buffer) {
     std::string struct_name = "context";
     PrintIndent();
-    os << "cmsis_nn_context " << struct_name << "= {" << buf_name << "," << buf_size << "};\n";
+    os << "cmsis_nn_context " << struct_name << "= {" << context_buffer.name << ","
+       << context_buffer.size << "};\n";
     return struct_name;
   }
 
   /*!  * \brief Emits cmsis_nn_conv_params struct */
-  std::string EmitCMSISNNConvParams(std::ostream& os, int32_t input_offset, int32_t output_offset,
-                                    int32_t stride_w, int32_t stride_h, int32_t padding_w,
-                                    int32_t padding_h, int32_t dilation_w, int32_t dilation_h,
-                                    int32_t clip_min, int32_t clip_max) {
-    std::string struct_name = "conv_params";
+  std::string EmitCMSISNNConvParams(std::ostream& os, Conv2DParams params) {
+    std::string struct_name = "cmsis_nn_conv_params";
+    std::string instance_name = "conv_params";
+    if (params.depth_multiplier != -1) {
+      struct_name = "cmsis_nn_dw_conv_params";
+    }
     PrintIndent();
-    os << "cmsis_nn_tile stride = {" << stride_w << "," << stride_h << "};\n";
+    os << "cmsis_nn_tile stride = {" << params.stride_w << "," << params.stride_h << "};\n";
     PrintIndent();
-    os << "cmsis_nn_tile padding = {" << padding_w << "," << padding_h << "};\n";
+    os << "cmsis_nn_tile padding = {" << params.padding_w << "," << params.padding_h << "};\n";
     PrintIndent();
-    os << "cmsis_nn_tile dilation = {" << dilation_w << "," << dilation_h << "};\n";
+    os << "cmsis_nn_tile dilation = {" << params.dilation_w << "," << params.dilation_h << "};\n";
     PrintIndent();
-    os << "cmsis_nn_activation activation = {" << clip_min << "," << clip_max << "};\n";
+    os << "cmsis_nn_activation activation = {" << params.clip_min << "," << params.clip_max
+       << "};\n";
     PrintIndent();
-    os << "cmsis_nn_conv_params " << struct_name << " = {" << input_offset << ", " << output_offset
-       << ", stride, padding, dilation, activation};\n";
-    return struct_name;
+    os << struct_name << " " << instance_name << " = {" << params.input_offset << ", "
+       << params.output_offset;
+    if (params.depth_multiplier != -1) {
+      os << ", " << params.depth_multiplier;
+    }
+    os << ", stride, padding, dilation, activation};\n";
+    return instance_name;
   }
 
   /*!  * \brief Emits cmsis_nn_per_channel_quant_params struct */
@@ -115,83 +151,109 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
   }
 
   /*!  * \brief Emits cmsis_nn_dims struct */
-  std::string EmitCMSISNNDims(std::ostream& os, std::string tensor_type, int32_t n, int32_t h,
-                              int32_t w, int32_t c) {
+  std::string EmitCMSISNNDims(std::ostream& os, std::string tensor_type, CMSISNNDims dims) {
     std::string struct_name = tensor_type + "_dims";
     PrintIndent();
-    os << "cmsis_nn_dims " << struct_name << " = {" << n << "," << h << "," << w << "," << c
-       << "};\n";
+    os << "cmsis_nn_dims " << struct_name << " = {" << dims.n << "," << dims.h << "," << dims.w
+       << "," << dims.c << "};\n";
     return struct_name;
   }
 
+  /*!  * \brief Deduces variable name from call_extern argument resting at id */
+  std::string VarNameFromArg(const CallNode* op, int id) {
+    return op->args[id].as<VarNode>()->name_hint.c_str();
+  }
+
+  /*!  * \brief Deduces value from call_extern argument resting at id */
+  int ValueFromArg(const CallNode* op, int id) { return op->args[id].as<IntImmNode>()->value; }
+
+  /*!  * \brief extracts CMSIS-NN context buffer information */
+  CMSISNNContextBuffer extract_context_buffer_info(const CallNode* op, int base_pos) {
+    CMSISNNContextBuffer context_buffer;
+    context_buffer.name = op->args[base_pos].as<StringImmNode>()->value;
+    context_buffer.size = ValueFromArg(op, base_pos + 1);
+    return context_buffer;
+  }
+
+  /*!  * \brief extracts CMSIS-NN conv2d parameters from call_extern */
+  Conv2DParams extract_conv2d_params(const CallNode* op, int base_pos) {
+    Conv2DParams conv2d_params;
+    conv2d_params.input_offset = ValueFromArg(op, base_pos);
+    conv2d_params.output_offset = ValueFromArg(op, ++base_pos);
+    conv2d_params.stride_w = ValueFromArg(op, ++base_pos);
+    conv2d_params.stride_h = ValueFromArg(op, ++base_pos);
+    conv2d_params.padding_w = ValueFromArg(op, ++base_pos);
+    conv2d_params.padding_h = ValueFromArg(op, ++base_pos);
+    conv2d_params.dilation_w = ValueFromArg(op, ++base_pos);
+    conv2d_params.dilation_h = ValueFromArg(op, ++base_pos);
+    conv2d_params.clip_min = ValueFromArg(op, ++base_pos);
+    conv2d_params.clip_max = ValueFromArg(op, ++base_pos);
+    conv2d_params.depth_multiplier = ValueFromArg(op, ++base_pos);
+    return conv2d_params;
+  }
+
+  /*!  * \brief extracts CMSIS-NN buffer dimensions from call_extern */
+  CMSISNNDims extract_buffer_dims(const CallNode* op, int base_pos) {
+    CMSISNNDims dims;
+    dims.n = ValueFromArg(op, base_pos);
+    dims.h = ValueFromArg(op, ++base_pos);
+    dims.w = ValueFromArg(op, ++base_pos);
+    dims.c = ValueFromArg(op, ++base_pos);
+    return dims;
+  }
+
   /*!  * \brief Emits CMSIS-NN APIs for every call_extern */
   void EmitConv2D(const CallNode* op) {
-    static const int max_num_args = 35;
-    std::string cmsis_func_name = op->args[0].as<StringImmNode>()->value;
+    // Position of various arguments relative to buffers in the call_extern
+    enum CallExternArgPos {
+      CONTEXT_BUFFER_POS = 1,
+      CONV2D_PARAMS_POS = 3,
+      INPUT_DIM_POS = 14,
+      FILTER_DIM_POS = 18,
+      BIAS_DIM_POS = 22,
+      OUTPUT_DIM_POS = 26,
+      MAX_NUM_ARGS = 36
+    };
 
-    bool bias_enabled = false;
-    if (op->args.size() == max_num_args) {
-      bias_enabled = true;
-    }
+    std::string cmsis_func_name = op->args[0].as<StringImmNode>()->value;
 
-    auto get_var_name = [](const CallNode* op, int id) {
-      return op->args[id].as<VarNode>()->name_hint.c_str();
-    };
-    auto get_arg_value = [](const CallNode* op, int id) {
-      return op->args[id].as<IntImmNode>()->value;
-    };
+    // extract buffer names from call_extern
     int arg_id = 0;
-    std::string input_data = get_var_name(op, ++arg_id);
-    std::string filter_data = get_var_name(op, ++arg_id);
-    std::string multiplier = get_var_name(op, ++arg_id);
-    std::string bias_data("0x0");
-    if (bias_enabled) {
-      bias_data = get_var_name(op, ++arg_id);
+    std::string input_data = VarNameFromArg(op, ++arg_id);
+    std::string filter_data = VarNameFromArg(op, ++arg_id);
+    std::string multiplier = VarNameFromArg(op, ++arg_id);
+    std::string bias_data("NULL");
+    if (op->args.size() == CallExternArgPos::MAX_NUM_ARGS) {
+      bias_data = VarNameFromArg(op, ++arg_id);
     }
-    std::string shift = get_var_name(op, ++arg_id);
-    std::string output_data = get_var_name(op, ++arg_id);
-
-    std::string context_buffer_name = op->args[++arg_id].as<StringImmNode>()->value;
-    int context_buffer_size = get_arg_value(op, ++arg_id);
-    int input_offset = get_arg_value(op, ++arg_id);
-    int output_offset = get_arg_value(op, ++arg_id);
-    int stride_w = get_arg_value(op, ++arg_id);
-    int stride_h = get_arg_value(op, ++arg_id);
-    int padding_w = get_arg_value(op, ++arg_id);
-    int padding_h = get_arg_value(op, ++arg_id);
-    int dilation_w = get_arg_value(op, ++arg_id);
-    int dilation_h = get_arg_value(op, ++arg_id);
-    int clip_min = get_arg_value(op, ++arg_id);
-    int clip_max = get_arg_value(op, ++arg_id);
-    int input_n = get_arg_value(op, ++arg_id);
-    int input_h = get_arg_value(op, ++arg_id);
-    int input_w = get_arg_value(op, ++arg_id);
-    int input_c = get_arg_value(op, ++arg_id);
-    int filter_n = get_arg_value(op, ++arg_id);
-    int filter_h = get_arg_value(op, ++arg_id);
-    int filter_w = get_arg_value(op, ++arg_id);
-    int filter_c = get_arg_value(op, ++arg_id);
-    int bias_n = get_arg_value(op, ++arg_id);
-    int bias_h = get_arg_value(op, ++arg_id);
-    int bias_w = get_arg_value(op, ++arg_id);
-    int bias_c = get_arg_value(op, ++arg_id);
-    int output_n = get_arg_value(op, ++arg_id);
-    int output_h = get_arg_value(op, ++arg_id);
-    int output_w = get_arg_value(op, ++arg_id);
-    int output_c = get_arg_value(op, ++arg_id);
-
-    std::string context = EmitCMSISNNContext(stream, context_buffer_name, context_buffer_size);
-    std::string conv_params =
-        EmitCMSISNNConvParams(stream, input_offset, output_offset, stride_w, stride_h, padding_w,
-                              padding_h, dilation_w, dilation_h, clip_min, clip_max);
+    std::string shift = VarNameFromArg(op, ++arg_id);
+    std::string output_data = VarNameFromArg(op, ++arg_id);
+
+    // extract CMSIS-NN API parameters
+    int context_buffer_pos = arg_id + CallExternArgPos::CONTEXT_BUFFER_POS;
+    int conv2d_params_pos = arg_id + CallExternArgPos::CONV2D_PARAMS_POS;
+    int input_dim_pos = arg_id + CallExternArgPos::INPUT_DIM_POS;
+    int filter_dim_pos = arg_id + CallExternArgPos::FILTER_DIM_POS;
+    int bias_dim_pos = arg_id + CallExternArgPos::BIAS_DIM_POS;
+    int output_dim_pos = arg_id + CallExternArgPos::OUTPUT_DIM_POS;
+
+    CMSISNNContextBuffer context_buffer = extract_context_buffer_info(op, context_buffer_pos);
+    Conv2DParams conv2d_params = extract_conv2d_params(op, conv2d_params_pos);
+    CMSISNNDims input_dims = extract_buffer_dims(op, input_dim_pos);
+    CMSISNNDims filter_dims = extract_buffer_dims(op, filter_dim_pos);
+    CMSISNNDims bias_dims = extract_buffer_dims(op, bias_dim_pos);
+    CMSISNNDims output_dims = extract_buffer_dims(op, output_dim_pos);
+
+    // Emit CMSIS-NN API arguments
+    std::string context = EmitCMSISNNContext(stream, context_buffer);
+    std::string conv_params = EmitCMSISNNConvParams(stream, conv2d_params);
     std::string quant_params = EmitCMSISNNPerChannelQuantParams(stream, multiplier, shift);
-    std::string input_dim = EmitCMSISNNDims(stream, "input", input_n, input_h, input_w, input_c);
-    std::string filter_dim =
-        EmitCMSISNNDims(stream, "filter", filter_n, filter_h, filter_w, filter_c);
-    std::string bias_dim = EmitCMSISNNDims(stream, "bias", bias_n, bias_h, bias_w, bias_c);
-    std::string output_dim =
-        EmitCMSISNNDims(stream, "output", output_n, output_h, output_w, output_c);
+    std::string input_dim = EmitCMSISNNDims(stream, "input", input_dims);
+    std::string filter_dim = EmitCMSISNNDims(stream, "filter", filter_dims);
+    std::string bias_dim = EmitCMSISNNDims(stream, "bias", bias_dims);
+    std::string output_dim = EmitCMSISNNDims(stream, "output", output_dims);
 
+    // Emit CMSIS-NN API
     PrintIndent();
     stream << "arm_status status = ";
     stream << cmsis_func_name << "(";
diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py
index 243197e..8d62763 100644
--- a/tests/python/contrib/test_cmsisnn/test_conv2d.py
+++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py
@@ -67,31 +67,30 @@ def make_model(
     w_index = weight_format.index("W")
     kernel_h = kernel_shape[h_index]
     kernel_w = kernel_shape[w_index]
-    a = relay.var("input", shape=shape, dtype=dtype)
+    invar = relay.var("input", shape=shape, dtype=dtype)
     p = (0, 0, 0, 0)
     if padding == "SAME":
         p = get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, strides)
-        a = relay.nn.pad(
-            a,
+        invar = relay.nn.pad(
+            invar,
             pad_width=[(0, 0), (p[0], p[2]), (p[1], p[3]), (0, 0)],
             pad_value=input_zero_point,
             pad_mode="constant",
         )
         shape = (shape[0], shape[1] + p[0] + p[2], shape[2] + p[1] + p[3], shape[3])
 
-    weight_shape = (kernel_h, kernel_w, shape[3] // groups, out_channels)
     rng = np.random.default_rng(12321)
     w = tvm.nd.array(
         rng.integers(
             np.iinfo(kernel_dtype).min,
             high=np.iinfo(kernel_dtype).max,
-            size=weight_shape,
+            size=kernel_shape,
             dtype=kernel_dtype,
         )
     )
     weight_const = relay.const(w, kernel_dtype)
     conv = relay.qnn.op.conv2d(
-        a,
+        invar,
         weight_const,
         input_zero_point=relay.const(input_zero_point, "int32"),
         kernel_zero_point=relay.const(kernel_zero_point, "int32"),
@@ -128,14 +127,14 @@ def make_model(
 @pytest.mark.parametrize("ifm_shape", [(1, 28, 28, 12), (1, 64, 100, 4)])
 @pytest.mark.parametrize("kernel_size", [(3, 3)])
 @pytest.mark.parametrize("padding", ["SAME", "VALID"])
-@pytest.mark.parametrize("strides, dilation", [((2, 2), (1, 1)), ((1, 1), (1, 1))])
+@pytest.mark.parametrize("strides, dilation", [((1, 1), (1, 1))])
+@pytest.mark.parametrize("relu_type", ["RELU"])
 @pytest.mark.parametrize("enable_bias", [True, False])
-@pytest.mark.parametrize("relu_type", ["NONE", "RELU"])
 @pytest.mark.parametrize(
     "input_zero_point, input_scale, kernel_scale, out_channels",
     [(10, 0.0128, [0.11, 0.22], 2), (-64, 1, [1, 0.0256, 1.37], 3)],
 )
-def test_op_int8(
+def test_conv2d_int8(
     ifm_shape,
     kernel_size,
     padding,
@@ -152,22 +151,134 @@ def test_op_int8(
     use_unpacked_api = True
     test_runner = AOT_CORSTONE300_RUNNER
 
-    kernel_zero_point = 0
+    dtype = "int8"
     groups = 1
     weight_format = "HWIO"
     kernel_h = kernel_size[0]
     kernel_w = kernel_size[1]
+    kernel_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels)
+    kernel_zero_point = 0
+    in_min, in_max = get_range_for_dtype_str(dtype)
+
+    output_scale, output_zero_point = get_conv2d_qnn_params(
+        kernel_shape,
+        input_scale,
+        input_zero_point,
+        kernel_scale,
+        kernel_zero_point,
+        dtype,
+        dtype,
+        dtype,
+    )
+
+    model, params = make_model(
+        ifm_shape,
+        kernel_shape,
+        input_zero_point,
+        input_scale,
+        kernel_zero_point,
+        kernel_scale,
+        output_zero_point,
+        output_scale,
+        padding,
+        strides,
+        dilation,
+        groups,
+        dtype,
+        dtype,
+        out_channels,
+        weight_format,
+        enable_bias,
+        relu_type,
+    )
+    orig_mod = make_module(model)
+    cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params)
+
+    # validate pattern matching
+    attrs = [
+        cmsisnn_mod[var.name_hint].attrs
+        for var in cmsisnn_mod.get_global_vars()
+        if cmsisnn_mod[var.name_hint].attrs
+    ]
+    assert any(attrs), "At least one function with external attributes was expected."
+
+    compilers = [
+        key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items()
+    ]
+    assert any(compilers), "Module does not contain function for cmsis-nn target."
+
+    assert count_num_calls(orig_mod) == count_num_calls(
+        cmsisnn_mod
+    ), "Number of calls changed during partitioning"
+
+    # validate the output
+    rng = np.random.default_rng(12345)
+    inputs = {"input": rng.integers(in_min, high=in_max, size=ifm_shape, dtype=dtype)}
+    output_list = generate_ref_data(orig_mod["main"], inputs, params)
+    compile_and_run(
+        AOTTestModel(
+            module=cmsisnn_mod,
+            inputs=inputs,
+            outputs=output_list,
+            params=params,
+            output_tolerance=1,
+        ),
+        test_runner,
+        interface_api,
+        use_unpacked_api,
+    )
+
+
+@tvm.testing.requires_cmsisnn
+@pytest.mark.parametrize("ifm_shape", [(1, 28, 28, 12), (1, 64, 100, 4)])
+@pytest.mark.parametrize("kernel_size", [(3, 3)])
+@pytest.mark.parametrize("padding", ["SAME", "VALID"])
+@pytest.mark.parametrize("strides, dilation", [((1, 1), (1, 1))])
+@pytest.mark.parametrize("relu_type", ["RELU"])
+@pytest.mark.parametrize(
+    "depth_multiplier, enable_bias",
+    [(1, True), (3, True)],
+)
+@pytest.mark.parametrize(
+    "input_zero_point, input_scale, kernel_scale, out_channels",
+    [(10, 0.0128, [0.11, 0.22], 2), (-64, 1, [1, 0.0256, 1.37], 3)],
+)
+def test_depthwise_int8(
+    ifm_shape,
+    kernel_size,
+    padding,
+    strides,
+    dilation,
+    enable_bias,
+    relu_type,
+    input_zero_point,
+    input_scale,
+    kernel_scale,
+    out_channels,
+    depth_multiplier,
+):
+    interface_api = "c"
+    use_unpacked_api = True
+    test_runner = AOT_CORSTONE300_RUNNER
+
     dtype = "int8"
+    groups = 1
+    weight_format = "HWIO"
+    kernel_h = kernel_size[0]
+    kernel_w = kernel_size[1]
+    kernel_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels)
+    kernel_zero_point = 0
     in_min, in_max = get_range_for_dtype_str(dtype)
 
-    weight_shape = None
-    if weight_format == "HWIO":
-        weight_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels)
-    else:
-        weight_shape = (kernel_h, kernel_w, ifm_shape[3], out_channels)
+    groups = ifm_shape[3]
+    weight_format = "HWOI"
+    kernel_shape = (kernel_h, kernel_w, ifm_shape[3], depth_multiplier)
+    out_channels = ifm_shape[3] * depth_multiplier
+    ks_len = len(kernel_scale)
+    kernel_scale = [kernel_scale[i % ks_len] for i in range(out_channels)]
 
     output_scale, output_zero_point = get_conv2d_qnn_params(
-        weight_shape,
+        kernel_shape,
         input_scale,
         input_zero_point,
         kernel_scale,
@@ -175,12 +286,12 @@ def test_op_int8(
         dtype,
         dtype,
         dtype,
-        False,
+        True,
     )
 
     model, params = make_model(
         ifm_shape,
-        weight_shape,
+        kernel_shape,
         input_zero_point,
         input_scale,
         kernel_zero_point,