You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2022/06/09 09:19:37 UTC

[tvm] branch main updated: [CMSIS-NN] Removed redudant arguments to CMSIS-NN wrapper function (#11431)

This is an automated email from the ASF dual-hosted git repository.

mousius pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 236eea0f49 [CMSIS-NN] Removed redudant arguments to CMSIS-NN wrapper function (#11431)
236eea0f49 is described below

commit 236eea0f49b4ca9a30e99d54f2ceb7ee3ef836f7
Author: Ashutosh Parkhi <86...@users.noreply.github.com>
AuthorDate: Thu Jun 9 10:19:31 2022 +0100

    [CMSIS-NN] Removed redudant arguments to CMSIS-NN wrapper function (#11431)
    
    Removed input_scale and filter_scale from CMSIS-NN
    wrapper function. These are not needed by CMSIS-NN
    API which gets called from the generated C wrapper
    function for Conv2D.
---
 src/relay/backend/contrib/cmsisnn/relay_to_tir.cc | 29 ++++++-
 tests/python/contrib/test_cmsisnn/test_conv2d.py  | 96 ++++++++++++++++++++++-
 2 files changed, 121 insertions(+), 4 deletions(-)

diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
index dc5537ee90..524735caa9 100644
--- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
+++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
@@ -141,18 +141,24 @@ class RelayToTIRVisitor : public MixedModeMutator {
     // %3 = qnn.requantize(%2, %input_scale_const_4, %cmsisnn_shift_const_5,
     //                     %output_scale_scalar, %output_zero_point_scalar)
     // clip(%3, a_min=%min_scalar, a_max=%max_scalar)
+    // Position of scales in the global function for Conv2D
+    const int filter_scale_pos = 3;
+    const int input_scale_pos = bias_add_call ? 5 : 4;
     BufferCreator buffer_creator;
     tir::Var input = buffer_creator.CreateBufferVar("input", DataType::Handle(8));
     tir::Var filter = buffer_creator.CreateBufferVar("filter", DataType::Handle(8));
     tir::Var multiplier = buffer_creator.CreateBufferVar("multiplier", DataType::Handle(32));
-    tir::Var filter_scale = buffer_creator.CreateBufferVar("filter_scale", DataType::Handle(32));
     if (bias_add_call) {
       buffer_creator.CreateBufferVar("bias", DataType::Handle(32));
     }
-    tir::Var input_scale = buffer_creator.CreateBufferVar("input_scale", DataType::Handle(32));
     tir::Var shift = buffer_creator.CreateBufferVar("shift", DataType::Handle(32));
     tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8));
 
+    // Relay function contains input_scale and filter_scale as function parameters at the following
+    // locations in the global partitioned function for Conv2D
+    skip_call_args_.insert(filter_scale_pos);
+    skip_call_args_.insert(input_scale_pos);
+
     // Individual arguments to the structs arguments of the CMSIS-NN API are filled into call_extern
     // https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50
 
@@ -742,11 +748,25 @@ class RelayToTIRVisitor : public MixedModeMutator {
                                                GetRef<Function>(func));
         }
 
+        // Drop out the redundant arguments, and the arg_types from the global function call
         Array<Expr> args;
+        Array<Type> arg_types;
+        auto* func_type = new_global_var->checked_type_.as<FuncTypeNode>();
+        int arg_id = -1;
         for (const auto& arg : call->args) {
+          ++arg_id;
+          if (std::find(skip_call_args_.begin(), skip_call_args_.end(), arg_id) !=
+              skip_call_args_.end()) {
+            continue;
+          }
           args.push_back(VisitExpr(arg));
+          arg_types.push_back(func_type->arg_types[arg_id]);
         }
-
+        if (arg_types.size() != func_type->arg_types.size()) {
+          new_global_var->checked_type_ =
+              FuncType(arg_types, func_type->ret_type, {}, func_type->type_constraints);
+        }
+        skip_call_args_.clear();
         return Call(new_global_var, args, call->attrs, call->type_args, call->span);
       }
     }
