You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/07/12 16:57:59 UTC
[tvm] branch main updated: [QNN] Replace nn.leaky_relu with qnn.leaky_relu (#11930)
This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6d676badff [QNN] Replace nn.leaky_relu with qnn.leaky_relu (#11930)
6d676badff is described below
commit 6d676badff499a3b87fb47370f2f0d1d1318e8ed
Author: zhaoyang-star <zh...@foxmail.com>
AuthorDate: Wed Jul 13 00:57:54 2022 +0800
[QNN] Replace nn.leaky_relu with qnn.leaky_relu (#11930)
* [QNN] Replace nn.leaky_relu with qnn.leaky_relu
* jostle ci
* fix typo
---
python/tvm/relay/frontend/qnn_torch.py | 17 ++++++++++++++---
1 file changed, 14 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py
index 0485a993ac..824d3bbe64 100644
--- a/python/tvm/relay/frontend/qnn_torch.py
+++ b/python/tvm/relay/frontend/qnn_torch.py
@@ -937,10 +937,9 @@ def _relu6():
return _impl
-def _leaky_relu():
+def _leaky_relu(fp32_piggy_back=False):
# refer to src/ATen/native/quantized/cpu/qrelu.cpp
- def _impl(inputs, _):
- assert len(inputs) == 7, "Input quant params not found in op inputs"
+ def _impl_fp32(inputs, _):
alpha = inputs[1]
output_scale = _expr.const(inputs[3])
output_zero_point = _expr.const(inputs[4])
@@ -952,6 +951,18 @@ def _leaky_relu():
dequantized, output_scale, output_zero_point, out_dtype="uint8"
)
+ def _impl_int8(inputs, _):
+ alpha = inputs[1]
+ output_scale = _expr.const(inputs[3])
+ output_zero_point = _expr.const(inputs[4])
+ return relay.qnn.op.leaky_relu(inputs[0], alpha, output_scale, output_zero_point)
+
+ def _impl(inputs, _):
+ assert len(inputs) == 7, "Input quant params not found in op inputs"
+ if fp32_piggy_back:
+ return _impl_fp32(inputs, _)
+ return _impl_int8(inputs, _)
+
return _impl