You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2022/02/02 17:06:19 UTC

[tvm] branch main updated: [CMSIS-NN] Convert scalar constants to tensor constants (#10100)

This is an automated email from the ASF dual-hosted git repository.

mousius pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cb3d7e2  [CMSIS-NN] Convert scalar constants to tensor constants (#10100)
cb3d7e2 is described below

commit cb3d7e2271d0decc633492ca9ad7440b2dd3d5db
Author: Ashutosh Parkhi <86...@users.noreply.github.com>
AuthorDate: Wed Feb 2 17:05:32 2022 +0000

    [CMSIS-NN] Convert scalar constants to tensor constants (#10100)
---
 python/tvm/relay/op/contrib/cmsisnn.py             |  19 +-
 .../backend/contrib/cmsisnn/extract_constants.cc   | 125 +++++++++----
 .../contrib/cmsisnn/scalar_to_tensor_constant.cc   | 204 +++++++++++++++++++++
 .../python/contrib/test_cmsisnn/test_binary_ops.py | 163 ++++++++++++++--
 .../contrib/test_cmsisnn/test_extract_constants.py |  67 ++-----
 .../test_cmsisnn/test_scalar_to_tensor_constant.py | 187 +++++++++++++++++++
 6 files changed, 671 insertions(+), 94 deletions(-)

diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py
index 7af47c3..e7bbfb6 100644
--- a/python/tvm/relay/op/contrib/cmsisnn.py
+++ b/python/tvm/relay/op/contrib/cmsisnn.py
@@ -57,6 +57,7 @@ def partition_for_cmsisnn(mod, params=None, **opts):
             transform.AnnotateTarget("cmsis-nn"),
             transform.PartitionGraph(),
             GenerateCMSISNNConstants(),
+            ScalarToTensorConstants(),
             ExtractConstantsFromPartitionedFunction(),
             transform.InferType(),
         ]
@@ -223,11 +224,23 @@ def pattern_table():
             is_constant(),
         )
 
-    def check_qnn_binary_op(extract):
+    def check_qnn_binary_op(pattern):
         """Check if multiply is supported by CMSIS-NN."""
+        arg0 = pattern.args[0]
+        arg1 = pattern.args[1]
+        both_args_scalar = False
+        if (
+            isinstance(arg0, tvm.relay.expr.Constant)
+            and len(arg0.checked_type.shape) == 0
+            and isinstance(arg1, tvm.relay.expr.Constant)
+            and len(arg1.checked_type.shape) == 0
+        ):
+            both_args_scalar = True
+
         return (
-            extract.args[0].checked_type.dtype == "int8"
-            and extract.args[1].checked_type.dtype == "int8"
+            arg0.checked_type.dtype == "int8"
+            and arg1.checked_type.dtype == "int8"
+            and not both_args_scalar
         )
 
     return [
diff --git a/src/relay/backend/contrib/cmsisnn/extract_constants.cc b/src/relay/backend/contrib/cmsisnn/extract_constants.cc
index 9b72403..1cbe36e 100644
--- a/src/relay/backend/contrib/cmsisnn/extract_constants.cc
+++ b/src/relay/backend/contrib/cmsisnn/extract_constants.cc
@@ -62,17 +62,88 @@ class ExtractConstantsMutator : public MixedModeMutator {
       return func;
     }
 
-    function_to_constants_.Set(func, Array<Constant>{});
+    function_to_arguments_.Set(func, Array<Expr>{});
     functions_.push_back(func);
     auto new_body = VisitExpr(func->body);
     functions_.pop_back();
-    if (function_to_constants_[func].size()) {
+    if (function_to_arguments_[func].size()) {
       func = WithFields(func, FreeVars(new_body), new_body, func->ret_type,
                         FreeTypeVars(new_body, mod_), func->attrs);
     }
     return std::move(func);
   }
 
+  // Creates new arguments from current call's arguments
+  // Updates constants into the caller arguments: here caller signifies caller that comprises call
+  // to func
+  Array<Expr> CreateNewCallArgsFromExtractedConstants(Call call, Function func) {
+    ICHECK(function_to_arguments_.find(func) != function_to_arguments_.end());
+    Array<Expr> function_signature(function_to_arguments_[func]);
+
+    // Is func a global_function?
+    // main() is not registered for extracting constants
+    bool is_global_function = functions_.empty() ? true : false;
+
+    bool new_constants_added = false;
+    // This tracks arguments traversed inside function_signature
+    uint32_t function_signature_id = 0;
+    // This contains arguments including constants for the caller of this function inside which
+    // post_call resides.
+    Array<Expr> new_caller_args;
+    // New arguments to post_call that includes new variables representing constants extracted from
+    // the function
+    Array<Expr> new_call_args;
+    for (auto& arg : call->args) {
+      if (auto* constant = arg.as<ConstantNode>()) {
+        new_caller_args.push_back(arg);
+        new_call_args.push_back(Var(gen_var_name(), constant->tensor_type()));
+        ++function_signature_id;
+        new_constants_added = true;
+        continue;
+      }
+
+      // Push all constants from the function_signature until a variable corresponding to the
+      // current argument is hit
+      while (function_signature_id < function_signature.size()) {
+        auto* constant = function_signature[function_signature_id].as<ConstantNode>();
+        if (constant == nullptr) {
+          break;
+        }
+        new_caller_args.push_back(function_signature[function_signature_id++]);
+        new_call_args.push_back(Var(gen_var_name(), constant->tensor_type()));
+        new_constants_added = true;
+      }
+
+      new_call_args.push_back(arg);
+      if (is_global_function || arg.as<VarNode>()) {
+        new_caller_args.push_back(arg);
+      }
+      ++function_signature_id;
+    }
+
+    // Push remaining constants as new arguments
+    for (uint32_t i = function_signature_id; i < function_signature.size(); ++i) {
+      auto* constant = function_signature[i].as<ConstantNode>();
+      ICHECK(constant)
+          << "Rest of the collected arguments should be constant in the partitioned function.";
+      new_caller_args.push_back(GetRef<Constant>(constant));
+      new_call_args.push_back(Var(gen_var_name(), constant->tensor_type()));
+      new_constants_added = true;
+    }
+
+    // Update the arguments of caller of local function
+    if (new_constants_added && !is_global_function) {
+      const Function& last_func = functions_.back();
+      Array<Expr> function_constants(function_to_arguments_[last_func]);
+      function_to_arguments_.Set(last_func,
+                                 tvm::runtime::Concat(function_constants, new_caller_args));
+    } else {
+      new_call_args = new_caller_args;
+    }
+
+    return new_call_args;
+  }
+
   Expr Rewrite_(const CallNode* call, const Expr& post) final {
     Expr final_call = post;
     auto* post_call = post.as<CallNode>();
@@ -81,23 +152,28 @@ class ExtractConstantsMutator : public MixedModeMutator {
     // Perform this for non-main Call Nodes only
     if (!functions_.empty() && call->op.as<OpNode>()) {
       Array<Expr> new_args;
+      const Function& last_func = functions_.back();
+      Array<Expr> function_signature(function_to_arguments_[last_func]);
       for (auto& arg : post_call->args) {
+        // Push all arguments including constants to maintain correct order of
+        // variables and constants
         auto* const_arg = arg.as<ConstantNode>();
         if (const_arg && !const_arg->is_scalar()) {
           Var var_arg = Var(gen_var_name(), const_arg->tensor_type());
           new_args.push_back(var_arg);
-          const Function& last_func = functions_.back();
-          Array<Constant> fconstants(function_to_constants_[last_func]);
-          fconstants.push_back(GetRef<Constant>(const_arg));
-          function_to_constants_.Set(last_func, fconstants);
+          function_signature.push_back(arg);
         } else {
+          if (arg.as<VarNode>()) {
+            function_signature.push_back(arg);
+          }
           new_args.push_back(arg);
         }
       }
+      function_to_arguments_.Set(last_func, function_signature);
       final_call = Call(call->op, new_args, call->attrs, {});
     }
 
-    // Since the constants are kicked out of partitioned functions
+    // Since the constants are extracted from partitioned functions
     // a new call to global function is needed
     if (auto* glob_var_node = post_call->op.as<GlobalVarNode>()) {
       auto glob_var = GetRef<GlobalVar>(glob_var_node);
@@ -105,34 +181,18 @@ class ExtractConstantsMutator : public MixedModeMutator {
       auto new_glob_func = VisitExpr(glob_func);
       if (!new_glob_func.same_as(glob_func)) {
         mod_->Update(glob_var, Downcast<Function>(new_glob_func));
-        Array<Expr> new_args = post_call->args;
-        ICHECK(function_to_constants_.find(glob_func) != function_to_constants_.end());
-        for (auto constant : function_to_constants_.at(glob_func)) {
-          new_args.push_back(constant);
-        }
+        auto new_args = CreateNewCallArgsFromExtractedConstants(GetRef<Call>(post_call), glob_func);
         final_call = Call(glob_var, new_args);
       }
     }
 
-    // Since the constants are kicked out of the local partitioned functions
+    // Since the constants are extracted from the local partitioned functions
     // a new call to local function is needed
-    // Also, pass on the constants to the callee of this function to support nested functions
     if (auto* func_node = call->op.as<FunctionNode>()) {
       Function func = GetRef<Function>(func_node);
       auto new_func = VisitExpr(func);
-      if (!new_func.same_as(func)) {
-        Array<Expr> new_args = post_call->args;
-        ICHECK(function_to_constants_.find(func) != function_to_constants_.end());
-        const Function& last_func = functions_.back();
-        Array<Constant> fconstants(function_to_constants_[last_func]);
-        for (auto constant : function_to_constants_.at(func)) {
-          fconstants.push_back(constant);
-          Var var_arg = Var(gen_var_name(), constant->tensor_type());
-          new_args.push_back(var_arg);
-        }
-        function_to_constants_.Set(last_func, fconstants);
-        final_call = Call(new_func, new_args);
-      }
+      Array<Expr> new_args = CreateNewCallArgsFromExtractedConstants(GetRef<Call>(post_call), func);
+      final_call = Call(new_func, new_args);
     }
 
     return final_call;
@@ -141,15 +201,16 @@ class ExtractConstantsMutator : public MixedModeMutator {
  private:
   /* \brief Updated module where all calls have replaced constants with new variables */
   IRModule mod_;
-  /* \brief Maintains mapping of original function to the replaced constants */
-  Map<Function, Array<Constant>> function_to_constants_;
-  /* \brief Stack of functions to determine scope while filling up function_to_constants_ */
+  /* \brief Maintains mapping of original function to the replaced constants along with other
+   * arguments to retain the order in which variables are used within the function */
+  Map<Function, Array<Expr>> function_to_arguments_;
+  /* \brief Stack of functions to determine scope while filling up function_to_arguments_ */
   Array<Function> functions_;
   /* \brief Keeps track of variables being created */
   int var_count_ = 0;
 };
 
-/*!  * \brief Kicks out all constants out of the partitioned function into main()  */
+/*!  * \brief Extracts all constants out of the partitioned function into main()  */
 IRModule ExtractConstants(const IRModule& mod) {
   String func_name;
   Function func;
@@ -169,7 +230,7 @@ transform::Pass ExtractConstantsFromPartitionedFunction() {
   runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
       [=](IRModule m, transform::PassContext pc) { return ExtractConstants(m); };
   return tvm::transform::CreateModulePass(pass_func, 0, "ExtractConstantsFromPartitionedFunction",
-                                          {});
+                                          {"InferType"});
 }
 
 TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.ExtractConstantsFromPartitionedFunction")
diff --git a/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc b/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
new file mode 100644
index 0000000..24ba073
--- /dev/null
+++ b/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file scalar_to_tensor_constant.cc
+ * \brief Converts scalar constant into tensor constant for binary ops of CMSIS-NN
+ */
+
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../op/make_op.h"
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+/*!
+ * \brief This Mutator finds all partitioned functions meant for CMSIS-NN binary ops.
+ * Then, it substitutes the scalar constants with tensor constants. It makes the shape of this
+ * new constant same as that of the neighbouring constant of the other binary operand. The
+ * expectation is that the ExtractConstant pass would later extract this tensor constant out of the
+ * global partitioned function, thus making the entire global partitioned and its composite function
+ * constant free. This makes the TIR generation for binary ops via CMSIS-NN independent of
+ * constants.
+ */
+class ScalarToTensorConstantMutator : public MixedModeMutator {
+ public:
+  explicit ScalarToTensorConstantMutator(const IRModule& mod) : mod_(mod) {}
+
+ private:
+  using MixedModeMutator::VisitExpr_;
+
+  // Here is an example with the annotated scalar constant:
+  // def @tvmgen_default_cmsis_nn_main_1(%cmsis_nn_input: Tensor[], Inline=1, Compiler="cmsis-nn",
+  //                                     global_symbol="tvmgen_default_cmsis_nn_main",
+  //                                     Primitive=1) -> Tensor[] {
+  //   %56 = fn (%input0: _scalar_constant_, %input1: Tensor[],
+  //             PartitionedFromPattern="qnn.mul_", Composite="cmsis-nn.qnn_mul") -> Tensor[] {
+  //     qnn.mul(%input0, %input1, scale0, zero_point0,
+  //              scale1, zero_point_1, output_scale, output_zero_point)
+  //   };
+  //   %56(meta[relay.Constant] /* _scalar constant_ */, %cmsis-nn_input)
+  // }
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    Expr final_call = post;
+    call = post.as<CallNode>();
+
+    // Create a new variable argument that is of the same shape as the neighbouring argument
+    // in the binary op. This needs to be done only when one of the arguments is a scalar.
+    if (call->op.as<OpNode>()) {
+      final_call = ReplaceScalarWithTensorVariable(GetRef<Call>(call));
+    }
+
+    if (auto* glob_var_node = call->op.as<GlobalVarNode>()) {
+      GlobalVar global_var = GetRef<GlobalVar>(glob_var_node);
+      Function func = Downcast<Function>(mod_->Lookup(global_var));
+      auto compiler_name = func->GetAttr<String>(::tvm::relay::attr::kCompiler);
+      if (!compiler_name.defined() || compiler_name != "cmsis-nn") {
+        return final_call;
+      }
+      auto new_body = VisitExpr(func->body);
+      if (new_body.same_as(func->body)) {
+        return final_call;
+      }
+      Function new_func = WithFields(func, FreeVars(new_body), new_body, func->ret_type,
+                                     FreeTypeVars(new_body, mod_), func->attrs);
+      mod_->Update(global_var, new_func);
+      final_call = Call(global_var, call->args);
+    }
+
+    // Substitute scalar constant with a tensor constant in the call to composite function
+    // comprising partitioned binary ops. Shape of the new constant should be same as its
+    // neighbouring tensor's shape.
+    if (auto* func_node = call->op.as<FunctionNode>()) {
+      Function func = GetRef<Function>(func_node);
+      auto func_name = func->GetAttr<String>(attr::kComposite);
+      if (func_name.defined() &&
+          (func_name == "cmsis-nn.qnn_add" || func_name == "cmsis-nn.qnn_mul")) {
+        final_call = ReplaceScalarWithTensorConstant(GetRef<Call>(call), func);
+      }
+    }
+
+    return final_call;
+  }
+
+  // Replaces scalar variable with a tensor variable with same shape as that of the neibouring
+  // operand tensor in a binary op
+  Call ReplaceScalarWithTensorVariable(Call call) {
+    const OpNode* opnode = call->op.as<OpNode>();
+    if (opnode == nullptr) {
+      return call;
+    }
+    String op_name = opnode->name;
+    Array<Expr> new_args;
+    for (uint32_t i = 0; i < call->args.size(); ++i) {
+      Expr arg = call->args[i];
+      new_args.push_back(arg);
+      if (!arg->checked_type_.defined()) {
+        continue;
+      }
+      auto* arg_type = arg->type_as<TensorTypeNode>();
+      if (arg_type->shape.size() != 0 || arg.as<ConstantNode>()) {
+        continue;
+      }
+      String arg_name = arg.as<VarNode>()->name_hint();
+      int tensor_arg_id = (i + 1) % 2;
+      Expr tensor_arg = call->args[tensor_arg_id];
+      if (!tensor_arg->checked_type_.defined()) {
+        continue;
+      }
+      TensorType tensor_type = GetRef<TensorType>(tensor_arg->type_as<TensorTypeNode>());
+      new_args.Set(i, Var(arg_name, tensor_type));
+    }
+    return Call(call->op, new_args, call->attrs, {});
+  }
+
+  // Makes tensor constant of same shape as tensor_arg with values from scalar_arg
+  Call ReplaceScalarWithTensorConstant(Call call, Function func) {
+    Array<Expr> new_args;
+    for (uint32_t i = 0; i < call->args.size(); ++i) {
+      new_args.push_back(call->args[i]);
+      Expr scalar_arg = call->args[i];
+      if (!scalar_arg->checked_type_.defined()) {
+        continue;
+      }
+      Array<PrimExpr> scalar_shape = scalar_arg->type_as<TensorTypeNode>()->shape;
+      if (scalar_shape.size() != 0 || scalar_arg.as<ConstantNode>() == nullptr) {
+        continue;
+      }
+      int tensor_arg_id = (i + 1) % 2;
+      Expr tensor_arg = call->args[tensor_arg_id];
+      if (!tensor_arg->checked_type_.defined()) {
+        continue;
+      }
+      TensorType tensor_type = GetRef<TensorType>(tensor_arg->type_as<TensorTypeNode>());
+      std::vector<int64_t> tensor_shape;
+      for (auto& dim : tensor_type->shape) {
+        tensor_shape.push_back(qnn::get_const_int(dim));
+      }
+      int8_t scalar_value = GetScalarFromConstant<int8_t>(scalar_arg);
+      int tensor_num_elements = qnn::get_const_int(tensor_type->Size());
+      std::vector<int8_t> tensor_values(tensor_num_elements, scalar_value);
+      Constant tensor_constant =
+          MakeConstantTensor<int8_t>(DataType::Int(8), tensor_shape, tensor_values);
+      new_args.Set(i, tensor_constant);
+    }
+    auto new_body = VisitExpr(func->body);
+    Function new_func = WithFields(func, FreeVars(new_body), new_body, func->ret_type,
+                                   FreeTypeVars(new_body, mod_), func->attrs);
+    return Call(new_func, new_args);
+  }
+
+ private:
+  IRModule mod_;
+};
+
+IRModule ScalarToTensorConstant(const IRModule& mod) {
+  auto mutator = ScalarToTensorConstantMutator(mod);
+  Function main_func = Downcast<Function>(mod->Lookup("main"));
+  auto new_main_body = mutator.VisitExpr(main_func->body);
+  if (!new_main_body.same_as(main_func->body)) {
+    auto main_var = mod->GetGlobalVar("main");
+    auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type,
+                                  main_func->type_params, main_func->attrs);
+    mod->Update(main_var, new_main_func);
+  }
+  return mod;
+}
+
+transform::Pass ScalarToTensorConstantPass() {
+  runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
+      [=](IRModule m, transform::PassContext pc) { return ScalarToTensorConstant(m); };
+  return tvm::transform::CreateModulePass(pass_func, 0, "ScalarToTensorConstant", {"InferType"});
+}
+
+TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.ScalarToTensorConstants")
+    .set_body_typed(ScalarToTensorConstantPass);
+
+}  // namespace cmsisnn
+}  // namespace contrib
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/contrib/test_cmsisnn/test_binary_ops.py b/tests/python/contrib/test_cmsisnn/test_binary_ops.py
index 39b8c5f..f6417ac 100644
--- a/tests/python/contrib/test_cmsisnn/test_binary_ops.py
+++ b/tests/python/contrib/test_cmsisnn/test_binary_ops.py
@@ -17,9 +17,11 @@
 
 """CMSIS-NN integration tests: binary ops"""
 
+import itertools
 import sys
 
 import numpy as np
+from enum import Enum
 import pytest
 
 import tvm
@@ -35,11 +37,29 @@ from tests.python.relay.aot.aot_test_utils import (
 )
 
 
+def generate_tensor_constant():
+    rng = np.random.default_rng(12321)
+    dtype = "int8"
+    shape = (1, 16, 16, 3)
+    values = tvm.nd.array(
+        rng.integers(np.iinfo(dtype).min, high=np.iinfo(dtype).max, size=shape, dtype=dtype)
+    )
+    return relay.const(values, dtype)
+
+
+def generate_scalar_constant():
+    dtype = "int8"
+    return relay.const(-30, dtype)
+
+
+def generate_variable(name, dtype="int8"):
+    return relay.var(name, shape=(1, 16, 16, 3), dtype=dtype)
+
+
 def make_model(
     op,
-    shape,
-    input_0_dtype,
-    input_1_dtype,
+    input_0,
+    input_1,
     input_0_scale,
     input_0_zero_point,
     input_1_scale,
@@ -48,10 +68,9 @@ def make_model(
     out_zero_point=-128,
 ):
     """Create a Relay Function / network model"""
-
     return op(
-        relay.var("input_0", shape=shape, dtype=input_0_dtype),
-        relay.var("input_1", shape=shape, dtype=input_1_dtype),
+        input_0,
+        input_1,
         relay.const(input_0_scale, "float32"),
         relay.const(input_0_zero_point, "int32"),
         relay.const(input_1_scale, "float32"),
@@ -82,9 +101,8 @@ def test_op_int8(op, input_0_scale, input_0_zero_point, input_1_scale, input_1_z
     shape = [1, 16, 16, 3]
     model = make_model(
         op,
-        shape,
-        dtype,
-        dtype,
+        generate_variable("input_0"),
+        generate_variable("input_1"),
         input_0_scale,
         input_0_zero_point,
         input_1_scale,
@@ -131,6 +149,128 @@ def test_op_int8(op, input_0_scale, input_0_zero_point, input_1_scale, input_1_z
     )
 
 
+# At least one of the inputs is a constant, both can't be variables, both can't be scalars
+def parameterize_for_constant_inputs(test):
+    op = [relay.qnn.op.mul, relay.qnn.op.add]
+    input_0 = [generate_variable("input_0"), generate_tensor_constant(), generate_scalar_constant()]
+    input_1 = [generate_variable("input_1"), generate_tensor_constant(), generate_scalar_constant()]
+    all_combinations = itertools.product(op, input_0, input_1)
+    all_combinations = filter(
+        lambda parameters: not (
+            (
+                isinstance(parameters[1], tvm.relay.expr.Var)
+                and isinstance(parameters[2], tvm.relay.expr.Var)
+            )
+            or (
+                isinstance(parameters[1], tvm.relay.expr.Constant)
+                and isinstance(parameters[2], tvm.relay.expr.Constant)
+                and parameters[1].data.numpy().ndim == 0
+                and parameters[2].data.numpy().ndim == 0
+            )
+        ),
+        all_combinations,
+    )
+    return pytest.mark.parametrize(
+        ["op", "input_0", "input_1"],
+        all_combinations,
+    )(test)
+
+
+@skip_if_no_reference_system
+@tvm.testing.requires_cmsisnn
+@parameterize_for_constant_inputs
+def test_constant_input_int8(op, input_0, input_1):
+    interface_api = "c"
+    use_unpacked_api = True
+    test_runner = AOT_CORSTONE300_RUNNER
+
+    dtype = "int8"
+    shape = [1, 16, 16, 3]
+    input_0_scale = 0.256
+    input_0_zero_point = 33
+    input_1_scale = 0.128
+    input_1_zero_point = -24
+    model = make_model(
+        op,
+        input_0,
+        input_1,
+        input_0_scale,
+        input_0_zero_point,
+        input_1_scale,
+        input_1_zero_point,
+    )
+    orig_mod = make_module(model)
+
+    cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
+
+    # validate pattern matching
+    attrs = [
+        cmsisnn_mod[var.name_hint].attrs
+        for var in cmsisnn_mod.get_global_vars()
+        if cmsisnn_mod[var.name_hint].attrs
+    ]
+    assert any(attrs), "At least one function with external attributes was expected."
+
+    compilers = [
+        key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items()
+    ]
+    assert any(compilers), "Module does not contain function for cmsisnn target."
+
+    assert count_num_calls(orig_mod) == count_num_calls(
+        cmsisnn_mod
+    ), "Number of calls changed during partitioning"
+
+    # validate the output
+    in_min, in_max = get_range_for_dtype_str(dtype)
+    inputs = {}
+    if isinstance(input_0, tvm.relay.expr.Var):
+        inputs.update({"input_0": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype)})
+    if isinstance(input_1, tvm.relay.expr.Var):
+        inputs.update({"input_1": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype)})
+    output_list = generate_ref_data(orig_mod["main"], inputs)
+    compile_and_run(
+        AOTTestModel(
+            module=cmsisnn_mod,
+            inputs=inputs,
+            outputs=output_list,
+            output_tolerance=1,
+        ),
+        test_runner,
+        interface_api,
+        use_unpacked_api,
+    )
+
+
+@skip_if_no_reference_system
+@tvm.testing.requires_cmsisnn
+@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add])
+def test_both_scalar_inputs_int8(
+    op,
+):
+    input_scale = 0.256
+    input_zero_point = 33
+    dtype = "int8"
+    model = make_model(
+        op,
+        generate_scalar_constant(),
+        generate_scalar_constant(),
+        input_scale,
+        input_zero_point,
+        input_scale,
+        input_zero_point,
+    )
+
+    orig_mod = make_module(model)
+    cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
+
+    attrs = [
+        cmsisnn_mod[var.name_hint].attrs
+        for var in cmsisnn_mod.get_global_vars()
+        if cmsisnn_mod[var.name_hint].attrs
+    ]
+    assert not any(attrs), "No function should have an external attribute."
+
+
 @skip_if_no_reference_system
 @tvm.testing.requires_cmsisnn
 @pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add])
