You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/10/01 23:32:48 UTC

[tvm] branch main updated: [CMSIS-NN] Initial operator support for Mul (#9163)

This is an automated email from the ASF dual-hosted git repository.

manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6d120c0  [CMSIS-NN] Initial operator support for Mul (#9163)
6d120c0 is described below

commit 6d120c0a52d4623cdad8116ababaa1c917674054
Author: Christopher Sidebottom <ch...@arm.com>
AuthorDate: Sat Oct 2 00:32:31 2021 +0100

    [CMSIS-NN] Initial operator support for Mul (#9163)
    
    This is largely as it says on the tin, it adds Mul support to CMSIS-NN
---
 python/tvm/relay/op/contrib/cmsisnn.py             |  21 +++
 src/relay/backend/contrib/cmsisnn/relay_to_tir.cc  |  89 ++++++++++--
 tests/python/contrib/test_cmsisnn/test_mul.py      | 154 +++++++++++++++++++++
 tests/python/contrib/test_cmsisnn/test_networks.py |  40 +-----
 tests/python/contrib/test_cmsisnn/test_softmax.py  |  74 ++--------
 tests/python/contrib/test_cmsisnn/utils.py         |  83 +++++++++++
 6 files changed, 350 insertions(+), 111 deletions(-)

diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py
index b74a09c..c28e97b 100644
--- a/python/tvm/relay/op/contrib/cmsisnn.py
+++ b/python/tvm/relay/op/contrib/cmsisnn.py
@@ -79,6 +79,27 @@ def pattern_table():
             and dequantize_call.args[0].checked_type.dtype == "int8"
         )
 
+    def mul_pattern():
+        """Matcher for QNN multiplication"""
+        return is_op("qnn.mul")(
+            wildcard(),
+            wildcard(),
+            is_constant(),
+            is_constant(),
+            is_constant(),
+            is_constant(),
+            is_constant(),
+            is_constant(),
+        )
+
+    def check_quantized_mul(extract):
+        """Check if multiply is supported by CMSIS-NN."""
+        return (
+            extract.args[0].checked_type.dtype == "int8"
+            and extract.args[1].checked_type.dtype == "int8"
+        )
+
     return [
         ("cmsisnn.quantized_softmax", softmax_pattern(), check_quantized_softmax),
+        ("cmsisnn.quantized_mul", mul_pattern(), check_quantized_mul),
     ]
diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
index 7c1728c..bcb171c 100644
--- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
+++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
@@ -32,17 +32,37 @@ namespace relay {
 namespace contrib {
 namespace cmsisnn {
 
-class RelayToTIR : public MixedModeVisitor {
+class RelayToTIRVisitor : public MixedModeVisitor {
  public:
-  explicit RelayToTIR(String func_name) : func_name_(func_name) {}
+  explicit RelayToTIRVisitor(String func_name) : func_name_(func_name) {}
+
+  tir::PrimFunc GetReplacementPrimFunc() { return primfunc_; }
 
  private:
-  void emit_softmax_tir(const Expr& expr) {
+  template <typename T>
+  const T ArgumentToConstantValue(const Expr& arg) {
+    const ConstantNode* constant_node = arg.as<ConstantNode>();
+    return static_cast<const T*>(constant_node->data->data)[0];
+  }
+
+  void CreatePrimFuncForExtern(Array<tir::Var> func_signature,
+                               tvm::Array<PrimExpr> call_extern_args) {
+    Map<String, ObjectRef> dict_attrs;
+    dict_attrs.Set("global_symbol", func_name_);
+    dict_attrs.Set("tir.noalias", Bool(true));
+
+    tir::Stmt body = tir::Evaluate(
+        tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), call_extern_args));
+
+    primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
+                              DictAttrs(dict_attrs));
+  }
+
+  void EmitSoftMax(const Expr& expr) {
     auto* quantize_call = expr.as<CallNode>();
     auto* softmax_call = quantize_call->args[0].as<CallNode>();
     auto* dequant_call = softmax_call->args[0].as<CallNode>();
-    auto* scale_const = dequant_call->args[1].as<ConstantNode>();
-    const float quant_scale = static_cast<const float*>(scale_const->data->data)[0];
+    const float quant_scale = ArgumentToConstantValue<float>(dequant_call->args[1]);
 
     // assuming layout as NHWC
     auto shape = quantize_call->type_as<TensorTypeNode>()->shape;
@@ -79,15 +99,51 @@ class RelayToTIR : public MixedModeVisitor {
         IntImm(DataType::Int(32), num_rows), IntImm(DataType::Int(32), row_size),
         IntImm(DataType::Int(32), mult),     IntImm(DataType::Int(32), shift),
         IntImm(DataType::Int(32), diff_min), out_var};
-    tir::Stmt body =
-        tir::Evaluate(tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), args));
 
-    Map<String, ObjectRef> dict_attrs;
-    dict_attrs.Set("global_symbol", func_name_);
-    dict_attrs.Set("tir.noalias", Bool(true));
+    CreatePrimFuncForExtern(func_signature, args);
+  }
 
-    primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
-                              DictAttrs(dict_attrs));
+  void EmitMul(const Expr& expr) {
+    auto* mul_call = expr.as<CallNode>();
+
+    const float input_0_scale = ArgumentToConstantValue<float>(mul_call->args[2]);
+    const int32_t input_0_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[3]);
+    const float input_1_scale = ArgumentToConstantValue<float>(mul_call->args[4]);
+    const int32_t input_1_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[5]);
+    const float output_scale = ArgumentToConstantValue<float>(mul_call->args[6]);
+    const int32_t output_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[7]);
+
+    double quantized_multiplier = static_cast<double>(input_0_scale) *
+                                  static_cast<double>(input_1_scale) /
+                                  static_cast<double>(output_scale);
+    auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier);
+    int32_t output_multiplier = std::get<0>(mult_shift_pair);
+    int32_t output_shift = std::get<1>(mult_shift_pair);
+
+    PrimExpr tensor_size = mul_call->type_as<TensorTypeNode>()->Size();
+
+    tir::Var input_0("input_0", DataType::Handle(8));
+    tir::Var input_1("input_1", DataType::Handle(8));
+    tir::Var output("output", DataType::Handle(8));
+
+    Array<tir::Var> func_signature{input_0, input_1, output};
+
+    tvm::Array<PrimExpr> args = {
+        tir::StringImm("arm_elementwise_mul_s8"),
+        input_0,
+        input_1,
+        IntImm(DataType::Int(32), -input_0_zero_point),
+        IntImm(DataType::Int(32), -input_1_zero_point),
+        output,
+        IntImm(DataType::Int(32), output_zero_point),
+        IntImm(DataType::Int(32), output_multiplier),
+        IntImm(DataType::Int(32), output_shift),
+        IntImm(DataType::Int(32), std::numeric_limits<int8_t>::min()),
+        IntImm(DataType::Int(32), std::numeric_limits<int8_t>::max()),
+        tensor_size,
+    };
+
+    CreatePrimFuncForExtern(func_signature, args);
   }
 
   void VisitExpr_(const CallNode* call) final {
@@ -98,7 +154,10 @@ class RelayToTIR : public MixedModeVisitor {
 
     auto comp_name = func->GetAttr<String>(attr::kComposite);
     if (comp_name.defined() && comp_name == "cmsisnn.quantized_softmax") {
-      emit_softmax_tir(func->body);
+      EmitSoftMax(func->body);
+    }
+    if (comp_name.defined() && comp_name == "cmsisnn.quantized_mul") {
+      EmitMul(func->body);
     }
   }
 
@@ -119,12 +178,12 @@ IRModule GenerateTIR(IRModule mod) {
   }
 
   // Prepare PrimFunc from Relay Function
-  auto relay_to_tir = RelayToTIR(func_name);
+  auto relay_to_tir = RelayToTIRVisitor(func_name);
   relay_to_tir.VisitExpr(func->body);
 
   // Build the TIR IRModule from the generated PrimFunc
   Map<GlobalVar, BaseFunc> var_func_map;
-  var_func_map.Set(GlobalVar(func_name), relay_to_tir.primfunc_);
+  var_func_map.Set(GlobalVar(func_name), relay_to_tir.GetReplacementPrimFunc());
   return IRModule(var_func_map);
 }
 
