You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/10/29 13:52:32 UTC

[incubator-tvm] branch main updated: [VTA] quant support for alu-only op (#6191)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1b6ef5c  [VTA] quant support for alu-only op (#6191)
1b6ef5c is described below

commit 1b6ef5c8022986664e043be01526c31b3e3a9449
Author: ZHANG Hao <zh...@4paradigm.com>
AuthorDate: Thu Oct 29 21:52:17 2020 +0800

    [VTA] quant support for alu-only op (#6191)
---
 python/tvm/relay/quantize/_annotate.py  |  1 +
 python/tvm/relay/quantize/_partition.py | 15 ++++++++++++++-
 src/relay/quantize/realize.cc           | 28 +++++++++++++++++++++-------
 3 files changed, 36 insertions(+), 8 deletions(-)

diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py
index 329ba64..b187387 100644
--- a/python/tvm/relay/quantize/_annotate.py
+++ b/python/tvm/relay/quantize/_annotate.py
@@ -284,6 +284,7 @@ def identity_rewrite(ref_call, new_args, ctx):
     return QAnnotateExpr(ret_expr, x_kind)
 
 
+register_annotate_function("reshape", identity_rewrite)
 register_annotate_function("clip", identity_rewrite)
 register_annotate_function("nn.relu", identity_rewrite)
 register_annotate_function("strided_slice", identity_rewrite)
diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py
index 6892e86..563d283 100644
--- a/python/tvm/relay/quantize/_partition.py
+++ b/python/tvm/relay/quantize/_partition.py
@@ -82,7 +82,7 @@ def add_partition_generic(ref_call, new_args, ctx):
         #     ...
         lhs = new_args[0].realize()
         rhs = new_args[1].realize()
-        return _forward_op(ref_call, [lhs, rhs])
+        return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
     if not lhs_cond and rhs_cond:
         # - introduced by residual connection in ResNet
         #     ...
@@ -130,6 +130,7 @@ def mul_partition_generic(ref_call, new_args, ctx):
 
     if lhs_cond:
         # introduced by bn: multiply(out, scale)
+        lhs = new_args[0].realize()
         return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
 
     if not lhs_cond and not rhs_cond:
@@ -155,3 +156,15 @@ def add_partition_function(ref_call, new_args, ctx):
 def multiply_partition_function(ref_call, new_args, ctx):
     """Rewrite function for ewise multiply for partition"""
     return mul_partition_generic(ref_call, new_args, ctx)
+
+
+# add cast after the relu op to make it run on vta
+@register_partition_function("nn.global_avg_pool2d")
+def global_avg_pool2d_partition_function(ref_call, new_args, ctx):
+    cond, expr = partition_expr_check(new_args[0])
+    if cond:
+        expr = new_args[0].realize()
+    else:
+        expr = QPartitionExpr(new_args[0]).realize()
+
+    return _forward_op(ref_call, [expr])
diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc
index c96a1b0..8db72a3 100644
--- a/src/relay/quantize/realize.cc
+++ b/src/relay/quantize/realize.cc
@@ -309,7 +309,8 @@ float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
 
 /* \brief Unify the dom scale of arguments */
 Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args,
-                            DataType* dtype_ptr, Expr* scale_ptr) {
+                            DataType* dtype_ptr, Expr* scale_ptr,
+                            DataType dtype = DataType::Void()) {
   static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
   const QConfig& cfg = QConfig::Current();
 
@@ -324,13 +325,15 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args
 
   // unify the data type
   ICHECK_EQ(ref_args.size(), args.size());
-  DataType dtype;
 
-  if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
-    dtype = cfg->dtype_input;
-  } else {
-    dtype = cfg->dtype_activation;
+  if (dtype.is_void()) {
+    if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
+      dtype = cfg->dtype_input;
+    } else {
+      dtype = cfg->dtype_activation;
+    }
   }
+
   for (size_t i = 0; i < ret.size(); ++i) {
     auto ref_arg = ref_args[i].as<CallNode>();
     if (nptrs[i]->dtype != dtype) {
@@ -361,7 +364,16 @@ Expr AddRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectR
   if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
     DataType dtype;
     Expr dom_scale;
-    Array<Expr> ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale);
+    // execute the operation with activation data type.
+    const QConfig& cfg = QConfig::Current();
+    Array<Expr> ret_args =
+        UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale, cfg->dtype_activation);
+    for (size_t i = 0; i < ret_args.size(); ++i) {
+      // do not fuse float32 arg
+      if (new_args[i].as<QRealizeIntExprNode>()->dtype == DataType::Float(32)) {
+        ret_args.Set(i, StopFusion(ret_args[i]));
+      }
+    }
     Expr ret = ForwardOp(ref_call, ret_args);
     return QRealizeIntExpr(ret, dom_scale, dtype);
   }
@@ -430,6 +442,8 @@ Expr IdentityRealize(const Call& ref_call, const Array<Expr>& new_args, const Ob
 
 RELAY_REGISTER_OP("nn.relu").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
 
+RELAY_REGISTER_OP("reshape").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
+
 RELAY_REGISTER_OP("strided_slice").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
 
 RELAY_REGISTER_OP("nn.batch_flatten")