You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/01/31 23:46:41 UTC

[GitHub] [tvm] Mousius commented on a change in pull request #10100: [CMSIS-NN] Convert scalar constants to tensor constants

Mousius commented on a change in pull request #10100:
URL: https://github.com/apache/tvm/pull/10100#discussion_r796156636



##########
File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc
##########
@@ -62,17 +62,87 @@ class ExtractConstantsMutator : public MixedModeMutator {
       return func;
     }
 
-    function_to_constants_.Set(func, Array<Constant>{});
+    function_to_arguments_.Set(func, Array<Expr>{});
     functions_.push_back(func);
     auto new_body = VisitExpr(func->body);
     functions_.pop_back();
-    if (function_to_constants_[func].size()) {
+    if (function_to_arguments_[func].size()) {
       func = WithFields(func, FreeVars(new_body), new_body, func->ret_type,
                         FreeTypeVars(new_body, mod_), func->attrs);
     }
     return std::move(func);
   }
 
+  // Creates new arguments from current call's arguments
+  // Updates constants into the caller arguments: here caller signifies caller that comprises call
+  // to func
+  Array<Expr> CreateNewCallArgsFromKickedoutConstants(Call call, Function func) {
+    ICHECK(function_to_arguments_.find(func) != function_to_arguments_.end());
+    Array<Expr> fSignature(function_to_arguments_[func]);
+
+    // Is func a global_function?
+    // main() is not registered for kicking out constants

Review comment:
       ```suggestion
       // main() is not registered for extracting constants
   ```

##########
File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc
##########
@@ -62,17 +62,87 @@ class ExtractConstantsMutator : public MixedModeMutator {
       return func;
     }
 
-    function_to_constants_.Set(func, Array<Constant>{});
+    function_to_arguments_.Set(func, Array<Expr>{});
     functions_.push_back(func);
     auto new_body = VisitExpr(func->body);
     functions_.pop_back();
-    if (function_to_constants_[func].size()) {
+    if (function_to_arguments_[func].size()) {
       func = WithFields(func, FreeVars(new_body), new_body, func->ret_type,
                         FreeTypeVars(new_body, mod_), func->attrs);
     }
     return std::move(func);
   }
 
+  // Creates new arguments from current call's arguments
+  // Updates constants into the caller arguments: here caller signifies caller that comprises call
+  // to func
+  Array<Expr> CreateNewCallArgsFromKickedoutConstants(Call call, Function func) {
+    ICHECK(function_to_arguments_.find(func) != function_to_arguments_.end());
+    Array<Expr> fSignature(function_to_arguments_[func]);
+
+    // Is func a global_function?
+    // main() is not registered for kicking out constants
+    bool is_global_function = functions_.empty() ? true : false;
+
+    bool new_constants_added = false;
+    // This tracks arguments traversed inside fSignature
+    uint32_t fsignature_id = 0;

Review comment:
       ```suggestion
       uint32_t function_signature_id = 0;
   ```

##########
File path: tests/python/contrib/test_cmsisnn/test_binary_ops.py
##########
@@ -46,12 +54,26 @@ def make_model(
     input_1_zero_point,
     out_scale=1.0 / 256,
     out_zero_point=-128,
+    input_0_type=BinaryOpInputType.Variable,
+    input_1_type=BinaryOpInputType.Variable,
 ):
     """Create a Relay Function / network model"""
 
+    def create_input(name, input_type, shape, dtype):

Review comment:
       Does this have to be nested inside `make_model`? We should also consider if we want this many branches in our test functions, you could parameterise with input creation functions to flatten out the logic.

##########
File path: src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
##########
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file scalar_to_tensor_constant.cc
+ * \brief Converts scalar constant into tensor constant for binary ops of CMSIS-NN
+ */
+
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../op/make_op.h"
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+/*!
+ * \brief This Mutator finds all partitioned functions meant for CMSIS-NN binary ops.
+ * Then, it substitutes the scalar constants with tensor constants. It makes the shape of this
+ * new constant same as that of the neighbouring constant of the other binary operand. The
+ * expectation is that the ExtractConstant pass would later kick this tensor constant out of the
+ * global partitioned function, thus making the entire global partitioned and its composite function
+ * constant free. This makes the TIR generation for binary ops via CMSIS-NN independent of
+ * constants.
+ */
+class ScalarToTensorConstantMutator : public MixedModeMutator {
+ public:
+  explicit ScalarToTensorConstantMutator(const IRModule& mod) : mod_(mod) {}
+
+ private:
+  using MixedModeMutator::VisitExpr_;
+
+  // Here is an example with the annotated scalar constant:
+  // def @tvmgen_default_cmsis_nn_main_1(%cmsis_nn_input: Tensor[], Inline=1, Compiler="cmsis-nn",
+  //                                     global_symbol="tvmgen_default_cmsis_nn_main",
+  //                                     Primitive=1) -> Tensor[] {
+  //   %56 = fn (%input0: _scalar_constant_, %input1: Tensor[],
+  //             PartitionedFromPattern="qnn.mul_", Composite="cmsis-nn.qnn_mul") -> Tensor[] {
+  //     qnn.mul(%input0, %input1, scale0, zero_point0,
+  //              scale1, zero_point_1, output_scale, output_zero_point)
+  //   };
+  //   %56(meta[relay.Constant] /* _scalar constant_ */, %cmsis-nn_input)
+  // }
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    Expr final_call = post;
+    call = post.as<CallNode>();
+
+    // Create a new variable argument that is of the same shape as the neibouring argument

Review comment:
       ```suggestion
       // Create a new variable argument that is of the same shape as the neighbouring argument
   ```

##########
File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc
##########
@@ -62,17 +62,87 @@ class ExtractConstantsMutator : public MixedModeMutator {
       return func;
     }
 
-    function_to_constants_.Set(func, Array<Constant>{});
+    function_to_arguments_.Set(func, Array<Expr>{});
     functions_.push_back(func);
     auto new_body = VisitExpr(func->body);
     functions_.pop_back();
-    if (function_to_constants_[func].size()) {
+    if (function_to_arguments_[func].size()) {
       func = WithFields(func, FreeVars(new_body), new_body, func->ret_type,
                         FreeTypeVars(new_body, mod_), func->attrs);
     }
     return std::move(func);
   }
 
+  // Creates new arguments from current call's arguments
+  // Updates constants into the caller arguments: here caller signifies caller that comprises call
+  // to func
+  Array<Expr> CreateNewCallArgsFromKickedoutConstants(Call call, Function func) {
+    ICHECK(function_to_arguments_.find(func) != function_to_arguments_.end());
+    Array<Expr> fSignature(function_to_arguments_[func]);

Review comment:
       ```suggestion
       Array<Expr> function_signature(function_to_arguments_[func]);
   ```

##########
File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc
##########
@@ -62,17 +62,87 @@ class ExtractConstantsMutator : public MixedModeMutator {
       return func;
     }
 
-    function_to_constants_.Set(func, Array<Constant>{});
+    function_to_arguments_.Set(func, Array<Expr>{});
     functions_.push_back(func);
     auto new_body = VisitExpr(func->body);
     functions_.pop_back();
-    if (function_to_constants_[func].size()) {
+    if (function_to_arguments_[func].size()) {
       func = WithFields(func, FreeVars(new_body), new_body, func->ret_type,
                         FreeTypeVars(new_body, mod_), func->attrs);
     }
     return std::move(func);
   }
 
+  // Creates new arguments from current call's arguments
+  // Updates constants into the caller arguments: here caller signifies caller that comprises call
+  // to func
+  Array<Expr> CreateNewCallArgsFromKickedoutConstants(Call call, Function func) {
+    ICHECK(function_to_arguments_.find(func) != function_to_arguments_.end());
+    Array<Expr> fSignature(function_to_arguments_[func]);
+
+    // Is func a global_function?
+    // main() is not registered for kicking out constants
+    bool is_global_function = functions_.empty() ? true : false;
+
+    bool new_constants_added = false;
+    // This tracks arguments traversed inside fSignature
+    uint32_t fsignature_id = 0;
+    // This contains arguments including constants for the caller of this function inside which
+    // post_call resides.
+    Array<Expr> new_caller_args;
+    // New arguments to post_call that includes new variables representing constants kicked out of
+    // the function
+    Array<Expr> new_call_args;
+    for (auto& arg : call->args) {
+      if (auto* constant = arg.as<ConstantNode>()) {
+        new_caller_args.push_back(arg);
+        new_call_args.push_back(Var(gen_var_name(), constant->tensor_type()));
+        ++fsignature_id;
+        new_constants_added = true;
+        continue;
+      }
+
+      // Push all constants from the fSignature until a variable corresponding to current argument
+      // is hit
+      while (fsignature_id < fSignature.size()) {
+        auto* constant = fSignature[fsignature_id].as<ConstantNode>();
+        if (constant == nullptr) {
+          break;
+        }
+        new_caller_args.push_back(fSignature[fsignature_id++]);
+        new_call_args.push_back(Var(gen_var_name(), constant->tensor_type()));
+        new_constants_added = true;
+      }
+
+      new_call_args.push_back(arg);
+      if (is_global_function || arg.as<VarNode>()) {
+        new_caller_args.push_back(arg);
+      }
+      ++fsignature_id;
+    }
+
+    // Push remaining constants as new arguments
+    for (uint32_t i = fsignature_id; i < fSignature.size(); ++i) {
+      auto* constant = fSignature[i].as<ConstantNode>();
+      ICHECK(constant)
+          << "Rest of the collected arguments should be constant in the partitioned function.";
+      new_caller_args.push_back(GetRef<Constant>(constant));
+      new_call_args.push_back(Var(gen_var_name(), constant->tensor_type()));
+      new_constants_added = true;
+    }
+
+    // Update the arguments of caller of local function
+    if (new_constants_added && !is_global_function) {
+      const Function& last_func = functions_.back();
+      Array<Expr> fconstants(function_to_arguments_[last_func]);

Review comment:
       ```suggestion
         Array<Expr> function_constants(function_to_arguments_[last_func]);
   ```

##########
File path: tests/python/contrib/test_cmsisnn/test_binary_ops.py
##########
@@ -131,6 +153,140 @@ def test_op_int8(op, input_0_scale, input_0_zero_point, input_1_scale, input_1_z
     )
 
 
+# At least one of the inputs is a constant
+def parameterize_for_constant_inputs(test):
+    op = [relay.qnn.op.mul, relay.qnn.op.add]
+    input_0_type = [
+        BinaryOpInputType.Variable,
+        BinaryOpInputType.TensorConstant,
+        BinaryOpInputType.ScalarConstant,
+    ]
+    input_1_type = [
+        BinaryOpInputType.Variable,
+        BinaryOpInputType.TensorConstant,
+        BinaryOpInputType.ScalarConstant,
+    ]
+    all_combinations = itertools.product(op, input_0_type, input_1_type)
+    all_combinations = filter(
+        lambda parameters: not (
+            (
+                parameters[1] == BinaryOpInputType.Variable
+                and parameters[2] == BinaryOpInputType.Variable
+            )
+            or (
+                parameters[1] == BinaryOpInputType.ScalarConstant
+                and parameters[2] == BinaryOpInputType.ScalarConstant
+            )
+        ),
+        all_combinations,
+    )
+    return pytest.mark.parametrize(
+        ["op", "input_0_type", "input_1_type"],
+        all_combinations,
+    )(test)
+
+
+@skip_if_no_reference_system
+@tvm.testing.requires_cmsisnn
+@parameterize_for_constant_inputs
+def test_constant_input_int8(op, input_0_type, input_1_type):
+    interface_api = "c"
+    use_unpacked_api = True
+    test_runner = AOT_CORSTONE300_RUNNER
+
+    dtype = "int8"
+    shape = [1, 16, 16, 3]
+    input_0_scale = 0.256
+    input_0_zero_point = 33
+    input_1_scale = 0.128
+    input_1_zero_point = -24
+    model = make_model(
+        op,
+        shape,
+        dtype,
+        dtype,
+        input_0_scale,
+        input_0_zero_point,
+        input_1_scale,
+        input_1_zero_point,
+        input_0_type=input_0_type,
+        input_1_type=input_1_type,
+    )
+    orig_mod = make_module(model)
+
+    cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
+
+    # validate pattern matching
+    attrs = [
+        cmsisnn_mod[var.name_hint].attrs
+        for var in cmsisnn_mod.get_global_vars()
+        if cmsisnn_mod[var.name_hint].attrs
+    ]
+    assert any(attrs), "At least one function with external attributes was expected."
+
+    compilers = [
+        key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items()
+    ]
+    assert any(compilers), "Module does not contain function for cmsisnn target."
+
+    assert count_num_calls(orig_mod) == count_num_calls(
+        cmsisnn_mod
+    ), "Number of calls changed during partitioning"
+
+    # validate the output
+    in_min, in_max = get_range_for_dtype_str(dtype)
+    inputs = {}
+    if input_0_type == BinaryOpInputType.Variable:
+        inputs.update({"input_0": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype)})
+    if input_1_type == BinaryOpInputType.Variable:
+        inputs.update({"input_1": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype)})
+    output_list = generate_ref_data(orig_mod["main"], inputs)
+    compile_and_run(
+        AOTTestModel(
+            module=cmsisnn_mod,
+            inputs=inputs,
+            outputs=output_list,
+            output_tolerance=1,
+        ),
+        test_runner,
+        interface_api,
+        use_unpacked_api,
+    )
+
+
+@skip_if_no_reference_system
+@tvm.testing.requires_cmsisnn
+@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add])
+def test_both_scalar_inputs_int8(
+    op,
+):
+    input_scale = 0.256
+    input_zero_point = 33
+    dtype = "int8"
+    model = make_model(
+        op,
+        [1, 16, 16, 3],
+        dtype,
+        dtype,
+        input_scale,
+        input_zero_point,
+        input_scale,
+        input_zero_point,
+        input_0_type=BinaryOpInputType.ScalarConstant,
+        input_1_type=BinaryOpInputType.ScalarConstant,
+    )
+
+    orig_mod = make_module(model)
+    cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
+
+    attrs = [
+        cmsisnn_mod[var.name_hint].attrs
+        for var in cmsisnn_mod.get_global_vars()
+        if cmsisnn_mod[var.name_hint].attrs
+    ]
+    assert not any(attrs), "No function should have an external attribute."

Review comment:
       Similar to the above, can we put this in a function for re-use?

##########
File path: src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
##########
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file scalar_to_tensor_constant.cc
+ * \brief Converts scalar constant into tensor constant for binary ops of CMSIS-NN
+ */
+
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../op/make_op.h"
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+/*!
+ * \brief This Mutator finds all partitioned functions meant for CMSIS-NN binary ops.
+ * Then, it substitutes the scalar constants with tensor constants. It makes the shape of this
+ * new constant same as that of the neighbouring constant of the other binary operand. The
+ * expectation is that the ExtractConstant pass would later kick this tensor constant out of the
+ * global partitioned function, thus making the entire global partitioned and its composite function
+ * constant free. This makes the TIR generation for binary ops via CMSIS-NN independent of
+ * constants.
+ */
+class ScalarToTensorConstantMutator : public MixedModeMutator {
+ public:
+  explicit ScalarToTensorConstantMutator(const IRModule& mod) : mod_(mod) {}
+
+ private:
+  using MixedModeMutator::VisitExpr_;
+
+  // Here is an example with the annotated scalar constant:
+  // def @tvmgen_default_cmsis_nn_main_1(%cmsis_nn_input: Tensor[], Inline=1, Compiler="cmsis-nn",
+  //                                     global_symbol="tvmgen_default_cmsis_nn_main",
+  //                                     Primitive=1) -> Tensor[] {
+  //   %56 = fn (%input0: _scalar_constant_, %input1: Tensor[],
+  //             PartitionedFromPattern="qnn.mul_", Composite="cmsis-nn.qnn_mul") -> Tensor[] {
+  //     qnn.mul(%input0, %input1, scale0, zero_point0,
+  //              scale1, zero_point_1, output_scale, output_zero_point)
+  //   };
+  //   %56(meta[relay.Constant] /* _scalar constant_ */, %cmsis-nn_input)
+  // }
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    Expr final_call = post;
+    call = post.as<CallNode>();
+
+    // Create a new variable argument that is of the same shape as the neibouring argument
+    // in the binary op. This needs to be done only when one of the arguments is a scalar.
+    if (auto* opnode = call->op.as<OpNode>()) {
+      String op_name = opnode->name;
+      Array<Expr> new_args;
+      for (uint32_t i = 0; i < call->args.size(); ++i) {
+        Expr arg = call->args[i];
+        new_args.push_back(arg);
+        if (!arg->checked_type_.defined()) {
+          continue;
+        }
+        auto* arg_type = arg->type_as<TensorTypeNode>();
+        if (arg_type->shape.size() != 0 || arg.as<ConstantNode>()) {
+          continue;
+        }
+        String arg_name = arg.as<VarNode>()->name_hint();
+        int tensor_arg_id = (i + 1) % 2;
+        Expr tensor_arg = call->args[tensor_arg_id];
+        if (!tensor_arg->checked_type_.defined()) {
+          continue;
+        }
+        TensorType tensor_type = GetRef<TensorType>(tensor_arg->type_as<TensorTypeNode>());
+        new_args.Set(i, Var(arg_name, tensor_type));
+      }
+      final_call = Call(call->op, new_args, call->attrs, {});

Review comment:
       Can these larger blocks inside of the if-statements be extracted into functions for clarity?

##########
File path: tests/python/contrib/test_cmsisnn/test_binary_ops.py
##########
@@ -131,6 +153,140 @@ def test_op_int8(op, input_0_scale, input_0_zero_point, input_1_scale, input_1_z
     )
 
 
+# At least one of the inputs is a constant
+def parameterize_for_constant_inputs(test):
+    op = [relay.qnn.op.mul, relay.qnn.op.add]
+    input_0_type = [
+        BinaryOpInputType.Variable,
+        BinaryOpInputType.TensorConstant,
+        BinaryOpInputType.ScalarConstant,
+    ]
+    input_1_type = [
+        BinaryOpInputType.Variable,
+        BinaryOpInputType.TensorConstant,
+        BinaryOpInputType.ScalarConstant,
+    ]
+    all_combinations = itertools.product(op, input_0_type, input_1_type)
+    all_combinations = filter(
+        lambda parameters: not (
+            (
+                parameters[1] == BinaryOpInputType.Variable
+                and parameters[2] == BinaryOpInputType.Variable
+            )
+            or (
+                parameters[1] == BinaryOpInputType.ScalarConstant
+                and parameters[2] == BinaryOpInputType.ScalarConstant
+            )
+        ),
+        all_combinations,
+    )
+    return pytest.mark.parametrize(
+        ["op", "input_0_type", "input_1_type"],
+        all_combinations,
+    )(test)
+
+
+@skip_if_no_reference_system
+@tvm.testing.requires_cmsisnn
+@parameterize_for_constant_inputs
+def test_constant_input_int8(op, input_0_type, input_1_type):
+    interface_api = "c"
+    use_unpacked_api = True
+    test_runner = AOT_CORSTONE300_RUNNER
+
+    dtype = "int8"
+    shape = [1, 16, 16, 3]
+    input_0_scale = 0.256
+    input_0_zero_point = 33
+    input_1_scale = 0.128
+    input_1_zero_point = -24
+    model = make_model(
+        op,
+        shape,
+        dtype,
+        dtype,
+        input_0_scale,
+        input_0_zero_point,
+        input_1_scale,
+        input_1_zero_point,
+        input_0_type=input_0_type,
+        input_1_type=input_1_type,
+    )
+    orig_mod = make_module(model)
+
+    cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
+
+    # validate pattern matching
+    attrs = [
+        cmsisnn_mod[var.name_hint].attrs
+        for var in cmsisnn_mod.get_global_vars()
+        if cmsisnn_mod[var.name_hint].attrs
+    ]
+    assert any(attrs), "At least one function with external attributes was expected."
+
+    compilers = [
+        key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items()
+    ]
+    assert any(compilers), "Module does not contain function for cmsisnn target."
+
+    assert count_num_calls(orig_mod) == count_num_calls(
+        cmsisnn_mod
+    ), "Number of calls changed during partitioning"

Review comment:
       We're re-using this across tests, should we store it in an `assert_partitioned` or similar?

##########
File path: src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
##########
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file scalar_to_tensor_constant.cc
+ * \brief Converts scalar constant into tensor constant for binary ops of CMSIS-NN
+ */
+
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../op/make_op.h"
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+/*!
+ * \brief This Mutator finds all partitioned functions meant for CMSIS-NN binary ops.
+ * Then, it substitutes the scalar constants with tensor constants. It makes the shape of this
+ * new constant same as that of the neighbouring constant of the other binary operand. The
+ * expectation is that the ExtractConstant pass would later kick this tensor constant out of the

Review comment:
       ```suggestion
    * expectation is that the ExtractConstant pass would later extract the tensor constant out of the
   ```

##########
File path: src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
##########
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file scalar_to_tensor_constant.cc
+ * \brief Converts scalar constant into tensor constant for binary ops of CMSIS-NN
+ */
+
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../op/make_op.h"
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+/*!
+ * \brief This Mutator finds all partitioned functions meant for CMSIS-NN binary ops.
+ * Then, it substitutes the scalar constants with tensor constants. It makes the shape of this
+ * new constant same as that of the neighbouring constant of the other binary operand. The
+ * expectation is that the ExtractConstant pass would later kick this tensor constant out of the
+ * global partitioned function, thus making the entire global partitioned and its composite function
+ * constant free. This makes the TIR generation for binary ops via CMSIS-NN independent of
+ * constants.
+ */
+class ScalarToTensorConstantMutator : public MixedModeMutator {
+ public:
+  explicit ScalarToTensorConstantMutator(const IRModule& mod) : mod_(mod) {}
+
+ private:
+  using MixedModeMutator::VisitExpr_;
+
+  // Here is an example with the annotated scalar constant:
+  // def @tvmgen_default_cmsis_nn_main_1(%cmsis_nn_input: Tensor[], Inline=1, Compiler="cmsis-nn",
+  //                                     global_symbol="tvmgen_default_cmsis_nn_main",
+  //                                     Primitive=1) -> Tensor[] {
+  //   %56 = fn (%input0: _scalar_constant_, %input1: Tensor[],
+  //             PartitionedFromPattern="qnn.mul_", Composite="cmsis-nn.qnn_mul") -> Tensor[] {
+  //     qnn.mul(%input0, %input1, scale0, zero_point0,
+  //              scale1, zero_point_1, output_scale, output_zero_point)
+  //   };
+  //   %56(meta[relay.Constant] /* _scalar constant_ */, %cmsis-nn_input)
+  // }
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    Expr final_call = post;
+    call = post.as<CallNode>();
+
+    // Create a new variable argument that is of the same shape as the neibouring argument
+    // in the binary op. This needs to be done only when one of the arguments is a scalar.
+    if (auto* opnode = call->op.as<OpNode>()) {
+      String op_name = opnode->name;
+      Array<Expr> new_args;
+      for (uint32_t i = 0; i < call->args.size(); ++i) {
+        Expr arg = call->args[i];
+        new_args.push_back(arg);
+        if (!arg->checked_type_.defined()) {
+          continue;
+        }
+        auto* arg_type = arg->type_as<TensorTypeNode>();
+        if (arg_type->shape.size() != 0 || arg.as<ConstantNode>()) {
+          continue;
+        }
+        String arg_name = arg.as<VarNode>()->name_hint();
+        int tensor_arg_id = (i + 1) % 2;
+        Expr tensor_arg = call->args[tensor_arg_id];
+        if (!tensor_arg->checked_type_.defined()) {
+          continue;
+        }
+        TensorType tensor_type = GetRef<TensorType>(tensor_arg->type_as<TensorTypeNode>());
+        new_args.Set(i, Var(arg_name, tensor_type));
+      }
+      final_call = Call(call->op, new_args, call->attrs, {});
+    }
+
+    if (auto* glob_var_node = call->op.as<GlobalVarNode>()) {
+      GlobalVar global_var = GetRef<GlobalVar>(glob_var_node);
+      Function func = Downcast<Function>(mod_->Lookup(global_var));
+      auto compiler_name = func->GetAttr<String>(::tvm::relay::attr::kCompiler);
+      if (!compiler_name.defined() || compiler_name != "cmsis-nn") {
+        return final_call;
+      }
+      auto new_body = VisitExpr(func->body);
+      if (new_body.same_as(func->body)) {
+        return final_call;
+      }
+      Function new_func = Function(FreeVars(new_body), new_body, func->ret_type,
+                                   FreeTypeVars(new_body, mod_), func->attrs);
+      mod_->Update(global_var, new_func);
+      final_call = Call(global_var, call->args);
+    }
+
+    // Substitute scalar constant with a tensor constant in the call to composite function
+    // comprising partitioned binary ops. Shape of the new constant should be same as its
+    // neighbouring tensor's shape.
+    if (auto* func_node = call->op.as<FunctionNode>()) {
+      Function func = GetRef<Function>(func_node);
+      auto func_name = func->GetAttr<String>(attr::kComposite);
+      if (!func_name.defined() ||
+          (func_name != "cmsis-nn.qnn_add" && func_name != "cmsis-nn.qnn_mul")) {
+        return final_call;
+      }
+      Array<Expr> new_args;
+      for (uint32_t i = 0; i < call->args.size(); ++i) {
+        Expr scalar_arg = call->args[i];
+        Array<PrimExpr> scalar_shape = scalar_arg->type_as<TensorTypeNode>()->shape;
+        if (scalar_shape.size() == 0 && scalar_arg.as<ConstantNode>()) {
+          int tensor_arg_id = (i + 1) % 2;
+          Expr tensor_arg = call->args[tensor_arg_id];
+          Constant tensor_constant = TensorConstantFromScalar(scalar_arg, tensor_arg);
+          new_args.push_back(tensor_constant);
+        } else {
+          new_args.push_back(call->args[i]);
+        }
+      }
+      auto new_body = VisitExpr(func->body);
+      Function new_func = Function(FreeVars(new_body), new_body, func->ret_type,
+                                   FreeTypeVars(new_body, mod_), func->attrs);
+      final_call = Call(new_func, new_args);
+    }
+
+    return final_call;
+  }
+
+  // Makes tensor constant of same shape as tensor_arg with values from scalar_arg
+  Constant TensorConstantFromScalar(Expr scalar_arg, Expr tensor_arg) {
+    int8_t scalar_value = GetScalarFromConstant<int8_t>(scalar_arg);
+    TensorType tensor_type = GetRef<TensorType>(tensor_arg->type_as<TensorTypeNode>());
+    std::vector<int64_t> tensor_shape;
+    for (auto& dim : tensor_type->shape) {
+      tensor_shape.push_back(qnn::get_const_int(dim));
+    }
+    int tensor_num_elements = qnn::get_const_int(tensor_type->Size());
+    std::vector<int8_t> tensor_values(tensor_num_elements, scalar_value);
+    return MakeConstantTensor<int8_t>(DataType::Int(8), tensor_shape, tensor_values);
+  }
+
+ private:
+  IRModule mod_;
+};
+
+IRModule ScalarToTensorConstant(const IRModule& mod) {
+  auto mutator = ScalarToTensorConstantMutator(mod);
+  Function main_func = Downcast<Function>(mod->Lookup("main"));
+  auto new_main_body = mutator.VisitExpr(main_func->body);
+  if (!new_main_body.same_as(main_func->body)) {
+    auto main_var = mod->GetGlobalVar("main");
+    auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type,

Review comment:
       Should this use `WithFields` ? 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org