You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by as...@apache.org on 2022/11/29 15:08:54 UTC
[tvm] branch main updated: [CMSIS-NN] Support int16 handling for pooling functions (#13498)
This is an automated email from the ASF dual-hosted git repository.
ashutoshp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f6f7feafb2 [CMSIS-NN] Support int16 handling for pooling functions (#13498)
f6f7feafb2 is described below
commit f6f7feafb297993f5f035de7f814407a2b876967
Author: neildhickey <ne...@arm.com>
AuthorDate: Tue Nov 29 15:08:48 2022 +0000
[CMSIS-NN] Support int16 handling for pooling functions (#13498)
[CMSIS-NN] Support int16 handling for pooling functions
-Pattern matching and RelayToTIR introduce int16 support
-Added int16 variants to fully_connected tests
---
python/tvm/relay/op/contrib/cmsisnn.py | 12 ++++++---
src/relay/backend/contrib/cmsisnn/relay_to_tir.cc | 29 +++++++++++++++++-----
.../backend/contrib/cmsisnn/tir_to_runtime.cc | 3 ++-
tests/python/contrib/test_cmsisnn/test_pooling.py | 12 ++++-----
4 files changed, 39 insertions(+), 17 deletions(-)
diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py
index 779fe35c37..4581378dcd 100644
--- a/python/tvm/relay/op/contrib/cmsisnn.py
+++ b/python/tvm/relay/op/contrib/cmsisnn.py
@@ -287,8 +287,10 @@ def pattern_table():
return (
pooling.attrs.layout == "NHWC"
and int(input_op.checked_type.shape[0]) == 1
- and input_op.checked_type.dtype == "int8"
- and output.checked_type.dtype == "int8"
+ and (
+ (input_op.checked_type.dtype == "int8" and output.checked_type.dtype == "int8")
+ or (input_op.checked_type.dtype == "int16" and output.checked_type.dtype == "int16")
+ )
)
def qnn_max_pool2d_pattern():
@@ -310,8 +312,10 @@ def pattern_table():
return (
pooling.attrs.layout == "NHWC"
and int(input_op.checked_type.shape[0]) == 1
- and input_op.checked_type.dtype == "int8"
- and output.checked_type.dtype == "int8"
+ and (
+ (input_op.checked_type.dtype == "int8" and output.checked_type.dtype == "int8")
+ or (input_op.checked_type.dtype == "int16" and output.checked_type.dtype == "int16")
+ )
)
def binary_op_pattern(op):
diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
index c9e41589fb..f8685dc4df 100644
--- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
+++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
@@ -428,12 +428,19 @@ class RelayToTIRVisitor : public MixedModeMutator {
pool = final_call;
}
+ int32_t dtype_bits = final_call->type_as<TensorTypeNode>()->dtype.bits();
+
// prepare cmsis_nn_pool_params
int32_t stride_h, stride_w, padding_h, padding_w, pool_size_h, pool_size_w;
int32_t clip_min, clip_max;
std::string cmsisnn_api;
if (pool_name == "cmsis-nn.qnn_avg_pool2d") {
- cmsisnn_api = "arm_avgpool_s8";
+ if (dtype_bits == 8) {
+ cmsisnn_api = "arm_avgpool_s8";
+ } else {
+ cmsisnn_api = "arm_avgpool_s16";
+ }
+
const AvgPool2DAttrs* attrs = pool->attrs.as<AvgPool2DAttrs>();
stride_h = qnn::get_const_int(attrs->strides[0]);
stride_w = qnn::get_const_int(attrs->strides[1]);
@@ -442,7 +449,12 @@ class RelayToTIRVisitor : public MixedModeMutator {
pool_size_h = qnn::get_const_int(attrs->pool_size[0]);
pool_size_w = qnn::get_const_int(attrs->pool_size[1]);
} else {
- cmsisnn_api = "arm_max_pool_s8";
+ if (dtype_bits == 8) {
+ cmsisnn_api = "arm_max_pool_s8";
+ } else {
+ cmsisnn_api = "arm_max_pool_s16";
+ }
+
const MaxPool2DAttrs* attrs = pool->attrs.as<MaxPool2DAttrs>();
stride_h = qnn::get_const_int(attrs->strides[0]);
stride_w = qnn::get_const_int(attrs->strides[1]);
@@ -456,8 +468,13 @@ class RelayToTIRVisitor : public MixedModeMutator {
clip_min = clip_attrs->a_min;
clip_max = clip_attrs->a_max;
} else {
- clip_min = -128;
- clip_max = 127;
+ if (dtype_bits == 8) {
+ clip_min = std::numeric_limits<int8_t>::min();
+ clip_max = std::numeric_limits<int8_t>::max();
+ } else {
+ clip_min = std::numeric_limits<int16_t>::min();
+ clip_max = std::numeric_limits<int16_t>::max();
+ }
}
tvm::Array<PrimExpr> scalar_args = {ToArg(stride_h), ToArg(stride_w), ToArg(padding_h),
@@ -472,8 +489,8 @@ class RelayToTIRVisitor : public MixedModeMutator {
Array<PrimExpr> cmsisnn_output_shape{1, output_shape[1], output_shape[2], output_shape[3]};
BufferCreator buffer_creator;
- tir::Var input = buffer_creator.CreateBufferVar("input", DataType::Handle(8));
- tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8));
+ tir::Var input = buffer_creator.CreateBufferVar("input", DataType::Handle(dtype_bits));
+ tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(dtype_bits));
tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, output};
int context_buffer_size = 0;
diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
index b5c5058ddb..420e8618a4 100644
--- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
@@ -118,7 +118,8 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
} else if (cmsis_func_name == "arm_fully_connected_s8" ||
cmsis_func_name == "arm_fully_connected_s16") {
EmitFullyConnected(op);
- } else if (cmsis_func_name == "arm_avgpool_s8" || cmsis_func_name == "arm_max_pool_s8") {
+ } else if (cmsis_func_name == "arm_avgpool_s8" || cmsis_func_name == "arm_avgpool_s16" ||
+ cmsis_func_name == "arm_max_pool_s8" || cmsis_func_name == "arm_max_pool_s16") {
EmitPool2D(op);
}
return;
diff --git a/tests/python/contrib/test_cmsisnn/test_pooling.py b/tests/python/contrib/test_cmsisnn/test_pooling.py
index 29140ad2e6..7657e0e632 100644
--- a/tests/python/contrib/test_cmsisnn/test_pooling.py
+++ b/tests/python/contrib/test_cmsisnn/test_pooling.py
@@ -81,6 +81,7 @@ def make_model(
@tvm.testing.requires_cmsisnn
+@pytest.mark.parametrize("dtype", ["int16", "int8"])
@pytest.mark.parametrize("in_shape", [(1, 28, 28, 12), (1, 64, 100, 4)])
@pytest.mark.parametrize(
"pool_size, strides, padding", [((3, 3), (2, 2), "SAME"), ((2, 2), (1, 1), "VALID")]
@@ -91,7 +92,8 @@ def make_model(
@pytest.mark.parametrize(
"compiler_cpu, cpu_flags", [("cortex-m55", "+nomve"), ("cortex-m55", ""), ("cortex-m7", "")]
)
-def test_op_int8(
+def test_ops(
+ dtype,
in_shape,
pool_size,
strides,
@@ -103,18 +105,17 @@ def test_op_int8(
compiler_cpu,
cpu_flags,
):
- """Tests QNN pooling op for int8 inputs"""
+ """Tests QNN pooling op for int8 and int16 pooling"""
interface_api = "c"
use_unpacked_api = True
- dtype = "int8"
-
model = make_model(
pool_op=pool_type,
shape=in_shape,
pool_size=pool_size,
strides=strides,
padding=padding,
+ dtype=dtype,
scale=scale,
zero_point=zero_point,
relu_type=relu_type,
@@ -130,7 +131,7 @@ def test_op_int8(
in_min, in_max = get_range_for_dtype_str(dtype)
np.random.seed(0)
inputs = {
- "input": np.random.randint(in_min, high=in_max, size=in_shape, dtype="int8"),
+ "input": np.random.randint(in_min, high=in_max, size=in_shape, dtype=dtype),
}
output_list = generate_ref_data(orig_mod["main"], inputs)
compile_and_run(
@@ -211,7 +212,6 @@ def test_int8_pool_with_float32_input(
def test_invalid_datatype(op):
"""Checks CMSIS-NN partitioning for non int8 dtype"""
model = make_model(pool_op=op, dtype="int64")
-
orig_mod = make_module(model)
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
assert_no_external_function(cmsisnn_mod)