You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2022/06/17 16:42:58 UTC

[tvm] branch main updated: [CMSIS-NN] Fixed the case with repeating operands in the QNN binary ops (#11732)

This is an automated email from the ASF dual-hosted git repository.

mousius pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dffc3108bb  [CMSIS-NN] Fixed the case with repeating operands in the QNN binary ops (#11732)
dffc3108bb is described below

commit dffc3108bbd406e7da49693533983adba634b19d
Author: Ashutosh Parkhi <86...@users.noreply.github.com>
AuthorDate: Fri Jun 17 17:42:49 2022 +0100

     [CMSIS-NN] Fixed the case with repeating operands in the QNN binary ops (#11732)
---
 python/tvm/relay/op/contrib/cmsisnn.py             |  1 -
 .../backend/contrib/cmsisnn/extract_constants.cc   | 13 ++++-
 src/relay/backend/contrib/cmsisnn/relay_to_tir.cc  | 14 ++++-
 .../contrib/cmsisnn/scalar_to_tensor_constant.cc   |  6 +++
 .../python/contrib/test_cmsisnn/test_binary_ops.py | 61 +++++++++++++++++++++-
 .../contrib/test_cmsisnn/test_extract_constants.py | 34 ++++++++++++
 .../test_cmsisnn/test_scalar_to_tensor_constant.py | 41 +++++++++++++++
 7 files changed, 165 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py
index 09831929e5..8d714b7269 100644
--- a/python/tvm/relay/op/contrib/cmsisnn.py
+++ b/python/tvm/relay/op/contrib/cmsisnn.py
@@ -223,7 +223,6 @@ def pattern_table():
     def check_qnn_max_pool2d(pattern):
         """Check if max pool2d is supported by CMSIS-NN."""
         output = pattern
-        input_op = None
 
         if str(pattern.op.name) == "clip":
             pooling = pattern.args[0]
diff --git a/src/relay/backend/contrib/cmsisnn/extract_constants.cc b/src/relay/backend/contrib/cmsisnn/extract_constants.cc
index 1cbe36e30f..c6ed7af9ff 100644
--- a/src/relay/backend/contrib/cmsisnn/extract_constants.cc
+++ b/src/relay/backend/contrib/cmsisnn/extract_constants.cc
@@ -164,7 +164,18 @@ class ExtractConstantsMutator : public MixedModeMutator {
           function_signature.push_back(arg);
         } else {
           if (arg.as<VarNode>()) {
-            function_signature.push_back(arg);
+            // Only push if its not already present as multiple consumers of any input var
+            // will appear only once in the function signature.
+            bool found_in_existing_signature = false;
+            for (auto& sign : function_signature) {
+              if (arg.same_as(sign)) {
+                found_in_existing_signature = true;
+                break;
+              }
+            }
+            if (!found_in_existing_signature) {
+              function_signature.push_back(arg);
+            }
           }
           new_args.push_back(arg);
         }
diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
index 524735caa9..5c99061fa8 100644
--- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
+++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
@@ -556,7 +556,12 @@ class RelayToTIRVisitor : public MixedModeMutator {
 
     BufferCreator buffer_creator;
     tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(8));
-    tir::Var input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8));
+    tir::Var input_1;
+    if (mul_call->args[0].same_as(mul_call->args[1])) {
+      input_1 = input_0;
+    } else {
+      input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8));
+    }
     tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8));
 
     tvm::Array<PrimExpr> args = {
@@ -626,7 +631,12 @@ class RelayToTIRVisitor : public MixedModeMutator {
 
     BufferCreator buffer_creator;
     tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(8));
-    tir::Var input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8));
+    tir::Var input_1;
+    if (add_call->args[0].same_as(add_call->args[1])) {
+      input_1 = input_0;
+    } else {
+      input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8));
+    }
     tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8));
 
     tvm::Array<PrimExpr> args = {
diff --git a/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc b/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
index 2448bfc766..40fd773eb2 100644
--- a/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
+++ b/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
@@ -179,6 +179,12 @@ class ScalarToTensorConstantMutator : public MixedModeMutator {
     auto new_body = VisitExpr(func->body);
     Function new_func = WithFields(func, FreeVars(new_body), new_body, func->ret_type,
                                    FreeTypeVars(new_body, mod_), func->attrs);
+
+    // Updating new_func parameters could result into uniquification of function parameters.
+    // Call arguments need to be aligned to the number of arguments expected by new_func.
+    if (new_args[0].same_as(new_args[1])) {
+      new_args.erase(new_args.begin());
+    }
     return Call(new_func, new_args);
   }
 
diff --git a/tests/python/contrib/test_cmsisnn/test_binary_ops.py b/tests/python/contrib/test_cmsisnn/test_binary_ops.py
index fec18c197e..26604da0a6 100644
--- a/tests/python/contrib/test_cmsisnn/test_binary_ops.py
+++ b/tests/python/contrib/test_cmsisnn/test_binary_ops.py
@@ -101,7 +101,7 @@ def make_model(
 def test_op_int8(
     op, relu_type, input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point
 ):
-    """Tests QNN Conv2D operator for CMSIS-NN"""
+    """Tests QNN binary operator for CMSIS-NN"""
     interface_api = "c"
     use_unpacked_api = True
     test_runner = AOT_USMP_CORSTONE300_RUNNER
@@ -145,6 +145,65 @@ def test_op_int8(
     )
 
 
+@skip_if_no_reference_system
+@tvm.testing.requires_cmsisnn
+@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add])
+@pytest.mark.parametrize("relu_type", ["RELU", "NONE"])
+def test_same_input_to_binary_op(op, relu_type):
+    """Tests QNN binary operator for CMSIS-NN where both inputs are the same"""
+    interface_api = "c"
+    use_unpacked_api = True
+    test_runner = AOT_USMP_CORSTONE300_RUNNER
+
+    dtype = "int8"
+    shape = [1, 16, 16, 3]
+    input_ = generate_variable("input")
+    input_scale = 0.256
+    input_zero_point = 33
+
+    model = make_model(
+        op,
+        input_,
+        input_,
+        input_scale,
+        input_zero_point,
+        input_scale,
+        input_zero_point,
+        relu_type,
+    )
+    orig_mod = make_module(model)
+
+    cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
+
+    # validate pattern matching
+    assert_partitioned_function(orig_mod, cmsisnn_mod)
+
+    # Check if the number of internal function parameter is 1
+    cmsisnn_global_func = cmsisnn_mod["tvmgen_default_cmsis_nn_main_0"]
+    assert (
+        isinstance(cmsisnn_global_func.body, tvm.relay.expr.Call)
+        and len(cmsisnn_global_func.body.args) == 1
+    ), "Composite function for the binary op should have only 1 parameter."
+
+    # validate the output
+    in_min, in_max = get_range_for_dtype_str(dtype)
+    inputs = {
+        "input": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype),
+    }
+    output_list = generate_ref_data(orig_mod["main"], inputs)
+    compile_and_run(
+        AOTTestModel(
+            module=cmsisnn_mod,
+            inputs=inputs,
+            outputs=output_list,
+            output_tolerance=1,
+        ),
+        test_runner,
+        interface_api,
+        use_unpacked_api,
+    )
+
+
 def parameterize_for_constant_inputs(test):
     """Generates parameters in such a way so that at least one of the inputs is a constant,
     both can't be variables, both can't be scalars.
diff --git a/tests/python/contrib/test_cmsisnn/test_extract_constants.py b/tests/python/contrib/test_cmsisnn/test_extract_constants.py
index 8831596d40..7d3e81a9c7 100644
--- a/tests/python/contrib/test_cmsisnn/test_extract_constants.py
+++ b/tests/python/contrib/test_cmsisnn/test_extract_constants.py
@@ -116,6 +116,40 @@ def test_nested_function():
     relay.transform.InferType()(mod)
 
 
+@tvm.testing.requires_cmsisnn
+def test_internal_function_with_duplicate_arguments():
+    """Tests the pass ExternConstants when a composite function
+    is present within global function with repeating arguments
+    to one of the binary ops.
+    """
+    input0 = relay.var("input0", shape=(8, 8))
+    binary_op0 = input0 + input0
+    binary_op1 = binary_op0 * relay.const(5.0, "float32")
+    local_func = relay.Function([input0], binary_op1, relay.TensorType((8, 8), "float32"))
+    local_func = set_composite_func_attr(local_func, "cmsis-nn")
+
+    arg = relay.var("arg", shape=(8, 8))
+    call_local_func = relay.Call(local_func, [arg])
+    extern_func = relay.Function([arg], call_local_func, relay.TensorType((8, 8), "float32"))
+
+    global_arg = relay.var("global_var", shape=(8, 8))
+    global_var = relay.GlobalVar("external_function")
+    extern_func = set_external_func_attr(extern_func, "cmsis-nn", global_var.name_hint)
+    call_extern_func = relay.Call(global_var, [global_arg])
+    main_func = relay.Function([global_arg], call_extern_func, relay.TensorType((8, 8), "float32"))
+    main_var = relay.GlobalVar("main")
+
+    mod = tvm.IRModule()
+    mod[global_var] = extern_func
+    mod[main_var] = main_func
+
+    mod = ExtractConstantsFromPartitionedFunction()(mod)
+    constant_verifier = CheckFunctionsForConstants()
+    constant_verifier.visit_function(mod[global_var])
+    constant_verifier.check_num_constants()
+    relay.transform.InferType()(mod)
+
+
 @tvm.testing.requires_cmsisnn
 def test_multiple_functions():
     """Tests the pass ExternConstants when global function
