You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/09/16 23:58:16 UTC
[tvm] branch aluo/rebase-08312022-autotensorization-fq2i-changes updated: pattern matching
This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch aluo/rebase-08312022-autotensorization-fq2i-changes
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/aluo/rebase-08312022-autotensorization-fq2i-changes by this push:
new 72373ea46c pattern matching
72373ea46c is described below
commit 72373ea46cd9519d104c177e890b8e5082825f0a
Author: Andrew Zhao Luo <an...@gmail.com>
AuthorDate: Fri Sep 16 16:58:09 2022 -0700
pattern matching
---
python/tvm/relay/qnn/transform.py | 78 ++++++++++++++++++++++
.../transform/fake_quantization_to_integer.py | 2 +-
2 files changed, 79 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relay/qnn/transform.py b/python/tvm/relay/qnn/transform.py
index 0485cecb99..7b42942c8b 100644
--- a/python/tvm/relay/qnn/transform.py
+++ b/python/tvm/relay/qnn/transform.py
@@ -114,3 +114,81 @@ def Legalize():
"""
return relay.transform.Legalize("FTVMQnnLegalize")
+
+
+from tvm.relay.dataflow_pattern import (
+ DFPatternCallback,
+ is_constant,
+ is_expr,
+ is_op,
+ rewrite,
+ wildcard,
+)
+
+
+class RSqrtPattern(DFPatternCallback):
+ """
+ Rewrites QNN RSQRT Pattern
+ """
+
+ def __init__(self):
+ super(RSqrtPattern, self).__init__()
+
+ self.sqrt_data = wildcard()
+ self.sqrt_data_input_scale = wildcard()
+ self.sqrt_data_input_zp = wildcard()
+
+ self.numerator = wildcard()
+ self.numerator_scale = wildcard()
+ self.numerator_zp = wildcard()
+
+ self.output_scale = wildcard()
+ self.output_zp = wildcard()
+
+ self.sqrt = is_op("qnn.sqrt")(
+ self.sqrt_data,
+ self.sqrt_data_input_scale,
+ self.sqrt_data_input_zp,
+ wildcard(),
+ wildcard(),
+ )
+
+ # TODO: match axis properly
+ self.rsqrt = is_op("qnn.div")(
+ self.numerator,
+ self.sqrt,
+ self.numerator_scale,
+ self.numerator_zp,
+ wildcard(),
+ wildcard(),
+ self.output_scale,
+ self.output_zp,
+ )
+
+ self.pattern = self.rsqrt
+
+ def callback(self, pre, post, node_map):
+ sqrt_data = node_map[self.sqrt_data][0]
+ sqrt_data_scale = node_map[self.sqrt_data_input_scale][0]
+ sqrt_data_zp = node_map[self.sqrt_data_input_zp][0]
+
+ numerator = node_map[self.numerator][0]
+ numerator_scale = node_map[self.numerator_scale][0]
+ numerator_zp = node_map[self.numerator_zp][0]
+
+ output_scale = node_map[self.output_scale][0]
+ output_zp = node_map[self.output_zp][0]
+
+ rsqrt = relay.qnn.op.rsqrt(
+ sqrt_data, sqrt_data_scale, sqrt_data_zp, numerator_scale, numerator_zp
+ )
+ return relay.qnn.op.mul(
+ numerator,
+ rsqrt,
+ numerator_scale,
+ numerator_zp,
+ numerator_scale,
+ numerator_zp,
+ output_scale,
+ output_zp,
+ )
diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py
index 5b6845bd63..0b31474eb4 100644
--- a/python/tvm/relay/transform/fake_quantization_to_integer.py
+++ b/python/tvm/relay/transform/fake_quantization_to_integer.py
@@ -664,4 +664,4 @@ register_unary_qnn("sigmoid", relay.qnn.op.sigmoid)
register_unary_qnn("hardswish", relay.qnn.op.hardswish)
register_unary_qnn("tanh", relay.qnn.op.tanh)
register_unary_qnn("abs", relay.qnn.op.abs)
-register_unary_qnn("log", relay.qnn.op.log)
\ No newline at end of file
+register_unary_qnn("log", relay.qnn.op.log)