You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/07/12 16:57:59 UTC

[tvm] branch main updated: [QNN] Replace nn.leaky_relu with qnn.leaky_relu (#11930)

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6d676badff [QNN] Replace nn.leaky_relu with qnn.leaky_relu (#11930)
6d676badff is described below

commit 6d676badff499a3b87fb47370f2f0d1d1318e8ed
Author: zhaoyang-star <zh...@foxmail.com>
AuthorDate: Wed Jul 13 00:57:54 2022 +0800

    [QNN] Replace nn.leaky_relu with qnn.leaky_relu (#11930)
    
    * [QNN] Replace nn.leaky_relu with qnn.leaky_relu
    
    * jostle ci
    
    * fix typo
---
 python/tvm/relay/frontend/qnn_torch.py | 17 ++++++++++++++---
 1 file changed, 14 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py
index 0485a993ac..824d3bbe64 100644
--- a/python/tvm/relay/frontend/qnn_torch.py
+++ b/python/tvm/relay/frontend/qnn_torch.py
@@ -937,10 +937,9 @@ def _relu6():
     return _impl
 
 
-def _leaky_relu():
+def _leaky_relu(fp32_piggy_back=False):
     # refer to src/ATen/native/quantized/cpu/qrelu.cpp
-    def _impl(inputs, _):
-        assert len(inputs) == 7, "Input quant params not found in op inputs"
+    def _impl_fp32(inputs, _):
         alpha = inputs[1]
         output_scale = _expr.const(inputs[3])
         output_zero_point = _expr.const(inputs[4])
@@ -952,6 +951,18 @@ def _leaky_relu():
             dequantized, output_scale, output_zero_point, out_dtype="uint8"
         )
 
+    def _impl_int8(inputs, _):
+        alpha = inputs[1]
+        output_scale = _expr.const(inputs[3])
+        output_zero_point = _expr.const(inputs[4])
+        return relay.qnn.op.leaky_relu(inputs[0], alpha, output_scale, output_zero_point)
+
+    def _impl(inputs, _):
+        assert len(inputs) == 7, "Input quant params not found in op inputs"
+        if fp32_piggy_back:
+            return _impl_fp32(inputs, _)
+        return _impl_int8(inputs, _)
+
     return _impl