You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/10/01 23:32:48 UTC
[tvm] branch main updated: [CMSIS-NN] Initial operator support for
Mul (#9163)
This is an automated email from the ASF dual-hosted git repository.
manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6d120c0 [CMSIS-NN] Initial operator support for Mul (#9163)
6d120c0 is described below
commit 6d120c0a52d4623cdad8116ababaa1c917674054
Author: Christopher Sidebottom <ch...@arm.com>
AuthorDate: Sat Oct 2 00:32:31 2021 +0100
[CMSIS-NN] Initial operator support for Mul (#9163)
This is largely as it says on the tin, it adds Mul support to CMSIS-NN
---
python/tvm/relay/op/contrib/cmsisnn.py | 21 +++
src/relay/backend/contrib/cmsisnn/relay_to_tir.cc | 89 ++++++++++--
tests/python/contrib/test_cmsisnn/test_mul.py | 154 +++++++++++++++++++++
tests/python/contrib/test_cmsisnn/test_networks.py | 40 +-----
tests/python/contrib/test_cmsisnn/test_softmax.py | 74 ++--------
tests/python/contrib/test_cmsisnn/utils.py | 83 +++++++++++
6 files changed, 350 insertions(+), 111 deletions(-)
diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py
index b74a09c..c28e97b 100644
--- a/python/tvm/relay/op/contrib/cmsisnn.py
+++ b/python/tvm/relay/op/contrib/cmsisnn.py
@@ -79,6 +79,27 @@ def pattern_table():
and dequantize_call.args[0].checked_type.dtype == "int8"
)
+ def mul_pattern():
+ """Matcher for QNN multiplication"""
+ return is_op("qnn.mul")(
+ wildcard(),
+ wildcard(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ )
+
+ def check_quantized_mul(extract):
+ """Check if multiply is supported by CMSIS-NN."""
+ return (
+ extract.args[0].checked_type.dtype == "int8"
+ and extract.args[1].checked_type.dtype == "int8"
+ )
+
return [
("cmsisnn.quantized_softmax", softmax_pattern(), check_quantized_softmax),
+ ("cmsisnn.quantized_mul", mul_pattern(), check_quantized_mul),
]
diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
index 7c1728c..bcb171c 100644
--- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
+++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
@@ -32,17 +32,37 @@ namespace relay {
namespace contrib {
namespace cmsisnn {
-class RelayToTIR : public MixedModeVisitor {
+class RelayToTIRVisitor : public MixedModeVisitor {
public:
- explicit RelayToTIR(String func_name) : func_name_(func_name) {}
+ explicit RelayToTIRVisitor(String func_name) : func_name_(func_name) {}
+
+ tir::PrimFunc GetReplacementPrimFunc() { return primfunc_; }
private:
- void emit_softmax_tir(const Expr& expr) {
+ template <typename T>
+ const T ArgumentToConstantValue(const Expr& arg) {
+ const ConstantNode* constant_node = arg.as<ConstantNode>();
+ return static_cast<const T*>(constant_node->data->data)[0];
+ }
+
+ void CreatePrimFuncForExtern(Array<tir::Var> func_signature,
+ tvm::Array<PrimExpr> call_extern_args) {
+ Map<String, ObjectRef> dict_attrs;
+ dict_attrs.Set("global_symbol", func_name_);
+ dict_attrs.Set("tir.noalias", Bool(true));
+
+ tir::Stmt body = tir::Evaluate(
+ tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), call_extern_args));
+
+ primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
+ DictAttrs(dict_attrs));
+ }
+
+ void EmitSoftMax(const Expr& expr) {
auto* quantize_call = expr.as<CallNode>();
auto* softmax_call = quantize_call->args[0].as<CallNode>();
auto* dequant_call = softmax_call->args[0].as<CallNode>();
- auto* scale_const = dequant_call->args[1].as<ConstantNode>();
- const float quant_scale = static_cast<const float*>(scale_const->data->data)[0];
+ const float quant_scale = ArgumentToConstantValue<float>(dequant_call->args[1]);
// assuming layout as NHWC
auto shape = quantize_call->type_as<TensorTypeNode>()->shape;
@@ -79,15 +99,51 @@ class RelayToTIR : public MixedModeVisitor {
IntImm(DataType::Int(32), num_rows), IntImm(DataType::Int(32), row_size),
IntImm(DataType::Int(32), mult), IntImm(DataType::Int(32), shift),
IntImm(DataType::Int(32), diff_min), out_var};
- tir::Stmt body =
- tir::Evaluate(tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), args));
- Map<String, ObjectRef> dict_attrs;
- dict_attrs.Set("global_symbol", func_name_);
- dict_attrs.Set("tir.noalias", Bool(true));
+ CreatePrimFuncForExtern(func_signature, args);
+ }
- primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
- DictAttrs(dict_attrs));
+ void EmitMul(const Expr& expr) {
+ auto* mul_call = expr.as<CallNode>();
+
+ const float input_0_scale = ArgumentToConstantValue<float>(mul_call->args[2]);
+ const int32_t input_0_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[3]);
+ const float input_1_scale = ArgumentToConstantValue<float>(mul_call->args[4]);
+ const int32_t input_1_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[5]);
+ const float output_scale = ArgumentToConstantValue<float>(mul_call->args[6]);
+ const int32_t output_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[7]);
+
+ double quantized_multiplier = static_cast<double>(input_0_scale) *
+ static_cast<double>(input_1_scale) /
+ static_cast<double>(output_scale);
+ auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier);
+ int32_t output_multiplier = std::get<0>(mult_shift_pair);
+ int32_t output_shift = std::get<1>(mult_shift_pair);
+
+ PrimExpr tensor_size = mul_call->type_as<TensorTypeNode>()->Size();
+
+ tir::Var input_0("input_0", DataType::Handle(8));
+ tir::Var input_1("input_1", DataType::Handle(8));
+ tir::Var output("output", DataType::Handle(8));
+
+ Array<tir::Var> func_signature{input_0, input_1, output};
+
+ tvm::Array<PrimExpr> args = {
+ tir::StringImm("arm_elementwise_mul_s8"),
+ input_0,
+ input_1,
+ IntImm(DataType::Int(32), -input_0_zero_point),
+ IntImm(DataType::Int(32), -input_1_zero_point),
+ output,
+ IntImm(DataType::Int(32), output_zero_point),
+ IntImm(DataType::Int(32), output_multiplier),
+ IntImm(DataType::Int(32), output_shift),
+ IntImm(DataType::Int(32), std::numeric_limits<int8_t>::min()),
+ IntImm(DataType::Int(32), std::numeric_limits<int8_t>::max()),
+ tensor_size,
+ };
+
+ CreatePrimFuncForExtern(func_signature, args);
}
void VisitExpr_(const CallNode* call) final {
@@ -98,7 +154,10 @@ class RelayToTIR : public MixedModeVisitor {
auto comp_name = func->GetAttr<String>(attr::kComposite);
if (comp_name.defined() && comp_name == "cmsisnn.quantized_softmax") {
- emit_softmax_tir(func->body);
+ EmitSoftMax(func->body);
+ }
+ if (comp_name.defined() && comp_name == "cmsisnn.quantized_mul") {
+ EmitMul(func->body);
}
}
@@ -119,12 +178,12 @@ IRModule GenerateTIR(IRModule mod) {
}
// Prepare PrimFunc from Relay Function
- auto relay_to_tir = RelayToTIR(func_name);
+ auto relay_to_tir = RelayToTIRVisitor(func_name);
relay_to_tir.VisitExpr(func->body);
// Build the TIR IRModule from the generated PrimFunc
Map<GlobalVar, BaseFunc> var_func_map;
- var_func_map.Set(GlobalVar(func_name), relay_to_tir.primfunc_);
+ var_func_map.Set(GlobalVar(func_name), relay_to_tir.GetReplacementPrimFunc());
return IRModule(var_func_map);
}
diff --git a/tests/python/contrib/test_cmsisnn/test_mul.py b/tests/python/contrib/test_cmsisnn/test_mul.py
new file mode 100644
index 0000000..88fbeb2
--- /dev/null
+++ b/tests/python/contrib/test_cmsisnn/test_mul.py
@@ -0,0 +1,154 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""CMSIS-NN integration tests: mul"""
+
+import sys
+
+import numpy as np
+import pytest
+
+from tvm import relay
+from tvm.relay.op.contrib import cmsisnn
+
+from utils import skip_if_no_reference_system, make_module, count_num_calls, get_range_for_dtype_str
+from tests.python.relay.aot.aot_test_utils import (
+ AOTTestModel,
+ AOT_CORSTONE300_RUNNER,
+ generate_ref_data,
+ compile_and_run,
+)
+
+
+def make_model(
+ shape,
+ input_0_dtype,
+ input_1_dtype,
+ input_0_scale,
+ input_0_zero_point,
+ input_1_scale,
+ input_1_zero_point,
+ out_scale=1.0 / 256,
+ out_zero_point=-128,
+):
+ """Create a Relay Function / network model"""
+
+ return relay.qnn.op.mul(
+ relay.var("input_0", shape=shape, dtype=input_0_dtype),
+ relay.var("input_1", shape=shape, dtype=input_1_dtype),
+ relay.const(input_0_scale, "float32"),
+ relay.const(input_0_zero_point, "int32"),
+ relay.const(input_1_scale, "float32"),
+ relay.const(input_1_zero_point, "int32"),
+ relay.const(out_scale, "float32"),
+ relay.const(out_zero_point, "int32"),
+ )
+
+
+@skip_if_no_reference_system
+@pytest.mark.parametrize(
+ [
+ "input_0_scale",
+ "input_0_zero_point",
+ "input_1_scale",
+ "input_1_zero_point",
+ "output_tolerance",
+ ],
+ [[0.256, 33, 0.256, 33, 0], [0.0128, -64, 0.0128, -64, 1], [0.0128, -64, 0.256, 33, 0]],
+)
+def test_mul_int8(
+ input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point, output_tolerance
+):
+ interface_api = "c"
+ use_unpacked_api = True
+ test_runner = AOT_CORSTONE300_RUNNER
+
+ dtype = "int8"
+ shape = [1, 16, 16, 3]
+ model = make_model(
+ shape, dtype, dtype, input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point
+ )
+ orig_mod = make_module(model)
+
+ cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
+
+ # validate pattern matching
+ attrs = [
+ cmsisnn_mod[var.name_hint].attrs
+ for var in cmsisnn_mod.get_global_vars()
+ if cmsisnn_mod[var.name_hint].attrs
+ ]
+ assert any(attrs), "At least one function with external attributes was expected."
+
+ compilers = [
+ key == "Compiler" and value == "cmsisnn" for attr in attrs for key, value in attr.items()
+ ]
+ assert any(compilers), "Module does not contain function for cmsisnn target."
+
+ assert count_num_calls(orig_mod) == count_num_calls(
+ cmsisnn_mod
+ ), "Number of calls changed during partitioning"
+
+ # validate the output
+ in_min, in_max = get_range_for_dtype_str(dtype)
+ inputs = {
+ "input_0": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype),
+ "input_1": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype),
+ }
+ output_list = generate_ref_data(orig_mod["main"], inputs)
+ compile_and_run(
+ AOTTestModel(
+ module=cmsisnn_mod,
+ inputs=inputs,
+ outputs=output_list,
+ output_tolerance=output_tolerance,
+ ),
+ test_runner,
+ interface_api,
+ use_unpacked_api,
+ )
+
+
+@pytest.mark.parametrize(["input_dtype"], [["uint8"], ["int16"]])
+def test_invalid_parameters(
+ input_dtype,
+):
+ input_scale = 0.256
+ input_zero_point = 33
+ model = make_model(
+ [1, 16, 16, 3],
+ input_dtype,
+ input_dtype,
+ input_scale,
+ input_zero_point,
+ input_scale,
+ input_zero_point,
+ )
+
+ orig_mod = make_module(model)
+ cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
+
+ attrs = [
+ cmsisnn_mod[var.name_hint].attrs
+ for var in cmsisnn_mod.get_global_vars()
+ if cmsisnn_mod[var.name_hint].attrs
+ ]
+ assert not any(attrs), "No function should have an external attribute."
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/python/contrib/test_cmsisnn/test_networks.py b/tests/python/contrib/test_cmsisnn/test_networks.py
index 1f6e0e7..b14a15c 100644
--- a/tests/python/contrib/test_cmsisnn/test_networks.py
+++ b/tests/python/contrib/test_cmsisnn/test_networks.py
@@ -17,18 +17,16 @@
"""CMSIS-NN: testing with networks"""
-import platform
import sys
-import os
-import pathlib
-import tvm
+
+import numpy as np
+import pytest
+
from tvm import relay
from tvm.contrib.download import download_testdata
from tvm.relay.op.contrib import cmsisnn
-import numpy as np
-import pytest
-import itertools
+from utils import skip_if_no_reference_system, get_range_for_dtype_str
from tests.python.relay.aot.aot_test_utils import (
AOTTestModel,
AOT_CORSTONE300_RUNNER,
@@ -37,30 +35,6 @@ from tests.python.relay.aot.aot_test_utils import (
)
-def get_range_for_dtype_str(dtype):
- """
- Produce the min,max for a give data type.
-
- Parameters
- ----------
- dtype : str
- a type string (e.g., int8)
-
- Returns
- -------
- type_info.min : int
- the minimum of the range
- type_info.max : int
- the maximum of the range
- """
-
- try:
- type_info = np.iinfo(dtype)
- except ValueError:
- type_info = np.finfo(dtype)
- return type_info.min, type_info.max
-
-
def convert_to_relay(
tflite_model_buf,
input_data,
@@ -99,9 +73,7 @@ def convert_to_relay(
return mod, params
-@pytest.mark.skipif(
- platform.machine() == "i686", reason="Reference system unavailable in i386 container"
-)
+@skip_if_no_reference_system
def test_cnn_small():
# download the model
base_url = "https://github.com/ARM-software/ML-zoo/raw/master/models/keyword_spotting/cnn_small/tflite_int8"
diff --git a/tests/python/contrib/test_cmsisnn/test_softmax.py b/tests/python/contrib/test_cmsisnn/test_softmax.py
index c1951d1..12e11c3 100644
--- a/tests/python/contrib/test_cmsisnn/test_softmax.py
+++ b/tests/python/contrib/test_cmsisnn/test_softmax.py
@@ -17,17 +17,21 @@
"""CMSIS-NN integration tests: softmax"""
-import platform
import sys
-import os
-import pathlib
-import tvm
-from tvm import relay
-from tvm.relay.op.contrib import cmsisnn
+import itertools
+
import numpy as np
import pytest
-import itertools
+from tvm import relay
+from tvm.relay.op.contrib import cmsisnn
+
+from utils import (
+ skip_if_no_reference_system,
+ make_module,
+ count_num_calls,
+ get_range_for_dtype_str,
+)
from tests.python.relay.aot.aot_test_utils import (
AOTTestModel,
AOT_CORSTONE300_RUNNER,
@@ -36,61 +40,9 @@ from tests.python.relay.aot.aot_test_utils import (
)
-def get_range_for_dtype_str(dtype):
- """
- Produce the min,max for a give data type.
-
- Parameters
- ----------
- dtype : str
- a type string (e.g., int8)
-
- Returns
- -------
- type_info.min : int
- the minimum of the range
- type_info.max : int
- the maximum of the range
- """
-
- try:
- type_info = np.iinfo(dtype)
- except ValueError:
- type_info = np.finfo(dtype)
- return type_info.min, type_info.max
-
-
-def count_num_calls(mod):
- """Count number of CallNode in the IRModule"""
-
- class CallCounter(relay.ExprVisitor):
- def __init__(self):
- super().__init__()
- self.count = 0
-
- def visit_call(self, call):
- if isinstance(call.op, tvm.ir.Op):
- self.count += 1
-
- super().visit_call(call)
-
- counter = CallCounter()
- for var in mod.get_global_vars():
- counter.visit(mod[var.name_hint])
- return counter.count
-
-
-def make_module(func):
- """Create IRModule from Function"""
- func = relay.Function(relay.analysis.free_vars(func), func)
- mod = tvm.IRModule.from_expr(func)
- return relay.transform.InferType()(mod)
-
-
def make_model(
shape, in_dtype, out_dtype, in_zero_point, in_scale, out_zero_point=-128, out_scale=1.0 / 256
):
-
"""Create a Relay Function / network model"""
a = relay.var("in0", shape=shape, dtype=in_dtype)
dequantize = relay.qnn.op.dequantize(
@@ -108,9 +60,7 @@ def make_model(
return model
-@pytest.mark.skipif(
- platform.machine() == "i686", reason="Reference system unavailable in i386 container"
-)
+@skip_if_no_reference_system
@pytest.mark.parametrize(["zero_point", "scale"], [[33, 0.256], [-64, 0.0128]])
def test_softmax_int8(zero_point, scale):
interface_api = "c"
diff --git a/tests/python/contrib/test_cmsisnn/utils.py b/tests/python/contrib/test_cmsisnn/utils.py
new file mode 100644
index 0000000..3fd12ef
--- /dev/null
+++ b/tests/python/contrib/test_cmsisnn/utils.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""CMSIS-NN functions for testing networks"""
+
+import platform
+
+import numpy as np
+import pytest
+
+import tvm
+from tvm import relay
+
+
+def skip_if_no_reference_system(func):
+ return pytest.mark.skipif(
+ platform.machine() == "i686", reason="Reference system unavailable in i386 container"
+ )(func)
+
+
+def count_num_calls(mod):
+ """Count number of CallNode in the IRModule"""
+
+ class CallCounter(relay.ExprVisitor):
+ def __init__(self):
+ super().__init__()
+ self.count = 0
+
+ def visit_call(self, call):
+ if isinstance(call.op, tvm.ir.Op):
+ self.count += 1
+
+ super().visit_call(call)
+
+ counter = CallCounter()
+ for var in mod.get_global_vars():
+ counter.visit(mod[var.name_hint])
+ return counter.count
+
+
+def get_range_for_dtype_str(dtype):
+ """
+ Produce the min,max for a give data type.
+
+ Parameters
+ ----------
+ dtype : str
+ a type string (e.g., int8)
+
+ Returns
+ -------
+ type_info.min : int
+ the minimum of the range
+ type_info.max : int
+ the maximum of the range
+ """
+
+ try:
+ type_info = np.iinfo(dtype)
+ except ValueError:
+ type_info = np.finfo(dtype)
+ return type_info.min, type_info.max
+
+
+def make_module(func):
+ """Create IRModule from Function"""
+ func = relay.Function(relay.analysis.free_vars(func), func)
+ mod = tvm.IRModule.from_expr(func)
+ return relay.transform.InferType()(mod)