@@ -757,7 +777,10 @@ class RelayToTIRVisitor : public MixedModeMutator {
   static constexpr int32_t kScaledDiffIntegerBits = 5;
   static constexpr int32_t kInputBits = 5;
   static constexpr double kBeta = 1.0;
+  /*! \brief Unique id for context buffer needed by CMSIS-NN layers. */
   int32_t context_buffer_id_;
+  /*! \brief Skip arguments in the call to global partitioned function. */
+  std::unordered_set<int32_t> skip_call_args_;
   IRModule ir_module_;
   Target target_;
 };
diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py
index 439a3ec39c..90261e540a 100644
--- a/tests/python/contrib/test_cmsisnn/test_conv2d.py
+++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py
@@ -23,7 +23,7 @@ import tvm
 from tvm import relay
 from tvm.relay.op.contrib import cmsisnn
 
-from tvm.testing.aot import generate_ref_data, AOTTestModel, compile_and_run
+from tvm.testing.aot import generate_ref_data, AOTTestModel, compile_models, compile_and_run
 
 from tvm.micro.testing.aot_test_utils import AOT_USMP_CORSTONE300_RUNNER
 from utils import (
@@ -119,6 +119,100 @@ def make_model(
     return last_op, params
 
 
+@tvm.testing.requires_cmsisnn
+@pytest.mark.parametrize("padding", ["SAME", "VALID"])
+@pytest.mark.parametrize("enable_bias", [True, False])
+@pytest.mark.parametrize(
+    "input_zero_point, input_scale, kernel_scale, out_channels",
+    [(10, 0.0128, [0.11, 0.22], 2)],
+)
+def test_conv2d_number_primfunc_args(
+    padding,
+    enable_bias,
+    input_zero_point,
+    input_scale,
+    kernel_scale,
+    out_channels,
+):
+    interface_api = "c"
+    use_unpacked_api = True
+    test_runner = AOT_USMP_CORSTONE300_RUNNER
+
+    ifm_shape = (1, 64, 100, 4)
+    kernel_size = (3, 3)
+    strides = (1, 1)
+    dilation = (1, 1)
+    dtype = "int8"
+    groups = 1
+    weight_format = "HWIO"
+    kernel_h = kernel_size[0]
+    kernel_w = kernel_size[1]
+    kernel_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels)
+    kernel_zero_point = 0
+    in_min, in_max = get_range_for_dtype_str(dtype)
+    relu_type = "RELU"
+
+    output_scale, output_zero_point = get_conv2d_qnn_params(
+        kernel_shape,
+        input_scale,
+        input_zero_point,
+        kernel_scale,
+        kernel_zero_point,
+        dtype,
+        dtype,
+        dtype,
+    )
+
+    model, params = make_model(
+        ifm_shape,
+        kernel_shape,
+        input_zero_point,
+        input_scale,
+        kernel_zero_point,
+        kernel_scale,
+        output_zero_point,
+        output_scale,
+        padding,
+        strides,
+        dilation,
+        groups,
+        dtype,
+        dtype,
+        out_channels,
+        weight_format,
+        enable_bias,
+        relu_type,
+    )
+    orig_mod = make_module(model)
+    cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params)
+
+    # validate pattern matching
+    assert_partitioned_function(orig_mod, cmsisnn_mod)
+
+    # compile the model
+    rng = np.random.default_rng(12345)
+    inputs = {"input": rng.integers(in_min, high=in_max, size=ifm_shape, dtype=dtype)}
+    output_list = generate_ref_data(orig_mod["main"], inputs, params)
+
+    compiled_models = compile_models(
+        AOTTestModel(module=cmsisnn_mod, inputs=inputs, outputs=output_list, params=params),
+        interface_api,
+        use_unpacked_api,
+    )
+
+    # validate number of TIR primfunc args
+    expected_num_params = 6 if enable_bias else 5
+    cmsisnn_tir_mod = None
+    for target, mod in compiled_models[0].executor_factory.lowered_ir_mods.items():
+        if "cmsis-nn" == target.kind.name:
+            cmsisnn_tir_mod = mod
+
+    cmsisnn_func = cmsisnn_tir_mod["tvmgen_default_cmsis_nn_main_0"]
+    assert (
+        len(cmsisnn_func.params) == expected_num_params
+    ), "Generated unexpected number of function arguments"
+
+
 @tvm.testing.requires_cmsisnn
 @pytest.mark.parametrize("padding", ["SAME", "VALID"])
 @pytest.mark.parametrize("relu_type", ["RELU"])