You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/09/30 16:46:05 UTC

[GitHub] [tvm] manupa-arm commented on a change in pull request #9163: [CMSIS-NN] Initial operator support for Mul

manupa-arm commented on a change in pull request #9163:
URL: https://github.com/apache/tvm/pull/9163#discussion_r719578588



##########
File path: src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
##########
@@ -32,17 +32,37 @@ namespace relay {
 namespace contrib {
 namespace cmsisnn {
 
-class RelayToTIR : public MixedModeVisitor {
+class RelayToTIRVisitor : public MixedModeVisitor {
  public:
-  explicit RelayToTIR(String func_name) : func_name_(func_name) {}
+  explicit RelayToTIRVisitor(String func_name) : func_name_(func_name) {}
+
+  tir::PrimFunc GetReplacementPrimFunc() { return primfunc_; }
 
  private:
-  void emit_softmax_tir(const Expr& expr) {
+  template <typename T>
+  const T ArgumentToConstantValue(const Expr& arg) {

Review comment:
       nit : Why is the function specific to "Argument" ? It seems like it can work for any Expr

##########
File path: src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
##########
@@ -79,15 +99,51 @@ class RelayToTIR : public MixedModeVisitor {
         IntImm(DataType::Int(32), num_rows), IntImm(DataType::Int(32), row_size),
         IntImm(DataType::Int(32), mult),     IntImm(DataType::Int(32), shift),
         IntImm(DataType::Int(32), diff_min), out_var};
-    tir::Stmt body =
-        tir::Evaluate(tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), args));
 
-    Map<String, ObjectRef> dict_attrs;
-    dict_attrs.Set("global_symbol", func_name_);
-    dict_attrs.Set("tir.noalias", Bool(true));
+    CreatePrimFuncForExtern(func_signature, args);
+  }
 
-    primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
-                              DictAttrs(dict_attrs));
+  void EmitMul(const Expr& expr) {
+    auto* mul_call = expr.as<CallNode>();
+
+    const float input_0_scale = ArgumentToConstantValue<float>(mul_call->args[2]);
+    const int32_t input_0_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[3]);
+    const float input_1_scale = ArgumentToConstantValue<float>(mul_call->args[4]);
+    const int32_t input_1_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[5]);
+    const float output_scale = ArgumentToConstantValue<float>(mul_call->args[6]);
+    const int32_t output_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[7]);
+
+    double quantized_multiplier = static_cast<double>(input_0_scale) *
+                                  static_cast<double>(input_1_scale) /
+                                  static_cast<double>(output_scale);
+    auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier);
+    int32_t output_multiplier = std::get<0>(mult_shift_pair);
+    int32_t output_shift = std::get<1>(mult_shift_pair);
+
+    PrimExpr tensor_size = mul_call->type_as<TensorTypeNode>()->Size();
+
+    tir::Var input_0("input_0", DataType::Handle(8));

Review comment:
       Its worth a comment why we create a handle of 8 bits here. Why do we think 8-bits are sufficient for all cases ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org