@@ -143,9 +283,8 @@ def test_invalid_parameters(
     input_zero_point = 33
     model = make_model(
         op,
-        [1, 16, 16, 3],
-        input_dtype,
-        input_dtype,
+        generate_variable("input_0", input_dtype),
+        generate_variable("input_1", input_dtype),
         input_scale,
         input_zero_point,
         input_scale,
diff --git a/tests/python/contrib/test_cmsisnn/test_extract_constants.py b/tests/python/contrib/test_cmsisnn/test_extract_constants.py
index ca3fbe6..8e25177 100644
--- a/tests/python/contrib/test_cmsisnn/test_extract_constants.py
+++ b/tests/python/contrib/test_cmsisnn/test_extract_constants.py
@@ -23,15 +23,6 @@ import pytest
 import tvm
 from tvm import relay
 
-from utils import (
-    make_module,
-    count_num_calls,
-    get_range_for_dtype_str,
-    get_same_padding,
-    get_conv2d_qnn_params,
-    make_qnn_relu,
-)
-
 tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__)
 
 
@@ -136,7 +127,6 @@ def test_multiple_functions():
     c10 = relay.Call(f20, [x10])
     c11 = relay.Call(f21, [c10])
     ef = relay.Function([x10], c11, relay.TensorType((8, 8), "float32"))
-
     x0 = relay.var("x0", shape=(8, 8))
     ev = relay.GlobalVar("cmsis-nn")
     ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint)
