You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/10/22 11:09:59 UTC

[GitHub] [tvm] manupa-arm commented on a change in pull request #9331: [4/10] Code generation for Conv2D via CMSIS-NN

manupa-arm commented on a change in pull request #9331:
URL: https://github.com/apache/tvm/pull/9331#discussion_r734368762



##########
File path: python/tvm/relay/op/contrib/cmsisnn.py
##########
@@ -47,42 +47,93 @@ def partition_for_cmsisnn(mod, params=None, **opts):
     if params:
         mod["main"] = bind_params_by_name(mod["main"], params)
 
+    tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__)
+
     seq = tvm.transform.Sequential(
         [
             transform.InferType(),
             transform.MergeComposite(pattern_table()),
             transform.AnnotateTarget("cmsisnn"),
-            transform.MergeCompilerRegions(),
             transform.PartitionGraph(),
+            GenerateCMSISNNConstants(),
+            ExtractConstantsFromPartitionedFunction(),
+            transform.InferType(),
         ]
     )
-
     return seq(mod)
 
 
 @register_pattern_table("cmsisnn")
 def pattern_table():
     """Get the cmsisnn compiler pattern table."""
 
-    def softmax_pattern():
+    def qnn_softmax_pattern():
+        """Create pattern for quantized softmax"""
         pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
         pattern = is_op("nn.softmax")(pattern)
         pattern = is_op("qnn.quantize")(pattern, is_constant(), is_constant())
         return pattern
 
-    def check_quantized_softmax(extract):
+    def check_qnn_softmax(pattern):
         """Check if softmax is supported by CMSIS-NN."""
-        dequantize_call = extract.args[0].args[0]
-        scale = extract.args[1].data.numpy().item(0)
-        zero_point = extract.args[2].data.numpy().item(0)
+        dequantize_call = pattern.args[0].args[0]
+        scale = pattern.args[1].data.numpy().item(0)
+        zero_point = pattern.args[2].data.numpy().item(0)
 
         # check for dtypes of quantize and dequantize
         return (
             (scale == 1.0 / 256 and zero_point == -128)
-            and extract.attrs.out_dtype == "int8"
+            and pattern.attrs.out_dtype == "int8"
             and dequantize_call.args[0].checked_type.dtype == "int8"
         )
 
