You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by as...@apache.org on 2022/11/29 15:08:54 UTC

[tvm] branch main updated: [CMSIS-NN] Support int16 handling for pooling functions (#13498)

This is an automated email from the ASF dual-hosted git repository.

ashutoshp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f6f7feafb2 [CMSIS-NN] Support int16 handling for pooling functions (#13498)
f6f7feafb2 is described below

commit f6f7feafb297993f5f035de7f814407a2b876967
Author: neildhickey <ne...@arm.com>
AuthorDate: Tue Nov 29 15:08:48 2022 +0000

    [CMSIS-NN] Support int16 handling for pooling functions (#13498)
    
    [CMSIS-NN] Support int16 handling for pooling functions
    
    -Pattern matching and RelayToTIR introduce int16 support
    -Added int16 variants to fully_connected tests
---
 python/tvm/relay/op/contrib/cmsisnn.py             | 12 ++++++---
 src/relay/backend/contrib/cmsisnn/relay_to_tir.cc  | 29 +++++++++++++++++-----
 .../backend/contrib/cmsisnn/tir_to_runtime.cc      |  3 ++-
 tests/python/contrib/test_cmsisnn/test_pooling.py  | 12 ++++-----
 4 files changed, 39 insertions(+), 17 deletions(-)

diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py
index 779fe35c37..4581378dcd 100644
--- a/python/tvm/relay/op/contrib/cmsisnn.py
+++ b/python/tvm/relay/op/contrib/cmsisnn.py
@@ -287,8 +287,10 @@ def pattern_table():
         return (
             pooling.attrs.layout == "NHWC"
             and int(input_op.checked_type.shape[0]) == 1
-            and input_op.checked_type.dtype == "int8"
-            and output.checked_type.dtype == "int8"
+            and (
+                (input_op.checked_type.dtype == "int8" and output.checked_type.dtype == "int8")
+                or (input_op.checked_type.dtype == "int16" and output.checked_type.dtype == "int16")
+            )
         )
 
     def qnn_max_pool2d_pattern():
@@ -310,8 +312,10 @@ def pattern_table():
         return (
             pooling.attrs.layout == "NHWC"
             and int(input_op.checked_type.shape[0]) == 1
-            and input_op.checked_type.dtype == "int8"
-            and output.checked_type.dtype == "int8"
+            and (
+                (input_op.checked_type.dtype == "int8" and output.checked_type.dtype == "int8")
+                or (input_op.checked_type.dtype == "int16" and output.checked_type.dtype == "int16")
+            )
         )
 
     def binary_op_pattern(op):
diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
index c9e41589fb..f8685dc4df 100644
--- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
+++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
@@ -428,12 +428,19 @@ class RelayToTIRVisitor : public MixedModeMutator {
       pool = final_call;
     }
 
+    int32_t dtype_bits = final_call->type_as<TensorTypeNode>()->dtype.bits();
+
     // prepare cmsis_nn_pool_params
     int32_t stride_h, stride_w, padding_h, padding_w, pool_size_h, pool_size_w;
     int32_t clip_min, clip_max;
     std::string cmsisnn_api;
     if (pool_name == "cmsis-nn.qnn_avg_pool2d") {
-      cmsisnn_api = "arm_avgpool_s8";
+      if (dtype_bits == 8) {
+        cmsisnn_api = "arm_avgpool_s8";
+      } else {
+        cmsisnn_api = "arm_avgpool_s16";
+      }
+
       const AvgPool2DAttrs* attrs = pool->attrs.as<AvgPool2DAttrs>();
       stride_h = qnn::get_const_int(attrs->strides[0]);
       stride_w = qnn::get_const_int(attrs->strides[1]);
@@ -442,7 +449,12 @@ class RelayToTIRVisitor : public MixedModeMutator {
       pool_size_h = qnn::get_const_int(attrs->pool_size[0]);
       pool_size_w = qnn::get_const_int(attrs->pool_size[1]);
     } else {
-      cmsisnn_api = "arm_max_pool_s8";
+      if (dtype_bits == 8) {
+        cmsisnn_api = "arm_max_pool_s8";
+      } else {
+        cmsisnn_api = "arm_max_pool_s16";
+      }
+
       const MaxPool2DAttrs* attrs = pool->attrs.as<MaxPool2DAttrs>();
       stride_h = qnn::get_const_int(attrs->strides[0]);
       stride_w = qnn::get_const_int(attrs->strides[1]);
@@ -456,8 +468,13 @@ class RelayToTIRVisitor : public MixedModeMutator {
       clip_min = clip_attrs->a_min;
       clip_max = clip_attrs->a_max;
     } else {
-      clip_min = -128;
-      clip_max = 127;
+      if (dtype_bits == 8) {
+        clip_min = std::numeric_limits<int8_t>::min();
+        clip_max = std::numeric_limits<int8_t>::max();
+      } else {
+        clip_min = std::numeric_limits<int16_t>::min();
+        clip_max = std::numeric_limits<int16_t>::max();
+      }
     }
 
     tvm::Array<PrimExpr> scalar_args = {ToArg(stride_h),  ToArg(stride_w), ToArg(padding_h),
@@ -472,8 +489,8 @@ class RelayToTIRVisitor : public MixedModeMutator {
     Array<PrimExpr> cmsisnn_output_shape{1, output_shape[1], output_shape[2], output_shape[3]};
 
     BufferCreator buffer_creator;
-    tir::Var input = buffer_creator.CreateBufferVar("input", DataType::Handle(8));
-    tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8));
+    tir::Var input = buffer_creator.CreateBufferVar("input", DataType::Handle(dtype_bits));
+    tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(dtype_bits));
     tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, output};
 
     int context_buffer_size = 0;
diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
index b5c5058ddb..420e8618a4 100644
--- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
@@ -118,7 +118,8 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
     } else if (cmsis_func_name == "arm_fully_connected_s8" ||
                cmsis_func_name == "arm_fully_connected_s16") {
       EmitFullyConnected(op);
-    } else if (cmsis_func_name == "arm_avgpool_s8" || cmsis_func_name == "arm_max_pool_s8") {
+    } else if (cmsis_func_name == "arm_avgpool_s8" || cmsis_func_name == "arm_avgpool_s16" ||
+               cmsis_func_name == "arm_max_pool_s8" || cmsis_func_name == "arm_max_pool_s16") {
       EmitPool2D(op);
     }
     return;
diff --git a/tests/python/contrib/test_cmsisnn/test_pooling.py b/tests/python/contrib/test_cmsisnn/test_pooling.py
index 29140ad2e6..7657e0e632 100644
--- a/tests/python/contrib/test_cmsisnn/test_pooling.py
+++ b/tests/python/contrib/test_cmsisnn/test_pooling.py
@@ -81,6 +81,7 @@ def make_model(
 
 
 @tvm.testing.requires_cmsisnn
+@pytest.mark.parametrize("dtype", ["int16", "int8"])
 @pytest.mark.parametrize("in_shape", [(1, 28, 28, 12), (1, 64, 100, 4)])
 @pytest.mark.parametrize(
     "pool_size, strides, padding", [((3, 3), (2, 2), "SAME"), ((2, 2), (1, 1), "VALID")]
@@ -91,7 +92,8 @@ def make_model(
 @pytest.mark.parametrize(
     "compiler_cpu, cpu_flags", [("cortex-m55", "+nomve"), ("cortex-m55", ""), ("cortex-m7", "")]
 )
-def test_op_int8(
+def test_ops(
+    dtype,
     in_shape,
     pool_size,
     strides,
@@ -103,18 +105,17 @@ def test_op_int8(
     compiler_cpu,
     cpu_flags,
 ):
-    """Tests QNN pooling op for int8 inputs"""
+    """Tests QNN pooling op for int8 and int16 pooling"""
     interface_api = "c"
     use_unpacked_api = True
 
-    dtype = "int8"
-
     model = make_model(
         pool_op=pool_type,
         shape=in_shape,
         pool_size=pool_size,
         strides=strides,
         padding=padding,
+        dtype=dtype,
         scale=scale,
         zero_point=zero_point,
         relu_type=relu_type,
@@ -130,7 +131,7 @@ def test_op_int8(
     in_min, in_max = get_range_for_dtype_str(dtype)
     np.random.seed(0)
     inputs = {
-        "input": np.random.randint(in_min, high=in_max, size=in_shape, dtype="int8"),
+        "input": np.random.randint(in_min, high=in_max, size=in_shape, dtype=dtype),
     }
     output_list = generate_ref_data(orig_mod["main"], inputs)
     compile_and_run(
@@ -211,7 +212,6 @@ def test_int8_pool_with_float32_input(
 def test_invalid_datatype(op):
     """Checks CMSIS-NN partitioning for non int8 dtype"""
     model = make_model(pool_op=op, dtype="int64")
-
     orig_mod = make_module(model)
     cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
     assert_no_external_function(cmsisnn_mod)