@@ -182,56 +172,42 @@ def test_main_function():
     ), "main() should have same number of arguments as before"
 
 
-def parameterize_for_invalid_model(test):
-    local_func_1 = ["cmsis-nn.qnn_op_1", "local_function_1"]
-    local_func_2 = ["cmsis-nn.qnn_op_2", "local_function_2"]
-    compiler_name = ["cmsis-nn", "external_compiler"]
-    all_combinations = itertools.product(local_func_1, local_func_2, compiler_name)
-    all_combinations = filter(
-        lambda parameters: not (
-            parameters[2] == "cmsis-nn"
-            and parameters[0] == "cmsis-nn.qnn_op_1"
-            and parameters[1] == "cmsis-nn.qnn_op_2"
-        ),
-        all_combinations,
-    )
-    return pytest.mark.parametrize(
-        ["func_name_1", "func_name_2", "external_compiler"],
-        all_combinations,
-    )(test)
-
-
 @tvm.testing.requires_cmsisnn
-@parameterize_for_invalid_model
-def test_multiple_functions_non_cmsisnn_compiler(func_name_1, func_name_2, external_compiler):
+@pytest.mark.parametrize("external_compiler", ["cmsis-nn", "other_compiler"])
+def test_multiple_functions_non_cmsisnn_compiler(external_compiler):
     y20_data = np.random.uniform(0, 1, (8, 8)).astype("float32")
     x20 = relay.var("x20", shape=(8, 8))
     y20_const = relay.const(y20_data, "float32")
     z20 = x20 + y20_const
     f20 = relay.Function([x20], z20, relay.TensorType((8, 8), "float32"))