+    def qnn_conv2d_pattern():
+        """Create pattern for qnn.conv2D with optional fused relu."""
+        qnn_conv2d = is_op("qnn.conv2d")(
+            wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
+        ).has_attr({"kernel_layout": "HWIO"})
+        bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant())
+        req = is_op("qnn.requantize")(
+            qnn_conv2d | bias_add, is_constant(), is_constant(), is_constant(), is_constant()
+        )
+        clip_or_req = req.optional(is_op("clip"))
+        return clip_or_req
+
+    def check_qnn_conv2d(pattern):
+        """Check if the Conv2D is supported by CMSIS-NN."""
+        if str(pattern.op.name) == "clip":
+            relu = pattern
+            requantize = relu.args[0]
+        else:
+            requantize = pattern
+        requantize_input = requantize.args[0]
+        bias_add = None
+        bias_dtype = "int32"
+        if str(requantize_input.op.name) == "nn.bias_add":
+            bias_add = requantize_input
+            conv2d = bias_add.args[0]
+            bias_dtype = bias_add.args[1].checked_type.dtype
+        else:
+            conv2d = requantize_input
+        conv2d_input = conv2d.args[0]
+        conv2d_weight = conv2d.args[1]
+
+        # kernel zero_point should be 0
+        kernel_zp = conv2d.args[3].data.numpy()
+        kernel_zp = [kernel_zp] if kernel_zp.ndim == 0 else kernel_zp
+
+        return (
+            conv2d.attrs.kernel_layout == "HWIO"

Review comment:
       Why cant we offload any other kernel layout ? We are converting OHWI anyway, right ?

##########
File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc
##########
@@ -0,0 +1,158 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+class ExtractConstantsMutator : public MixedModeMutator {
+ public:
+  explicit ExtractConstantsMutator(IRModule& mod) : mod_(mod) {}
+
+ private:
+  String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }
+
+  Expr VisitExpr_(const FunctionNode* func) final {
+    Function final_func = GetRef<Function>(func);
+    ++func_nesting_level_;
+    auto new_body = VisitExpr(func->body);
+    --func_nesting_level_;
+    if (!new_body.same_as(func->body)) {

Review comment:
       I think this condition is not useful.

##########
File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc
##########
@@ -0,0 +1,158 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+class ExtractConstantsMutator : public MixedModeMutator {
+ public:
+  explicit ExtractConstantsMutator(IRModule& mod) : mod_(mod) {}
+
+ private:
+  String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }
+
+  Expr VisitExpr_(const FunctionNode* func) final {
+    Function final_func = GetRef<Function>(func);
+    ++func_nesting_level_;
+    auto new_body = VisitExpr(func->body);
+    --func_nesting_level_;
+    if (!new_body.same_as(func->body)) {
+      final_func = Function(FreeVars(new_body), new_body, func->ret_type,
+                            FreeTypeVars(new_body, mod_), func->attrs);
+      function_to_constants_.Set(GetRef<Function>(func), constants_within_function_);
+      constants_within_function_.clear();
+    }
+    return final_func;
+  }
+
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    Expr final_call = post;
+    auto* post_call = post.as<CallNode>();
+    if (post_call == nullptr) {
+      return final_call;
+    }
+
+    // Replace Constant arguments with Vars for ML Operators
+    // Perform this for non-main Call Nodes only
+    if (func_nesting_level_ && call->op.as<OpNode>()) {
+      Array<Expr> new_args;
+      for (auto& arg : post_call->args) {

Review comment:
       nit : I think you can just do Downcast<Call> to specialize the Expr be a Call Expr and access args.

##########
File path: src/relay/backend/contrib/cmsisnn/generate_constants.cc
##########
@@ -0,0 +1,230 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */

Review comment:
       Can we add documentation as to what this Pass would do ?
   
   Also we would need unit testing for this pass.

##########
File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc
##########
@@ -0,0 +1,158 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+class ExtractConstantsMutator : public MixedModeMutator {
+ public:
+  explicit ExtractConstantsMutator(IRModule& mod) : mod_(mod) {}
+
+ private:
+  String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }
+
+  Expr VisitExpr_(const FunctionNode* func) final {
+    Function final_func = GetRef<Function>(func);
+    ++func_nesting_level_;
+    auto new_body = VisitExpr(func->body);
+    --func_nesting_level_;
+    if (!new_body.same_as(func->body)) {
+      final_func = Function(FreeVars(new_body), new_body, func->ret_type,
+                            FreeTypeVars(new_body, mod_), func->attrs);
+      function_to_constants_.Set(GetRef<Function>(func), constants_within_function_);
+      constants_within_function_.clear();
+    }
+    return final_func;
+  }
+
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    Expr final_call = post;
+    auto* post_call = post.as<CallNode>();
+    if (post_call == nullptr) {
+      return final_call;
+    }
+
+    // Replace Constant arguments with Vars for ML Operators
+    // Perform this for non-main Call Nodes only
+    if (func_nesting_level_ && call->op.as<OpNode>()) {
+      Array<Expr> new_args;
+      for (auto& arg : post_call->args) {
+        auto* const_arg = arg.as<ConstantNode>();
+        if (const_arg && !const_arg->is_scalar()) {
+          Var var_arg = Var(gen_var_name(), const_arg->tensor_type());
+          new_args.push_back(var_arg);
+          constants_within_function_.push_back(GetRef<Constant>(const_arg));
+        } else {
+          new_args.push_back(arg);
+        }
+      }
+      final_call = Call(call->op, new_args, call->attrs, {});
+    }
+
+    // Since the constants are kicked out of partitioned functions
+    // a new call to global function is needed
+    if (auto* glob_var_node = post_call->op.as<GlobalVarNode>()) {
+      auto glob_var = GetRef<GlobalVar>(glob_var_node);
+      auto glob_func = Downcast<Function>(mod_->Lookup(glob_var));
+      auto new_glob_func = VisitExpr(glob_func);
+      if (!new_glob_func.same_as(glob_func)) {
+        mod_->Update(glob_var, Downcast<Function>(new_glob_func));
+        Array<Expr> new_args = post_call->args;
+        ICHECK(function_to_constants_.find(glob_func) != function_to_constants_.end());
+        for (auto constant : function_to_constants_.at(glob_func)) {
+          new_args.push_back(constant);
+        }
+        final_call = Call(glob_var, new_args);
+      }
+    }
+
+    // Since the constants are kicked out of the local partitioned functions
+    // a new call to local function is needed
+    if (auto* func_node = call->op.as<FunctionNode>()) {
+      Function func = GetRef<Function>(func_node);
+      auto new_func = VisitExpr(func);
+      if (!new_func.same_as(func)) {
+        Array<Expr> new_args = post_call->args;
+        ICHECK(function_to_constants_.find(func) != function_to_constants_.end());
+        for (auto constant : function_to_constants_.at(func)) {
+          constants_within_function_.push_back(constant);
+          Var var_arg = Var(gen_var_name(), constant->tensor_type());
+          new_args.push_back(var_arg);
+        }
+        final_call = Call(new_func, new_args);
+      }
+    }
+
+    return final_call;
+  }
+
+ private:
+  /* \brief Updated module where all calls have replaced constants with new variables */
+  IRModule mod_;
+  /* \brief Maintains mapping of original function to the replaced constants */
+  Map<Function, Array<Constant>> function_to_constants_;
+  /* \brief Constants being kicked out of a function during the function visit */
+  Array<Constant> constants_within_function_;
+  /* \brief Keeps track of variables being created */
+  int var_count_ = 0;
+  /* \brief Keeps track of function scope */
+  int func_nesting_level_ = 0;
+};
+
+/*!  * \brief Kicks out all constants out of the partitioned function into main()  */
+IRModule ExtractConstants(IRModule mod) {

Review comment:
       Coule we use const IRModule& here ? 

##########
File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc
##########
@@ -0,0 +1,158 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+class ExtractConstantsMutator : public MixedModeMutator {
+ public:
+  explicit ExtractConstantsMutator(IRModule& mod) : mod_(mod) {}
+
+ private:
+  String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }
+
+  Expr VisitExpr_(const FunctionNode* func) final {
+    Function final_func = GetRef<Function>(func);
+    ++func_nesting_level_;
+    auto new_body = VisitExpr(func->body);
+    --func_nesting_level_;
+    if (!new_body.same_as(func->body)) {
+      final_func = Function(FreeVars(new_body), new_body, func->ret_type,
+                            FreeTypeVars(new_body, mod_), func->attrs);
+      function_to_constants_.Set(GetRef<Function>(func), constants_within_function_);
+      constants_within_function_.clear();

Review comment:
       nit : We can avoid this with the above suggestion. Let me know what you think.

##########
File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc
##########
@@ -0,0 +1,158 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+class ExtractConstantsMutator : public MixedModeMutator {
+ public:
+  explicit ExtractConstantsMutator(IRModule& mod) : mod_(mod) {}

Review comment:
       Would it be possible not to modify the input IRModule ?

##########
File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc
##########
@@ -0,0 +1,158 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+class ExtractConstantsMutator : public MixedModeMutator {
+ public:
+  explicit ExtractConstantsMutator(IRModule& mod) : mod_(mod) {}
+
+ private:
+  String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }
+
+  Expr VisitExpr_(const FunctionNode* func) final {
+    Function final_func = GetRef<Function>(func);
+    ++func_nesting_level_;
+    auto new_body = VisitExpr(func->body);
+    --func_nesting_level_;
+    if (!new_body.same_as(func->body)) {
+      final_func = Function(FreeVars(new_body), new_body, func->ret_type,
+                            FreeTypeVars(new_body, mod_), func->attrs);
+      function_to_constants_.Set(GetRef<Function>(func), constants_within_function_);
+      constants_within_function_.clear();
+    }
+    return final_func;
+  }
+
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    Expr final_call = post;
+    auto* post_call = post.as<CallNode>();
+    if (post_call == nullptr) {
+      return final_call;
+    }
+
+    // Replace Constant arguments with Vars for ML Operators
+    // Perform this for non-main Call Nodes only
+    if (func_nesting_level_ && call->op.as<OpNode>()) {
+      Array<Expr> new_args;
+      for (auto& arg : post_call->args) {
+        auto* const_arg = arg.as<ConstantNode>();
+        if (const_arg && !const_arg->is_scalar()) {
+          Var var_arg = Var(gen_var_name(), const_arg->tensor_type());
+          new_args.push_back(var_arg);
+          constants_within_function_.push_back(GetRef<Constant>(const_arg));
+        } else {
+          new_args.push_back(arg);
+        }
+      }
+      final_call = Call(call->op, new_args, call->attrs, {});
+    }
+
+    // Since the constants are kicked out of partitioned functions
+    // a new call to global function is needed
+    if (auto* glob_var_node = post_call->op.as<GlobalVarNode>()) {
+      auto glob_var = GetRef<GlobalVar>(glob_var_node);
+      auto glob_func = Downcast<Function>(mod_->Lookup(glob_var));
+      auto new_glob_func = VisitExpr(glob_func);
+      if (!new_glob_func.same_as(glob_func)) {
+        mod_->Update(glob_var, Downcast<Function>(new_glob_func));
+        Array<Expr> new_args = post_call->args;
+        ICHECK(function_to_constants_.find(glob_func) != function_to_constants_.end());
+        for (auto constant : function_to_constants_.at(glob_func)) {
+          new_args.push_back(constant);
+        }
+        final_call = Call(glob_var, new_args);
+      }
+    }
+
+    // Since the constants are kicked out of the local partitioned functions
+    // a new call to local function is needed
+    if (auto* func_node = call->op.as<FunctionNode>()) {

Review comment:
       Can we refactor the this section and the above ? I see a bit of code duplication and I think the difference is origin of the Function (being a local or a global)

##########
File path: src/relay/backend/contrib/cmsisnn/generate_constants.cc
##########
@@ -0,0 +1,230 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+Expr MakeTranspose(Expr data, Array<Integer> axes);
+namespace contrib {
+namespace cmsisnn {
+
+class GenerateConstantsMutator : public MixedModeMutator {
+ public:
+  explicit GenerateConstantsMutator(IRModule& mod) : mod_(mod) {}
+
+ private:
+  /*!  * \brief Converts Kernel layout from HWIO to OHWI to align to CMSIS-NN requirements */
+  Expr ConvertKernelLayout(Expr kernel_expr, const Conv2DAttrs* conv2d_attrs, Attrs* new_attrs) {
+    auto attrs = make_object<Conv2DAttrs>();
+    attrs->strides = std::move(conv2d_attrs->strides);
+    attrs->padding = std::move(conv2d_attrs->padding);
+    attrs->dilation = std::move(conv2d_attrs->dilation);
+    attrs->groups = conv2d_attrs->groups;
+    attrs->channels = std::move(conv2d_attrs->channels);
+    attrs->kernel_size = std::move(conv2d_attrs->kernel_size);
+    attrs->data_layout = std::move(conv2d_attrs->data_layout);
+    attrs->kernel_layout = runtime::String("OHWI");
+    attrs->out_layout = std::move(conv2d_attrs->out_layout);
+    attrs->out_dtype = std::move(conv2d_attrs->out_dtype);
+    *new_attrs = tvm::Attrs{attrs};
+
+    IRModule kernel_module;
+    auto func_body = MakeTranspose(kernel_expr, {Integer(3), Integer(0), Integer(1), Integer(2)});
+    auto kernel_func =
+        Function(FreeVars(func_body), func_body, Type(), FreeTypeVars(func_body, kernel_module));
+    GlobalVar kernel_var("main");
+    kernel_module->Add(kernel_var, kernel_func);
+    kernel_module = relay::transform::FoldConstant()(kernel_module);
+    kernel_func = Downcast<Function>(kernel_module->Lookup("main"));
+    return kernel_func->body;
+  }
+
+  /*!  * \brief Performs weight transpose and substitutes existing constants in the composite
+   *            function for Conv2D with CMSIS-NN Requantize constants */
+  Expr GenerateConv2dRequantConstants(const Expr& expr) {
+    const CallNode* clip_call = nullptr;
+    const CallNode* requantize_call = nullptr;
+    const CallNode* bias_add_call = nullptr;
+    const CallNode* conv2d_call = nullptr;
+    auto* final_call = expr.as<CallNode>();
+    auto* final_op = final_call->op.as<OpNode>();
+    if (final_op->name == "clip") {
+      clip_call = final_call;
+      requantize_call = clip_call->args[0].as<CallNode>();
+    } else {
+      requantize_call = final_call;
+    }
+    auto* requantize_input = requantize_call->args[0].as<CallNode>();
+    auto* requantize_input_op = requantize_input->op.as<OpNode>();
+    if (requantize_input_op->name == "nn.bias_add") {
+      bias_add_call = requantize_input;
+      conv2d_call = bias_add_call->args[0].as<CallNode>();
+    } else {
+      conv2d_call = requantize_input;
+    }
+
+    // Transpose weights: HWIO -> OHWI
+    auto* conv2d_attrs = conv2d_call->attrs.as<Conv2DAttrs>();
+    tvm::Attrs new_conv2d_attrs;
+    Expr transposed_kernel =
+        ConvertKernelLayout(conv2d_call->args[1], conv2d_attrs, &new_conv2d_attrs);
+
+    // Obtain input and output scales from Relay's Requantization
+    int64_t out_channels = conv2d_attrs->channels.as<IntImmNode>()->value;
+    float output_scale = GetScalarFromConstant<float>(requantize_call->args[3]);
+    auto input_scales = tvm::relay::qnn::GetFloatVectorFromConstant(requantize_call->args[1]);
+    ICHECK(input_scales.size() == static_cast<size_t>(out_channels));
+
+    // Calculate requantization multiplier and shift

Review comment:
       Would you be able to give some comments as to what is being calculated here ?

##########
File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc
##########
@@ -0,0 +1,158 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>

Review comment:
       Can we add documentation on what this pass would do ?
   
   Also we would need unit testing for this pass.

##########
File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc
##########
@@ -0,0 +1,158 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+class ExtractConstantsMutator : public MixedModeMutator {
+ public:
+  explicit ExtractConstantsMutator(IRModule& mod) : mod_(mod) {}
+
+ private:
+  String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }
+
+  Expr VisitExpr_(const FunctionNode* func) final {
+    Function final_func = GetRef<Function>(func);
+    ++func_nesting_level_;
+    auto new_body = VisitExpr(func->body);
+    --func_nesting_level_;
+    if (!new_body.same_as(func->body)) {
+      final_func = Function(FreeVars(new_body), new_body, func->ret_type,
+                            FreeTypeVars(new_body, mod_), func->attrs);
+      function_to_constants_.Set(GetRef<Function>(func), constants_within_function_);
+      constants_within_function_.clear();
+    }
+    return final_func;
+  }
+
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    Expr final_call = post;
+    auto* post_call = post.as<CallNode>();
+    if (post_call == nullptr) {

Review comment:
       I do not follow this logic. Why is this needed ?

##########
File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc
##########
@@ -0,0 +1,158 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+class ExtractConstantsMutator : public MixedModeMutator {
+ public:
+  explicit ExtractConstantsMutator(IRModule& mod) : mod_(mod) {}
+
+ private:
+  String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }
+
+  Expr VisitExpr_(const FunctionNode* func) final {
+    Function final_func = GetRef<Function>(func);
+    ++func_nesting_level_;
+    auto new_body = VisitExpr(func->body);
+    --func_nesting_level_;
+    if (!new_body.same_as(func->body)) {
+      final_func = Function(FreeVars(new_body), new_body, func->ret_type,
+                            FreeTypeVars(new_body, mod_), func->attrs);
+      function_to_constants_.Set(GetRef<Function>(func), constants_within_function_);
+      constants_within_function_.clear();
+    }
+    return final_func;
+  }
+
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    Expr final_call = post;
+    auto* post_call = post.as<CallNode>();
+    if (post_call == nullptr) {
+      return final_call;
+    }
+
+    // Replace Constant arguments with Vars for ML Operators
+    // Perform this for non-main Call Nodes only
+    if (func_nesting_level_ && call->op.as<OpNode>()) {
+      Array<Expr> new_args;
+      for (auto& arg : post_call->args) {
+        auto* const_arg = arg.as<ConstantNode>();
+        if (const_arg && !const_arg->is_scalar()) {
+          Var var_arg = Var(gen_var_name(), const_arg->tensor_type());
+          new_args.push_back(var_arg);
+          constants_within_function_.push_back(GetRef<Constant>(const_arg));
+        } else {
+          new_args.push_back(arg);
+        }
+      }
+      final_call = Call(call->op, new_args, call->attrs, {});
+    }
+
+    // Since the constants are kicked out of partitioned functions
+    // a new call to global function is needed
+    if (auto* glob_var_node = post_call->op.as<GlobalVarNode>()) {
+      auto glob_var = GetRef<GlobalVar>(glob_var_node);
+      auto glob_func = Downcast<Function>(mod_->Lookup(glob_var));
+      auto new_glob_func = VisitExpr(glob_func);
+      if (!new_glob_func.same_as(glob_func)) {
+        mod_->Update(glob_var, Downcast<Function>(new_glob_func));
+        Array<Expr> new_args = post_call->args;
+        ICHECK(function_to_constants_.find(glob_func) != function_to_constants_.end());

Review comment:
       This would be a bit compute intensive as this an Array and I think it take O(N) to do this.
   If we require this, I would suggest to use an onordered_set with constant access time.

##########
File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc
##########
@@ -0,0 +1,158 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+class ExtractConstantsMutator : public MixedModeMutator {
+ public:
+  explicit ExtractConstantsMutator(IRModule& mod) : mod_(mod) {}
+
+ private:
+  String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }
+
+  Expr VisitExpr_(const FunctionNode* func) final {
+    Function final_func = GetRef<Function>(func);
+    ++func_nesting_level_;
+    auto new_body = VisitExpr(func->body);
+    --func_nesting_level_;
+    if (!new_body.same_as(func->body)) {
+      final_func = Function(FreeVars(new_body), new_body, func->ret_type,
+                            FreeTypeVars(new_body, mod_), func->attrs);
+      function_to_constants_.Set(GetRef<Function>(func), constants_within_function_);
+      constants_within_function_.clear();
+    }
+    return final_func;
+  }
+
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    Expr final_call = post;
+    auto* post_call = post.as<CallNode>();
+    if (post_call == nullptr) {
+      return final_call;
+    }
+
+    // Replace Constant arguments with Vars for ML Operators
+    // Perform this for non-main Call Nodes only
+    if (func_nesting_level_ && call->op.as<OpNode>()) {
+      Array<Expr> new_args;
+      for (auto& arg : post_call->args) {
+        auto* const_arg = arg.as<ConstantNode>();
+        if (const_arg && !const_arg->is_scalar()) {
+          Var var_arg = Var(gen_var_name(), const_arg->tensor_type());
+          new_args.push_back(var_arg);
+          constants_within_function_.push_back(GetRef<Constant>(const_arg));

Review comment:
       nit : I would Downcast the ObjectRef (expr) to Constant here through out because ultimately you would need the ObjectRef again.

##########
File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc
##########
@@ -0,0 +1,158 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+class ExtractConstantsMutator : public MixedModeMutator {
+ public:
+  explicit ExtractConstantsMutator(IRModule& mod) : mod_(mod) {}
+
+ private:
+  String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }
+
+  Expr VisitExpr_(const FunctionNode* func) final {
+    Function final_func = GetRef<Function>(func);
+    ++func_nesting_level_;
+    auto new_body = VisitExpr(func->body);
+    --func_nesting_level_;
+    if (!new_body.same_as(func->body)) {
+      final_func = Function(FreeVars(new_body), new_body, func->ret_type,
+                            FreeTypeVars(new_body, mod_), func->attrs);
+      function_to_constants_.Set(GetRef<Function>(func), constants_within_function_);
+      constants_within_function_.clear();
+    }
+    return final_func;
+  }
+
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    Expr final_call = post;
+    auto* post_call = post.as<CallNode>();
+    if (post_call == nullptr) {
+      return final_call;
+    }
+
+    // Replace Constant arguments with Vars for ML Operators
+    // Perform this for non-main Call Nodes only
+    if (func_nesting_level_ && call->op.as<OpNode>()) {
+      Array<Expr> new_args;
+      for (auto& arg : post_call->args) {
+        auto* const_arg = arg.as<ConstantNode>();
+        if (const_arg && !const_arg->is_scalar()) {

Review comment:
       nit : it is not always required explicitly access the ConstantNode and you should be able to directly use Constant

##########
File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc
##########
@@ -0,0 +1,158 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+class ExtractConstantsMutator : public MixedModeMutator {
+ public:
+  explicit ExtractConstantsMutator(IRModule& mod) : mod_(mod) {}
+
+ private:
+  String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }
+
+  Expr VisitExpr_(const FunctionNode* func) final {
+    Function final_func = GetRef<Function>(func);
+    ++func_nesting_level_;

Review comment:
       A better design suggestion (feel free to disagree :) ) : We could use a stack of functions that gets pushed and popped when it goes in and out of scope.
   Thus, we dont need to do stateful clearing of constats_within_function_ array and performing the assignment here.
   
   Instead, as and when the visitor encounters a constant that could inserted to a Map<Function, Array<Constant>>.

##########
File path: src/relay/backend/contrib/cmsisnn/generate_constants.cc
##########
@@ -0,0 +1,230 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+Expr MakeTranspose(Expr data, Array<Integer> axes);
+namespace contrib {
+namespace cmsisnn {
+
+class GenerateConstantsMutator : public MixedModeMutator {
+ public:
+  explicit GenerateConstantsMutator(IRModule& mod) : mod_(mod) {}
+
+ private:
+  /*!  * \brief Converts Kernel layout from HWIO to OHWI to align to CMSIS-NN requirements */
+  Expr ConvertKernelLayout(Expr kernel_expr, const Conv2DAttrs* conv2d_attrs, Attrs* new_attrs) {
+    auto attrs = make_object<Conv2DAttrs>();
+    attrs->strides = std::move(conv2d_attrs->strides);
+    attrs->padding = std::move(conv2d_attrs->padding);
+    attrs->dilation = std::move(conv2d_attrs->dilation);
+    attrs->groups = conv2d_attrs->groups;
+    attrs->channels = std::move(conv2d_attrs->channels);
+    attrs->kernel_size = std::move(conv2d_attrs->kernel_size);
+    attrs->data_layout = std::move(conv2d_attrs->data_layout);
+    attrs->kernel_layout = runtime::String("OHWI");
+    attrs->out_layout = std::move(conv2d_attrs->out_layout);
+    attrs->out_dtype = std::move(conv2d_attrs->out_dtype);
+    *new_attrs = tvm::Attrs{attrs};
+
+    IRModule kernel_module;
+    auto func_body = MakeTranspose(kernel_expr, {Integer(3), Integer(0), Integer(1), Integer(2)});
+    auto kernel_func =
+        Function(FreeVars(func_body), func_body, Type(), FreeTypeVars(func_body, kernel_module));
+    GlobalVar kernel_var("main");
+    kernel_module->Add(kernel_var, kernel_func);
+    kernel_module = relay::transform::FoldConstant()(kernel_module);
+    kernel_func = Downcast<Function>(kernel_module->Lookup("main"));
+    return kernel_func->body;
+  }
+
+  /*!  * \brief Performs weight transpose and substitutes existing constants in the composite
+   *            function for Conv2D with CMSIS-NN Requantize constants */
+  Expr GenerateConv2dRequantConstants(const Expr& expr) {
+    const CallNode* clip_call = nullptr;
+    const CallNode* requantize_call = nullptr;
+    const CallNode* bias_add_call = nullptr;
+    const CallNode* conv2d_call = nullptr;
+    auto* final_call = expr.as<CallNode>();
+    auto* final_op = final_call->op.as<OpNode>();
+    if (final_op->name == "clip") {
+      clip_call = final_call;
+      requantize_call = clip_call->args[0].as<CallNode>();

Review comment:
       stylistic & nit : FYI, things could be accessed from ObjectRefs as well.

##########
File path: src/relay/backend/contrib/cmsisnn/generate_constants.cc
##########
@@ -0,0 +1,230 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+Expr MakeTranspose(Expr data, Array<Integer> axes);
+namespace contrib {
+namespace cmsisnn {
+
+class GenerateConstantsMutator : public MixedModeMutator {
+ public:
+  explicit GenerateConstantsMutator(IRModule& mod) : mod_(mod) {}
+
+ private:
+  /*!  * \brief Converts Kernel layout from HWIO to OHWI to align to CMSIS-NN requirements */
+  Expr ConvertKernelLayout(Expr kernel_expr, const Conv2DAttrs* conv2d_attrs, Attrs* new_attrs) {
+    auto attrs = make_object<Conv2DAttrs>();
+    attrs->strides = std::move(conv2d_attrs->strides);
+    attrs->padding = std::move(conv2d_attrs->padding);
+    attrs->dilation = std::move(conv2d_attrs->dilation);
+    attrs->groups = conv2d_attrs->groups;
+    attrs->channels = std::move(conv2d_attrs->channels);
+    attrs->kernel_size = std::move(conv2d_attrs->kernel_size);
+    attrs->data_layout = std::move(conv2d_attrs->data_layout);
+    attrs->kernel_layout = runtime::String("OHWI");
+    attrs->out_layout = std::move(conv2d_attrs->out_layout);
+    attrs->out_dtype = std::move(conv2d_attrs->out_dtype);
+    *new_attrs = tvm::Attrs{attrs};
+
+    IRModule kernel_module;
+    auto func_body = MakeTranspose(kernel_expr, {Integer(3), Integer(0), Integer(1), Integer(2)});
+    auto kernel_func =
+        Function(FreeVars(func_body), func_body, Type(), FreeTypeVars(func_body, kernel_module));
+    GlobalVar kernel_var("main");
+    kernel_module->Add(kernel_var, kernel_func);
+    kernel_module = relay::transform::FoldConstant()(kernel_module);
+    kernel_func = Downcast<Function>(kernel_module->Lookup("main"));
+    return kernel_func->body;
+  }
+
+  /*!  * \brief Performs weight transpose and substitutes existing constants in the composite
+   *            function for Conv2D with CMSIS-NN Requantize constants */
+  Expr GenerateConv2dRequantConstants(const Expr& expr) {
+    const CallNode* clip_call = nullptr;
+    const CallNode* requantize_call = nullptr;
+    const CallNode* bias_add_call = nullptr;
+    const CallNode* conv2d_call = nullptr;
+    auto* final_call = expr.as<CallNode>();
+    auto* final_op = final_call->op.as<OpNode>();
+    if (final_op->name == "clip") {
+      clip_call = final_call;
+      requantize_call = clip_call->args[0].as<CallNode>();
+    } else {
+      requantize_call = final_call;
+    }
+    auto* requantize_input = requantize_call->args[0].as<CallNode>();
+    auto* requantize_input_op = requantize_input->op.as<OpNode>();
+    if (requantize_input_op->name == "nn.bias_add") {
+      bias_add_call = requantize_input;
+      conv2d_call = bias_add_call->args[0].as<CallNode>();
+    } else {
+      conv2d_call = requantize_input;
+    }
+
+    // Transpose weights: HWIO -> OHWI
+    auto* conv2d_attrs = conv2d_call->attrs.as<Conv2DAttrs>();
+    tvm::Attrs new_conv2d_attrs;
+    Expr transposed_kernel =
+        ConvertKernelLayout(conv2d_call->args[1], conv2d_attrs, &new_conv2d_attrs);
+
+    // Obtain input and output scales from Relay's Requantization
+    int64_t out_channels = conv2d_attrs->channels.as<IntImmNode>()->value;
+    float output_scale = GetScalarFromConstant<float>(requantize_call->args[3]);
+    auto input_scales = tvm::relay::qnn::GetFloatVectorFromConstant(requantize_call->args[1]);
+    ICHECK(input_scales.size() == static_cast<size_t>(out_channels));
+
+    // Calculate requantization multiplier and shift
+    Device dev{DLDeviceType::kDLCPU, 0};
+    runtime::NDArray multiplier_nda =
+        runtime::NDArray::Empty({out_channels}, DataType::Int(32), dev);
+    runtime::NDArray shift_nda = runtime::NDArray::Empty({out_channels}, DataType::Int(32), dev);
+    int32_t* multiplier = static_cast<int32_t*>(multiplier_nda->data);
+    int32_t* shift = static_cast<int32_t*>(shift_nda->data);
+    for (int i = 0; i < out_channels; ++i) {
+      double effective_output_scale =
+          static_cast<double>(input_scales[i]) / static_cast<double>(output_scale);
+      std::tie(*(multiplier + i), *(shift + i)) =
+          tvm::relay::qnn::GetFixedPointMultiplierShift(effective_output_scale);
+    }
+
+    // Create constants from requantization multiplier and shift
+    Constant multiplier_const(multiplier_nda);
+    Constant shift_const(shift_nda);
+
+    // Convert scale scalars into Constants
+    // Scales are expected as Constants by following passes
+    Expr weight_scale = conv2d_call->args[5];
+    Expr req_inp_scale = requantize_call->args[1];
+    if (out_channels == 1) {
+      runtime::NDArray weight_scale_nda =
+          runtime::NDArray::Empty({out_channels}, DataType::Float(32), dev);
+      float* weight_scale_p = static_cast<float*>(weight_scale_nda->data);
+      *weight_scale_p = GetScalarFromConstant<float>(weight_scale);
+      weight_scale = Constant(weight_scale_nda);
+
+      runtime::NDArray req_inp_scale_nda =
+          runtime::NDArray::Empty({out_channels}, DataType::Float(32), dev);
+      float* req_inp_scale_p = static_cast<float*>(req_inp_scale_nda->data);
+      *req_inp_scale_p = GetScalarFromConstant<float>(req_inp_scale);
+      req_inp_scale = Constant(req_inp_scale_nda);
+    }
+
+    // Replace existing weights (HWIO) with the transposed ones (OHWI)
+    // Substitute Conv2D weight_zero_point with the CMSIS-NN multiplier
+    // Substitute Requantize input_zero_point with CMSIS-NN shift
+    // Conv2D arguments: data, weight, input_zp, weight_zp, input_sc, weight_sc
+    Array<Expr> conv2d_args = {conv2d_call->args[0], transposed_kernel,    conv2d_call->args[2],
+                               multiplier_const,     conv2d_call->args[4], weight_scale};
+    Call ret_call = Call(conv2d_call->op, conv2d_args, new_conv2d_attrs, {});
+    if (bias_add_call) {
+      ret_call =
+          Call(bias_add_call->op, {ret_call, bias_add_call->args[1]}, bias_add_call->attrs, {});
+    }
+    Array<Expr> requantize_args = {ret_call, req_inp_scale, shift_const, requantize_call->args[3],
+                                   requantize_call->args[4]};
+    ret_call = Call(requantize_call->op, requantize_args, requantize_call->attrs, {});
+    if (clip_call) {
+      ret_call = Call(clip_call->op, {ret_call}, clip_call->attrs, {});
+    }
+    return ret_call;
+  }
+
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    Expr final_call = post;
+    auto* post_call = post.as<CallNode>();
+    if (post_call == nullptr) {
+      return final_call;
+    }
+
+    auto* global_var = call->op.as<GlobalVarNode>();
+    if (global_var) {
+      // Update to global function call needed because the body changes while
+      // generating new constants
+      Function func = Downcast<Function>(mod_->Lookup(global_var->name_hint));
+      Expr new_body = VisitExpr(func->body);
+      if (!new_body.same_as(func->body)) {
+        Function new_func = Function(FreeVars(new_body), new_body, func->ret_type,
+                                     FreeTypeVars(new_body, mod_), func->attrs);
+        mod_->Update(GetRef<GlobalVar>(global_var), new_func);
+        final_call = Call(GetRef<GlobalVar>(global_var), post_call->args);
+      }
+    }
+
+    // Recreate composite function and corresponding call
+    // Updated composite function contains CMSIS-NN quantized multiplier and shift constants
+    if (call->op.as<FunctionNode>()) {
+      auto* func = call->op.as<FunctionNode>();
+      auto func_name = func->GetAttr<String>(attr::kComposite);
+      if (func_name.defined() && func_name == "cmsisnn.qnn_conv2d") {
+        Expr new_body = GenerateConv2dRequantConstants(func->body);
+        Function new_func = Function(FreeVars(new_body), new_body, func->ret_type,
+                                     FreeTypeVars(new_body, mod_), func->attrs);
+        final_call = Call(new_func, post_call->args);
+      }
+    }
+
+    return final_call;
+  }
+
+ private:
+  IRModule mod_;
+};
+
+IRModule GenerateConstants(IRModule mod) {

Review comment:
       I would be better if we dont have to modify the input IRModule to follow the design pattern -- unless there is a reason.

##########
File path: src/relay/backend/contrib/cmsisnn/generate_constants.cc
##########
@@ -0,0 +1,230 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+Expr MakeTranspose(Expr data, Array<Integer> axes);
+namespace contrib {
+namespace cmsisnn {
+
+class GenerateConstantsMutator : public MixedModeMutator {
+ public:
+  explicit GenerateConstantsMutator(IRModule& mod) : mod_(mod) {}
+
+ private:
+  /*!  * \brief Converts Kernel layout from HWIO to OHWI to align to CMSIS-NN requirements */
+  Expr ConvertKernelLayout(Expr kernel_expr, const Conv2DAttrs* conv2d_attrs, Attrs* new_attrs) {
+    auto attrs = make_object<Conv2DAttrs>();
+    attrs->strides = std::move(conv2d_attrs->strides);
+    attrs->padding = std::move(conv2d_attrs->padding);
+    attrs->dilation = std::move(conv2d_attrs->dilation);
+    attrs->groups = conv2d_attrs->groups;
+    attrs->channels = std::move(conv2d_attrs->channels);
+    attrs->kernel_size = std::move(conv2d_attrs->kernel_size);
+    attrs->data_layout = std::move(conv2d_attrs->data_layout);
+    attrs->kernel_layout = runtime::String("OHWI");
+    attrs->out_layout = std::move(conv2d_attrs->out_layout);
+    attrs->out_dtype = std::move(conv2d_attrs->out_dtype);
+    *new_attrs = tvm::Attrs{attrs};
+
+    IRModule kernel_module;
+    auto func_body = MakeTranspose(kernel_expr, {Integer(3), Integer(0), Integer(1), Integer(2)});
+    auto kernel_func =
+        Function(FreeVars(func_body), func_body, Type(), FreeTypeVars(func_body, kernel_module));
+    GlobalVar kernel_var("main");
+    kernel_module->Add(kernel_var, kernel_func);
+    kernel_module = relay::transform::FoldConstant()(kernel_module);
+    kernel_func = Downcast<Function>(kernel_module->Lookup("main"));
+    return kernel_func->body;
+  }
+
+  /*!  * \brief Performs weight transpose and substitutes existing constants in the composite
+   *            function for Conv2D with CMSIS-NN Requantize constants */
+  Expr GenerateConv2dRequantConstants(const Expr& expr) {
+    const CallNode* clip_call = nullptr;
+    const CallNode* requantize_call = nullptr;
+    const CallNode* bias_add_call = nullptr;
+    const CallNode* conv2d_call = nullptr;
+    auto* final_call = expr.as<CallNode>();
+    auto* final_op = final_call->op.as<OpNode>();
+    if (final_op->name == "clip") {
+      clip_call = final_call;
+      requantize_call = clip_call->args[0].as<CallNode>();
+    } else {
+      requantize_call = final_call;
+    }
+    auto* requantize_input = requantize_call->args[0].as<CallNode>();
+    auto* requantize_input_op = requantize_input->op.as<OpNode>();
+    if (requantize_input_op->name == "nn.bias_add") {
+      bias_add_call = requantize_input;
+      conv2d_call = bias_add_call->args[0].as<CallNode>();
+    } else {
+      conv2d_call = requantize_input;
+    }
+
+    // Transpose weights: HWIO -> OHWI
+    auto* conv2d_attrs = conv2d_call->attrs.as<Conv2DAttrs>();
+    tvm::Attrs new_conv2d_attrs;
+    Expr transposed_kernel =
+        ConvertKernelLayout(conv2d_call->args[1], conv2d_attrs, &new_conv2d_attrs);
+
+    // Obtain input and output scales from Relay's Requantization
+    int64_t out_channels = conv2d_attrs->channels.as<IntImmNode>()->value;
+    float output_scale = GetScalarFromConstant<float>(requantize_call->args[3]);
+    auto input_scales = tvm::relay::qnn::GetFloatVectorFromConstant(requantize_call->args[1]);
+    ICHECK(input_scales.size() == static_cast<size_t>(out_channels));
+
+    // Calculate requantization multiplier and shift
+    Device dev{DLDeviceType::kDLCPU, 0};
+    runtime::NDArray multiplier_nda =
+        runtime::NDArray::Empty({out_channels}, DataType::Int(32), dev);
+    runtime::NDArray shift_nda = runtime::NDArray::Empty({out_channels}, DataType::Int(32), dev);
+    int32_t* multiplier = static_cast<int32_t*>(multiplier_nda->data);
+    int32_t* shift = static_cast<int32_t*>(shift_nda->data);
+    for (int i = 0; i < out_channels; ++i) {
+      double effective_output_scale =
+          static_cast<double>(input_scales[i]) / static_cast<double>(output_scale);
+      std::tie(*(multiplier + i), *(shift + i)) =
+          tvm::relay::qnn::GetFixedPointMultiplierShift(effective_output_scale);
+    }
+
+    // Create constants from requantization multiplier and shift
+    Constant multiplier_const(multiplier_nda);
+    Constant shift_const(shift_nda);
+
+    // Convert scale scalars into Constants
+    // Scales are expected as Constants by following passes
+    Expr weight_scale = conv2d_call->args[5];
+    Expr req_inp_scale = requantize_call->args[1];
+    if (out_channels == 1) {
+      runtime::NDArray weight_scale_nda =
+          runtime::NDArray::Empty({out_channels}, DataType::Float(32), dev);
+      float* weight_scale_p = static_cast<float*>(weight_scale_nda->data);
+      *weight_scale_p = GetScalarFromConstant<float>(weight_scale);
+      weight_scale = Constant(weight_scale_nda);
+
+      runtime::NDArray req_inp_scale_nda =
+          runtime::NDArray::Empty({out_channels}, DataType::Float(32), dev);
+      float* req_inp_scale_p = static_cast<float*>(req_inp_scale_nda->data);
+      *req_inp_scale_p = GetScalarFromConstant<float>(req_inp_scale);
+      req_inp_scale = Constant(req_inp_scale_nda);
+    }
+
+    // Replace existing weights (HWIO) with the transposed ones (OHWI)
+    // Substitute Conv2D weight_zero_point with the CMSIS-NN multiplier
+    // Substitute Requantize input_zero_point with CMSIS-NN shift
+    // Conv2D arguments: data, weight, input_zp, weight_zp, input_sc, weight_sc
+    Array<Expr> conv2d_args = {conv2d_call->args[0], transposed_kernel,    conv2d_call->args[2],
+                               multiplier_const,     conv2d_call->args[4], weight_scale};
+    Call ret_call = Call(conv2d_call->op, conv2d_args, new_conv2d_attrs, {});
+    if (bias_add_call) {
+      ret_call =
+          Call(bias_add_call->op, {ret_call, bias_add_call->args[1]}, bias_add_call->attrs, {});
+    }
+    Array<Expr> requantize_args = {ret_call, req_inp_scale, shift_const, requantize_call->args[3],
+                                   requantize_call->args[4]};
+    ret_call = Call(requantize_call->op, requantize_args, requantize_call->attrs, {});
+    if (clip_call) {
+      ret_call = Call(clip_call->op, {ret_call}, clip_call->attrs, {});
+    }
+    return ret_call;
+  }
+
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    Expr final_call = post;
+    auto* post_call = post.as<CallNode>();
+    if (post_call == nullptr) {
+      return final_call;
+    }
+
+    auto* global_var = call->op.as<GlobalVarNode>();
+    if (global_var) {
+      // Update to global function call needed because the body changes while
+      // generating new constants
+      Function func = Downcast<Function>(mod_->Lookup(global_var->name_hint));
+      Expr new_body = VisitExpr(func->body);
+      if (!new_body.same_as(func->body)) {
+        Function new_func = Function(FreeVars(new_body), new_body, func->ret_type,
+                                     FreeTypeVars(new_body, mod_), func->attrs);
+        mod_->Update(GetRef<GlobalVar>(global_var), new_func);
+        final_call = Call(GetRef<GlobalVar>(global_var), post_call->args);
+      }
+    }
+
+    // Recreate composite function and corresponding call
+    // Updated composite function contains CMSIS-NN quantized multiplier and shift constants
+    if (call->op.as<FunctionNode>()) {

Review comment:
       This could be refactored with above section.

##########
File path: src/relay/backend/contrib/cmsisnn/generate_constants.cc
##########
@@ -0,0 +1,230 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+Expr MakeTranspose(Expr data, Array<Integer> axes);
+namespace contrib {
+namespace cmsisnn {
+
+class GenerateConstantsMutator : public MixedModeMutator {
+ public:
+  explicit GenerateConstantsMutator(IRModule& mod) : mod_(mod) {}
+
+ private:
+  /*!  * \brief Converts Kernel layout from HWIO to OHWI to align to CMSIS-NN requirements */
+  Expr ConvertKernelLayout(Expr kernel_expr, const Conv2DAttrs* conv2d_attrs, Attrs* new_attrs) {
+    auto attrs = make_object<Conv2DAttrs>();
+    attrs->strides = std::move(conv2d_attrs->strides);
+    attrs->padding = std::move(conv2d_attrs->padding);
+    attrs->dilation = std::move(conv2d_attrs->dilation);
+    attrs->groups = conv2d_attrs->groups;
+    attrs->channels = std::move(conv2d_attrs->channels);
+    attrs->kernel_size = std::move(conv2d_attrs->kernel_size);
+    attrs->data_layout = std::move(conv2d_attrs->data_layout);
+    attrs->kernel_layout = runtime::String("OHWI");
+    attrs->out_layout = std::move(conv2d_attrs->out_layout);
+    attrs->out_dtype = std::move(conv2d_attrs->out_dtype);
+    *new_attrs = tvm::Attrs{attrs};
+
+    IRModule kernel_module;
+    auto func_body = MakeTranspose(kernel_expr, {Integer(3), Integer(0), Integer(1), Integer(2)});
+    auto kernel_func =
+        Function(FreeVars(func_body), func_body, Type(), FreeTypeVars(func_body, kernel_module));
+    GlobalVar kernel_var("main");
+    kernel_module->Add(kernel_var, kernel_func);
+    kernel_module = relay::transform::FoldConstant()(kernel_module);
+    kernel_func = Downcast<Function>(kernel_module->Lookup("main"));
+    return kernel_func->body;
+  }
+
+  /*!  * \brief Performs weight transpose and substitutes existing constants in the composite
+   *            function for Conv2D with CMSIS-NN Requantize constants */
+  Expr GenerateConv2dRequantConstants(const Expr& expr) {
+    const CallNode* clip_call = nullptr;
+    const CallNode* requantize_call = nullptr;
+    const CallNode* bias_add_call = nullptr;
+    const CallNode* conv2d_call = nullptr;
+    auto* final_call = expr.as<CallNode>();
+    auto* final_op = final_call->op.as<OpNode>();
+    if (final_op->name == "clip") {
+      clip_call = final_call;
+      requantize_call = clip_call->args[0].as<CallNode>();
+    } else {
+      requantize_call = final_call;
+    }
+    auto* requantize_input = requantize_call->args[0].as<CallNode>();
+    auto* requantize_input_op = requantize_input->op.as<OpNode>();
+    if (requantize_input_op->name == "nn.bias_add") {
+      bias_add_call = requantize_input;
+      conv2d_call = bias_add_call->args[0].as<CallNode>();
+    } else {
+      conv2d_call = requantize_input;
+    }
+
+    // Transpose weights: HWIO -> OHWI
+    auto* conv2d_attrs = conv2d_call->attrs.as<Conv2DAttrs>();
+    tvm::Attrs new_conv2d_attrs;
+    Expr transposed_kernel =
+        ConvertKernelLayout(conv2d_call->args[1], conv2d_attrs, &new_conv2d_attrs);
+
+    // Obtain input and output scales from Relay's Requantization
+    int64_t out_channels = conv2d_attrs->channels.as<IntImmNode>()->value;
+    float output_scale = GetScalarFromConstant<float>(requantize_call->args[3]);
+    auto input_scales = tvm::relay::qnn::GetFloatVectorFromConstant(requantize_call->args[1]);
+    ICHECK(input_scales.size() == static_cast<size_t>(out_channels));
+
+    // Calculate requantization multiplier and shift
+    Device dev{DLDeviceType::kDLCPU, 0};
+    runtime::NDArray multiplier_nda =
+        runtime::NDArray::Empty({out_channels}, DataType::Int(32), dev);
+    runtime::NDArray shift_nda = runtime::NDArray::Empty({out_channels}, DataType::Int(32), dev);
+    int32_t* multiplier = static_cast<int32_t*>(multiplier_nda->data);
+    int32_t* shift = static_cast<int32_t*>(shift_nda->data);
+    for (int i = 0; i < out_channels; ++i) {
+      double effective_output_scale =
+          static_cast<double>(input_scales[i]) / static_cast<double>(output_scale);
+      std::tie(*(multiplier + i), *(shift + i)) =
+          tvm::relay::qnn::GetFixedPointMultiplierShift(effective_output_scale);
+    }
+
+    // Create constants from requantization multiplier and shift
+    Constant multiplier_const(multiplier_nda);
+    Constant shift_const(shift_nda);
+
+    // Convert scale scalars into Constants
+    // Scales are expected as Constants by following passes
+    Expr weight_scale = conv2d_call->args[5];
+    Expr req_inp_scale = requantize_call->args[1];
+    if (out_channels == 1) {
+      runtime::NDArray weight_scale_nda =
+          runtime::NDArray::Empty({out_channels}, DataType::Float(32), dev);
+      float* weight_scale_p = static_cast<float*>(weight_scale_nda->data);
+      *weight_scale_p = GetScalarFromConstant<float>(weight_scale);
+      weight_scale = Constant(weight_scale_nda);
+
+      runtime::NDArray req_inp_scale_nda =
+          runtime::NDArray::Empty({out_channels}, DataType::Float(32), dev);
+      float* req_inp_scale_p = static_cast<float*>(req_inp_scale_nda->data);
+      *req_inp_scale_p = GetScalarFromConstant<float>(req_inp_scale);
+      req_inp_scale = Constant(req_inp_scale_nda);
+    }
+
+    // Replace existing weights (HWIO) with the transposed ones (OHWI)
+    // Substitute Conv2D weight_zero_point with the CMSIS-NN multiplier
+    // Substitute Requantize input_zero_point with CMSIS-NN shift
+    // Conv2D arguments: data, weight, input_zp, weight_zp, input_sc, weight_sc
+    Array<Expr> conv2d_args = {conv2d_call->args[0], transposed_kernel,    conv2d_call->args[2],

Review comment:
       nit : this has a bit odd spacing. is it clang-format doing this ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org