You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2020/11/07 02:21:14 UTC
[incubator-tvm] branch main updated: making quantization tweaks
(#6731)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ff9c480 making quantization tweaks (#6731)
ff9c480 is described below
commit ff9c4803913b82085f281c98afbd54feedefeb7c
Author: Thierry Moreau <tm...@octoml.ai>
AuthorDate: Fri Nov 6 18:20:56 2020 -0800
making quantization tweaks (#6731)
---
python/tvm/relay/quantize/_annotate.py | 43 ++++++++++++++++++++++++++++++++++
src/relay/quantize/realize.cc | 36 ++++++++++++++++++++++++++++
2 files changed, 79 insertions(+)
diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py
index b187387..6c395e2 100644
--- a/python/tvm/relay/quantize/_annotate.py
+++ b/python/tvm/relay/quantize/_annotate.py
@@ -175,6 +175,28 @@ def conv2d_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
+@register_annotate_function("nn.conv1d")
+def conv1d_rewrite(ref_call, new_args, ctx):
+ """Rewrite function for conv1d. Lhs of conv will be quantized to
+ input field, and rhs of conv will be quantized to weight field.
+ Output would be in activation field"""
+ if quantize_context().check_to_skip(ref_call):
+ return None
+
+ lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
+ rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
+
+ if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
+ lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
+
+ assert rhs_kind is None
+ rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
+
+ expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
+
+ return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
+
+
@register_annotate_function("nn.dense")
def dense_rewrite(ref_call, new_args, ctx):
"""Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
@@ -289,6 +311,8 @@ register_annotate_function("clip", identity_rewrite)
register_annotate_function("nn.relu", identity_rewrite)
register_annotate_function("strided_slice", identity_rewrite)
register_annotate_function("nn.avg_pool2d", identity_rewrite)
+register_annotate_function("nn.batch_flatten", identity_rewrite)
+register_annotate_function("transpose", identity_rewrite)
register_annotate_function("annotation.stop_fusion", identity_rewrite)
@@ -311,6 +335,25 @@ def pool2d_rewrite(ref_call, new_args, ctx):
register_annotate_function("nn.max_pool2d", pool2d_rewrite)
+def pool1d_rewrite(ref_call, new_args, ctx):
+ """Rewrite function for max pool1d"""
+ if quantize_context().check_to_skip(ref_call):
+ return None
+
+ expr, x_kind = _get_expr_kind(new_args[0])
+
+ if x_kind is None:
+ return None
+ if x_kind == QAnnotateKind.ACTIVATION:
+ expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)
+
+ expr = _forward_op(ref_call, [expr])
+ return QAnnotateExpr(expr, QAnnotateKind.INPUT)
+
+
+register_annotate_function("nn.max_pool1d", pool1d_rewrite)
+
+
@register_annotate_function("annotation.cast_hint")
def cast_hint_rewrite(ref_call, new_args, ctx):
"""Rewrite function to force cast"""
diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc
index 8db72a3..2716c6e 100644
--- a/src/relay/quantize/realize.cc
+++ b/src/relay/quantize/realize.cc
@@ -234,6 +234,37 @@ Expr Conv2dRealize(const Call& ref_call, const Array<Expr>& new_args, const Obje
RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);
+Expr Conv1dRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
+ const QConfig& cfg = QConfig::Current();
+ CHECK_EQ(new_args.size(), 2);
+ if (!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>()) {
+ return Expr(nullptr);
+ }
+ const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
+ CHECK(lhs);
+ const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
+ CHECK(rhs);
+
+ Expr ldata = lhs->data;
+ if (lhs->dtype != cfg->dtype_input) {
+ ldata = Cast(ldata, cfg->dtype_input);
+ }
+ Expr rdata = Cast(rhs->data, cfg->dtype_weight);
+
+ const auto ref_attrs = ref_call->attrs.as<Conv1DAttrs>();
+ auto attrs = make_object<Conv1DAttrs>();
+ *attrs = *ref_attrs;
+ DataType out_dtype = cfg->dtype_activation;
+ attrs->out_dtype = out_dtype;
+
+ Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args);
+ Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
+ Expr dom_scale = FoldConstantOpt(mul);
+ return QRealizeIntExpr(ret, dom_scale, out_dtype);
+}
+
+RELAY_REGISTER_OP("nn.conv1d").set_attr<FForwardRewrite>("FQRealizeRewrite", Conv1dRealize);
+
Expr DenseRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
ICHECK_EQ(new_args.size(), 2);
@@ -449,6 +480,8 @@ RELAY_REGISTER_OP("strided_slice").set_attr<FForwardRewrite>("FQRealizeRewrite",
RELAY_REGISTER_OP("nn.batch_flatten")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
+RELAY_REGISTER_OP("transpose").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
+
RELAY_REGISTER_OP("annotation.stop_fusion")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
@@ -469,6 +502,9 @@ Expr CastDtypeInputRealize(const Call& ref_call, const Array<Expr>& new_args,
RELAY_REGISTER_OP("nn.max_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
+RELAY_REGISTER_OP("nn.max_pool1d")
+ .set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
+
Expr AvgPoolRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
ICHECK_EQ(new_args.size(), 1);