-    f20 = set_composite_func_attr(f20, func_name_1)
+    f20 = set_composite_func_attr(f20, "cmsis-nn.qnn_op_1")
+    x10 = relay.var("x10", shape=(8, 8))
+    c10 = relay.Call(f20, [x10])
+    ef0 = relay.Function([x10], c10, relay.TensorType((8, 8), "float32"))
 
     y21_data = np.random.uniform(0, 1, (8, 8)).astype("float32")
     x21 = relay.var("x21", shape=(8, 8))
     y21_const = relay.const(y21_data, "float32")
     z21 = x21 + y21_const
     f21 = relay.Function([x21], z21, relay.TensorType((8, 8), "float32"))
-    f21 = set_composite_func_attr(f21, func_name_2)
-
-    x10 = relay.var("x10", shape=(8, 8))
-    c10 = relay.Call(f20, [x10])
-    c11 = relay.Call(f21, [c10])
-    ef = relay.Function([x10], c11, relay.TensorType((8, 8), "float32"))
+    f21 = set_composite_func_attr(f21, "cmsis-nn.qnn_op_2")
+    x11 = relay.var("x11", shape=(8, 8))
+    c11 = relay.Call(f21, [x11])
+    ef1 = relay.Function([x11], c11, relay.TensorType((8, 8), "float32"))
 
     x0 = relay.var("x0", shape=(8, 8))
