You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/10/29 13:52:32 UTC
[incubator-tvm] branch main updated: [VTA] quant support for
alu-only op (#6191)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1b6ef5c [VTA] quant support for alu-only op (#6191)
1b6ef5c is described below
commit 1b6ef5c8022986664e043be01526c31b3e3a9449
Author: ZHANG Hao <zh...@4paradigm.com>
AuthorDate: Thu Oct 29 21:52:17 2020 +0800
[VTA] quant support for alu-only op (#6191)
---
python/tvm/relay/quantize/_annotate.py | 1 +
python/tvm/relay/quantize/_partition.py | 15 ++++++++++++++-
src/relay/quantize/realize.cc | 28 +++++++++++++++++++++-------
3 files changed, 36 insertions(+), 8 deletions(-)
diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py
index 329ba64..b187387 100644
--- a/python/tvm/relay/quantize/_annotate.py
+++ b/python/tvm/relay/quantize/_annotate.py
@@ -284,6 +284,7 @@ def identity_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(ret_expr, x_kind)
+register_annotate_function("reshape", identity_rewrite)
register_annotate_function("clip", identity_rewrite)
register_annotate_function("nn.relu", identity_rewrite)
register_annotate_function("strided_slice", identity_rewrite)
diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py
index 6892e86..563d283 100644
--- a/python/tvm/relay/quantize/_partition.py
+++ b/python/tvm/relay/quantize/_partition.py
@@ -82,7 +82,7 @@ def add_partition_generic(ref_call, new_args, ctx):
# ...
lhs = new_args[0].realize()
rhs = new_args[1].realize()
- return _forward_op(ref_call, [lhs, rhs])
+ return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
if not lhs_cond and rhs_cond:
# - introduced by residual connection in ResNet
# ...
@@ -130,6 +130,7 @@ def mul_partition_generic(ref_call, new_args, ctx):
if lhs_cond:
# introduced by bn: multiply(out, scale)
+ lhs = new_args[0].realize()
return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
if not lhs_cond and not rhs_cond:
@@ -155,3 +156,15 @@ def add_partition_function(ref_call, new_args, ctx):
def multiply_partition_function(ref_call, new_args, ctx):
"""Rewrite function for ewise multiply for partition"""
return mul_partition_generic(ref_call, new_args, ctx)
+
+
+# add cast after the relu op to make it run on vta
+@register_partition_function("nn.global_avg_pool2d")
+def global_avg_pool2d_partition_function(ref_call, new_args, ctx):
+ cond, expr = partition_expr_check(new_args[0])
+ if cond:
+ expr = new_args[0].realize()
+ else:
+ expr = QPartitionExpr(new_args[0]).realize()
+
+ return _forward_op(ref_call, [expr])
diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc
index c96a1b0..8db72a3 100644
--- a/src/relay/quantize/realize.cc
+++ b/src/relay/quantize/realize.cc
@@ -309,7 +309,8 @@ float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
/* \brief Unify the dom scale of arguments */
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args,
- DataType* dtype_ptr, Expr* scale_ptr) {
+ DataType* dtype_ptr, Expr* scale_ptr,
+ DataType dtype = DataType::Void()) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
const QConfig& cfg = QConfig::Current();
@@ -324,13 +325,15 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args
// unify the data type
ICHECK_EQ(ref_args.size(), args.size());
- DataType dtype;
- if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
- dtype = cfg->dtype_input;
- } else {
- dtype = cfg->dtype_activation;
+ if (dtype.is_void()) {
+ if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
+ dtype = cfg->dtype_input;
+ } else {
+ dtype = cfg->dtype_activation;
+ }
}
+
for (size_t i = 0; i < ret.size(); ++i) {
auto ref_arg = ref_args[i].as<CallNode>();
if (nptrs[i]->dtype != dtype) {
@@ -361,7 +364,16 @@ Expr AddRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectR
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
- Array<Expr> ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale);
+ // execute the operation with activation data type.
+ const QConfig& cfg = QConfig::Current();
+ Array<Expr> ret_args =
+ UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale, cfg->dtype_activation);
+ for (size_t i = 0; i < ret_args.size(); ++i) {
+ // do not fuse float32 arg
+ if (new_args[i].as<QRealizeIntExprNode>()->dtype == DataType::Float(32)) {
+ ret_args.Set(i, StopFusion(ret_args[i]));
+ }
+ }
Expr ret = ForwardOp(ref_call, ret_args);
return QRealizeIntExpr(ret, dom_scale, dtype);
}
@@ -430,6 +442,8 @@ Expr IdentityRealize(const Call& ref_call, const Array<Expr>& new_args, const Ob
RELAY_REGISTER_OP("nn.relu").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
+RELAY_REGISTER_OP("reshape").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
+
RELAY_REGISTER_OP("strided_slice").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
RELAY_REGISTER_OP("nn.batch_flatten")