diff --git a/tests/python/contrib/test_cmsisnn/test_mul.py b/tests/python/contrib/test_cmsisnn/test_mul.py
new file mode 100644
index 0000000..88fbeb2
--- /dev/null
+++ b/tests/python/contrib/test_cmsisnn/test_mul.py
@@ -0,0 +1,154 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""CMSIS-NN integration tests: mul"""
+
+import sys
+
+import numpy as np
+import pytest
+
+from tvm import relay
+from tvm.relay.op.contrib import cmsisnn
+
+from utils import skip_if_no_reference_system, make_module, count_num_calls, get_range_for_dtype_str
+from tests.python.relay.aot.aot_test_utils import (
+    AOTTestModel,
+    AOT_CORSTONE300_RUNNER,
+    generate_ref_data,
+    compile_and_run,
+)
+
+
+def make_model(
+    shape,
+    input_0_dtype,
+    input_1_dtype,
+    input_0_scale,
+    input_0_zero_point,
+    input_1_scale,
+    input_1_zero_point,
+    out_scale=1.0 / 256,
+    out_zero_point=-128,
+):
+    """Create a Relay Function / network model"""
+
+    return relay.qnn.op.mul(
+        relay.var("input_0", shape=shape, dtype=input_0_dtype),
+        relay.var("input_1", shape=shape, dtype=input_1_dtype),
+        relay.const(input_0_scale, "float32"),
+        relay.const(input_0_zero_point, "int32"),
+        relay.const(input_1_scale, "float32"),
+        relay.const(input_1_zero_point, "int32"),
+        relay.const(out_scale, "float32"),
+        relay.const(out_zero_point, "int32"),
+    )
+
+
+@skip_if_no_reference_system
+@pytest.mark.parametrize(
+    [
+        "input_0_scale",
+        "input_0_zero_point",
+        "input_1_scale",
+        "input_1_zero_point",
+        "output_tolerance",
+    ],
+    [[0.256, 33, 0.256, 33, 0], [0.0128, -64, 0.0128, -64, 1], [0.0128, -64, 0.256, 33, 0]],
+)
+def test_mul_int8(
+    input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point, output_tolerance
+):
+    interface_api = "c"
+    use_unpacked_api = True
+    test_runner = AOT_CORSTONE300_RUNNER
+
+    dtype = "int8"
+    shape = [1, 16, 16, 3]
+    model = make_model(
+        shape, dtype, dtype, input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point
+    )
+    orig_mod = make_module(model)
+
+    cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
+
+    # validate pattern matching
+    attrs = [
+        cmsisnn_mod[var.name_hint].attrs
+        for var in cmsisnn_mod.get_global_vars()
+        if cmsisnn_mod[var.name_hint].attrs
+    ]
+    assert any(attrs), "At least one function with external attributes was expected."
+
+    compilers = [
+        key == "Compiler" and value == "cmsisnn" for attr in attrs for key, value in attr.items()
+    ]
+    assert any(compilers), "Module does not contain function for cmsisnn target."
+
+    assert count_num_calls(orig_mod) == count_num_calls(
+        cmsisnn_mod
+    ), "Number of calls changed during partitioning"
+
+    # validate the output
+    in_min, in_max = get_range_for_dtype_str(dtype)
+    inputs = {
+        "input_0": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype),
+        "input_1": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype),
+    }
+    output_list = generate_ref_data(orig_mod["main"], inputs)
+    compile_and_run(
+        AOTTestModel(
+            module=cmsisnn_mod,
+            inputs=inputs,
+            outputs=output_list,
+            output_tolerance=output_tolerance,
+        ),
+        test_runner,
+        interface_api,
+        use_unpacked_api,
+    )
+
+
+@pytest.mark.parametrize(["input_dtype"], [["uint8"], ["int16"]])
+def test_invalid_parameters(
+    input_dtype,
+):
+    input_scale = 0.256
+    input_zero_point = 33
+    model = make_model(
+        [1, 16, 16, 3],
+        input_dtype,
+        input_dtype,
+        input_scale,
+        input_zero_point,
+        input_scale,
+        input_zero_point,
+    )
+
+    orig_mod = make_module(model)
+    cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
+
+    attrs = [
+        cmsisnn_mod[var.name_hint].attrs
+        for var in cmsisnn_mod.get_global_vars()
+        if cmsisnn_mod[var.name_hint].attrs
+    ]
+    assert not any(attrs), "No function should have an external attribute."
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/python/contrib/test_cmsisnn/test_networks.py b/tests/python/contrib/test_cmsisnn/test_networks.py
index 1f6e0e7..b14a15c 100644
--- a/tests/python/contrib/test_cmsisnn/test_networks.py
+++ b/tests/python/contrib/test_cmsisnn/test_networks.py
@@ -17,18 +17,16 @@
 
 """CMSIS-NN: testing with networks"""
 
-import platform
 import sys
-import os
-import pathlib
-import tvm
+
+import numpy as np
+import pytest
+
 from tvm import relay
 from tvm.contrib.download import download_testdata
 from tvm.relay.op.contrib import cmsisnn
-import numpy as np
-import pytest
-import itertools
 
+from utils import skip_if_no_reference_system, get_range_for_dtype_str
 from tests.python.relay.aot.aot_test_utils import (
     AOTTestModel,
     AOT_CORSTONE300_RUNNER,
@@ -37,30 +35,6 @@ from tests.python.relay.aot.aot_test_utils import (
 )
 
 
-def get_range_for_dtype_str(dtype):
-    """
-    Produce the min,max for a give data type.
-
-    Parameters
-    ----------
-    dtype : str
-        a type string (e.g., int8)
-
-    Returns
-    -------
-    type_info.min : int
-        the minimum of the range
-    type_info.max : int
-        the maximum of the range
-    """
-
-    try:
-        type_info = np.iinfo(dtype)
-    except ValueError:
-        type_info = np.finfo(dtype)
-    return type_info.min, type_info.max
-
-
 def convert_to_relay(
     tflite_model_buf,
     input_data,
@@ -99,9 +73,7 @@ def convert_to_relay(
     return mod, params
 
 
-@pytest.mark.skipif(
-    platform.machine() == "i686", reason="Reference system unavailable in i386 container"
-)
+@skip_if_no_reference_system
 def test_cnn_small():
     # download the model
     base_url = "https://github.com/ARM-software/ML-zoo/raw/master/models/keyword_spotting/cnn_small/tflite_int8"
diff --git a/tests/python/contrib/test_cmsisnn/test_softmax.py b/tests/python/contrib/test_cmsisnn/test_softmax.py
index c1951d1..12e11c3 100644
--- a/tests/python/contrib/test_cmsisnn/test_softmax.py
+++ b/tests/python/contrib/test_cmsisnn/test_softmax.py
@@ -17,17 +17,21 @@
 
 """CMSIS-NN integration tests: softmax"""
 
-import platform
 import sys
-import os
-import pathlib
-import tvm
-from tvm import relay
-from tvm.relay.op.contrib import cmsisnn
+import itertools
+
 import numpy as np
 import pytest
-import itertools
 
+from tvm import relay
+from tvm.relay.op.contrib import cmsisnn
+
+from utils import (
+    skip_if_no_reference_system,
+    make_module,
+    count_num_calls,
+    get_range_for_dtype_str,
+)
 from tests.python.relay.aot.aot_test_utils import (
     AOTTestModel,
     AOT_CORSTONE300_RUNNER,
@@ -36,61 +40,9 @@ from tests.python.relay.aot.aot_test_utils import (
 )
 
 
-def get_range_for_dtype_str(dtype):
-    """
-    Produce the min,max for a give data type.
-
-    Parameters
-    ----------
-    dtype : str
-        a type string (e.g., int8)
-
-    Returns
-    -------
-    type_info.min : int
-        the minimum of the range
-    type_info.max : int
-        the maximum of the range
-    """
-
-    try:
-        type_info = np.iinfo(dtype)
-    except ValueError:
-        type_info = np.finfo(dtype)
-    return type_info.min, type_info.max
-
-
-def count_num_calls(mod):
-    """Count number of CallNode in the IRModule"""
-
-    class CallCounter(relay.ExprVisitor):
-        def __init__(self):
-            super().__init__()
-            self.count = 0
-
-        def visit_call(self, call):
-            if isinstance(call.op, tvm.ir.Op):
-                self.count += 1
-
-            super().visit_call(call)
-
-    counter = CallCounter()
-    for var in mod.get_global_vars():
-        counter.visit(mod[var.name_hint])
-    return counter.count
-
-
-def make_module(func):
-    """Create IRModule from Function"""
-    func = relay.Function(relay.analysis.free_vars(func), func)
-    mod = tvm.IRModule.from_expr(func)
-    return relay.transform.InferType()(mod)
-
-
 def make_model(
     shape, in_dtype, out_dtype, in_zero_point, in_scale, out_zero_point=-128, out_scale=1.0 / 256
 ):
-
     """Create a Relay Function / network model"""
     a = relay.var("in0", shape=shape, dtype=in_dtype)
     dequantize = relay.qnn.op.dequantize(
@@ -108,9 +60,7 @@ def make_model(
     return model
 
 
-@pytest.mark.skipif(
-    platform.machine() == "i686", reason="Reference system unavailable in i386 container"
-)
+@skip_if_no_reference_system
 @pytest.mark.parametrize(["zero_point", "scale"], [[33, 0.256], [-64, 0.0128]])
 def test_softmax_int8(zero_point, scale):
     interface_api = "c"
diff --git a/tests/python/contrib/test_cmsisnn/utils.py b/tests/python/contrib/test_cmsisnn/utils.py
new file mode 100644
index 0000000..3fd12ef
--- /dev/null
+++ b/tests/python/contrib/test_cmsisnn/utils.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""CMSIS-NN functions for testing networks"""
+
+import platform
+
+import numpy as np
+import pytest
+
+import tvm
+from tvm import relay
+
+
+def skip_if_no_reference_system(func):
+    return pytest.mark.skipif(
+        platform.machine() == "i686", reason="Reference system unavailable in i386 container"
+    )(func)
+
+
+def count_num_calls(mod):
+    """Count number of CallNode in the IRModule"""
+
+    class CallCounter(relay.ExprVisitor):
+        def __init__(self):
+            super().__init__()
+            self.count = 0
+
+        def visit_call(self, call):
+            if isinstance(call.op, tvm.ir.Op):
+                self.count += 1
+
+            super().visit_call(call)
+
+    counter = CallCounter()
+    for var in mod.get_global_vars():
+        counter.visit(mod[var.name_hint])
+    return counter.count
+
+
+def get_range_for_dtype_str(dtype):
+    """
+    Produce the min,max for a give data type.
+
+    Parameters
+    ----------
+    dtype : str
+        a type string (e.g., int8)
+
+    Returns
+    -------
+    type_info.min : int
+        the minimum of the range
+    type_info.max : int
+        the maximum of the range
+    """
+
+    try:
+        type_info = np.iinfo(dtype)
+    except ValueError:
+        type_info = np.finfo(dtype)
+    return type_info.min, type_info.max
+
+
+def make_module(func):
+    """Create IRModule from Function"""
+    func = relay.Function(relay.analysis.free_vars(func), func)
+    mod = tvm.IRModule.from_expr(func)
+    return relay.transform.InferType()(mod)