-    ev = relay.GlobalVar("external_function")
-    ef = set_external_func_attr(ef, external_compiler, ev.name_hint)
-    c = relay.Call(ev, [x0])
-    mf = relay.Function([x0], c, relay.TensorType((8, 8), "float32"))
+    ev0 = relay.GlobalVar("external_function_0")
+    ef0 = set_external_func_attr(ef0, external_compiler, ev0.name_hint)
+    c0 = relay.Call(ev0, [x0])
+    ev1 = relay.GlobalVar("external_function_1")
+    ef1 = set_external_func_attr(ef1, external_compiler, ev1.name_hint)
+    c1 = relay.Call(ev1, [c0])
+    mf = relay.Function([x0], c1, relay.TensorType((8, 8), "float32"))
     mv = relay.GlobalVar("main")
 
     mod = tvm.IRModule()
-    mod[ev] = ef
+    mod[ev0] = ef0
+    mod[ev1] = ef1
     mod[mv] = mf
 
     mod = ExtractConstantsFromPartitionedFunction()(mod)
@@ -240,10 +216,7 @@ def test_multiple_functions_non_cmsisnn_compiler(func_name_1, func_name_2, exter
 
     num_extracted_constants = 0
     if external_compiler == "cmsis-nn":
-        if "cmsis-nn" in func_name_1:
-            num_extracted_constants += 1
-        if "cmsis-nn" in func_name_2:
-            num_extracted_constants += 1
+        num_extracted_constants = 2
 
     assert (
         check_for_constants.num_constants_ == num_extracted_constants
diff --git a/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py b/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py
new file mode 100644
index 0000000..7039617
--- /dev/null
+++ b/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py
@@ -0,0 +1,187 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""CMSIS-NN integration tests: scalar_to_tensor_constant pass"""
+import itertools
+import math
+import numpy as np
+import pytest
+import tvm
+from tvm import relay
+
+tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__)
+
+
+class CheckFunctionsForConstants(tvm.relay.ExprVisitor):
+    def __init__(self):
+        super().__init__()
+        self.num_constants_ = 0
+
+    def visit_call(self, call):
+        super().visit_call(call)
+        for arg in call.args:
+            if isinstance(arg, relay.Constant) and arg.data.numpy().ndim > 0:
+                self.num_constants_ += 1
+
+    def check_num_constants(self, func):
+        assert self.num_constants_ == 0, "Functions should not have constant arguments in Calls"
+
+
+def set_external_func_attr(func, compiler, ext_symbol):
+    func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+    func = func.with_attr("Compiler", compiler)
+    func = func.with_attr("global_symbol", ext_symbol)
+    return func
+
+
+def set_composite_func_attr(func, name):
+    func = func.with_attr("Composite", name)
+    return func
+
+
+@tvm.testing.requires_cmsisnn
+def test_single_scalar_position_0():
+    x0 = relay.var("x0", shape=None)
+    x1 = relay.var("x1", shape=(8, 8))
+    z1 = x0 + x1
+    lf = relay.Function([x0, x1], z1, relay.TensorType((8, 8), "float32"))
+    lf = set_composite_func_attr(lf, "cmsis-nn.qnn_add")
+
+    y0 = relay.expr.const(3, "float32")
+    y1 = relay.var("y1", shape=(8, 8))
+    c0 = relay.Call(lf, [y0, y1])
+    ef = relay.Function([y1], c0, relay.TensorType((8, 8), "float32"))
+
+    x = relay.var("x", shape=(8, 8))
+    ev = relay.GlobalVar("external_function")
+    ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint)
+    c = relay.Call(ev, [x])
+    mf = relay.Function([x], c, relay.TensorType((8, 8), "float32"))
+    mv = relay.GlobalVar("main")
+
+    mod = tvm.IRModule()
+    mod[ev] = ef
+    mod[mv] = mf
+
+    mod = relay.transform.InferType()(mod)
+    mod = ScalarToTensorConstants()(mod)
+    check_for_constants = CheckFunctionsForConstants()
+    check_for_constants.visit_call(mod[ev].body)
+    assert (
+        check_for_constants.num_constants_ == 1
+    ), "Scalar constant wasn't converted into tensor constant"
+
+
+@tvm.testing.requires_cmsisnn
+def test_single_scalar_position_1():
+    x0 = relay.var("x0", shape=(8, 8))
+    x1 = relay.var("x1", shape=None)
+    z1 = x0 + x1
+    lf = relay.Function([x0, x1], z1, relay.TensorType((8, 8), "float32"))
+    lf = set_composite_func_attr(lf, "cmsis-nn.qnn_add")
+
+    y0 = relay.var("y0", shape=(8, 8))
+    y1 = relay.expr.const(3, "float32")
+    c0 = relay.Call(lf, [y0, y1])
+    ef = relay.Function([y0], c0, relay.TensorType((8, 8), "float32"))
+
+    x = relay.var("x", shape=(8, 8))
+    ev = relay.GlobalVar("external_function")
+    ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint)
+    c = relay.Call(ev, [x])
+    mf = relay.Function([x], c, relay.TensorType((8, 8), "float32"))
+    mv = relay.GlobalVar("main")
+
+    mod = tvm.IRModule()
+    mod[ev] = ef
+    mod[mv] = mf
+
+    mod = relay.transform.InferType()(mod)
+    mod = ScalarToTensorConstants()(mod)
+    check_for_constants = CheckFunctionsForConstants()
+    check_for_constants.visit_call(mod[ev].body)
+    assert (
+        check_for_constants.num_constants_ == 1
+    ), "Scalar constant wasn't converted into tensor constant"
+
+
+@tvm.testing.requires_cmsisnn
+def test_two_scalars():
+    x1 = relay.var("x1", shape=None)
+    x2 = relay.var("x2", shape=None)
+    z1 = x1 + x2
+    lf = relay.Function([x1, x2], z1, relay.TensorType((), "float32"))
+    lf = set_composite_func_attr(lf, "cmsis-nn.qnn_add")
+
+    y0 = relay.expr.const(5, "float32")
+    y1 = relay.expr.const(3, "float32")
+    c0 = relay.Call(lf, [y0, y1])
+    ef = relay.Function([], c0, relay.TensorType((), "float32"))
+
+    ev = relay.GlobalVar("external_function")
+    ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint)
+    c = relay.Call(ev, [])
+    mf = relay.Function([], c, relay.TensorType((), "float32"))
+    mv = relay.GlobalVar("main")
+
+    mod = tvm.IRModule()
+    mod[ev] = ef
+    mod[mv] = mf
+
+    mod = relay.transform.InferType()(mod)
+    mod = ScalarToTensorConstants()(mod)
+    check_for_constants = CheckFunctionsForConstants()
+    check_for_constants.visit_call(mod[ev].body)
+    assert (
+        check_for_constants.num_constants_ == 0
+    ), "Scalar constant wasn't converted into tensor constant"
+
+
+@tvm.testing.requires_cmsisnn
+def test_two_tensor_constants():
+    x0 = relay.var("x0", shape=(8, 8))
+    x1 = relay.var("x1", shape=(8, 8))
+    z1 = x0 + x1
+    lf = relay.Function([x0, x1], z1, relay.TensorType((8, 8), "float32"))
+    lf = set_composite_func_attr(lf, "cmsis-nn.qnn_add")
+
+    y0 = relay.const(np.random.uniform(0, 1, (8, 8)).astype("float32"), "float32")
+    y1 = relay.const(np.random.uniform(0, 1, (8, 8)).astype("float32"), "float32")
+    c0 = relay.Call(lf, [y0, y1])
+    ef = relay.Function([], c0, relay.TensorType((8, 8), "float32"))
+
+    ev = relay.GlobalVar("external_function")
+    ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint)
+    c = relay.Call(ev, [])
+    mf = relay.Function([], c, relay.TensorType((8, 8), "float32"))
+    mv = relay.GlobalVar("main")
+
+    mod = tvm.IRModule()
+    mod[ev] = ef
+    mod[mv] = mf
+
+    mod = relay.transform.InferType()(mod)
+    mod = ScalarToTensorConstants()(mod)
+    check_for_constants = CheckFunctionsForConstants()
+    check_for_constants.visit_call(mod[ev].body)
+    assert (
+        check_for_constants.num_constants_ == 2
+    ), "Scalar constant wasn't converted into tensor constant"
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))