diff --git a/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py b/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py
index 557a65aeff..df54f7ce55 100644
--- a/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py
+++ b/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py
@@ -256,6 +256,47 @@ def test_all_primary_operands_tensor_constants():
     assert tvm.ir.structural_equal(mod[global_var].body, new_mod[global_var].body)
 
 
+@tvm.testing.requires_cmsisnn
+def test_duplicate_constant_arguments():
+    """Tests the pass when repeating operands are arguments to the binary op"""
+    dtype = "int8"
+    shape = (1, 3, 3, 32)
+    operand0 = generate_variable("operand0", shape, dtype)
+    operand1 = generate_variable("operand1", shape, dtype)
+    binary_op = make_binary_op(
+        relay.qnn.op.add,
+        operand0,
+        operand0,
+        input_0_scale=0.0128,
+        input_0_zero_point=32,
+        input_1_scale=0.256,
+        input_1_zero_point=-64,
+    )
+
+    local_func = relay.Function([operand0, operand1], binary_op, relay.TensorType(shape, dtype))
+    local_func = set_composite_func_attr(local_func, "cmsis-nn.qnn_add")
+
+    rng = np.random.default_rng(12345)
+    arg0 = relay.const(rng.integers(-128, high=127, size=shape, dtype=dtype))
+    call_local_func = relay.Call(local_func, [arg0, arg0])
+    extern_func = relay.Function([], call_local_func, relay.TensorType(shape, dtype))
+
+    global_var = relay.GlobalVar("external_function")
+    extern_func = set_external_func_attr(extern_func, "cmsis-nn", global_var.name_hint)
+    call_extern_func = relay.Call(global_var, [])
+    main_func = relay.Function([], call_extern_func, relay.TensorType(shape, dtype))
+    main_var = relay.GlobalVar("main")
+
+    mod = tvm.IRModule()
+    mod[global_var] = extern_func
+    mod[main_var] = main_func
+
+    mod = relay.transform.InferType()(mod)
+    mod = ScalarToTensorConstants()(mod)
+    new_mod = relay.transform.InferType()(mod)
+    assert tvm.ir.structural_equal(mod[global_var].body, new_mod[global_var].body)
+
+
 @tvm.testing.requires_cmsisnn
 def test_non_cmsisnn_ext_func():
     """Non CMSISNN functions should not be altered."""