You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/09/16 23:58:16 UTC

[tvm] branch aluo/rebase-08312022-autotensorization-fq2i-changes updated: pattern matching

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch aluo/rebase-08312022-autotensorization-fq2i-changes
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/aluo/rebase-08312022-autotensorization-fq2i-changes by this push:
     new 72373ea46c pattern matching
72373ea46c is described below

commit 72373ea46cd9519d104c177e890b8e5082825f0a
Author: Andrew Zhao Luo <an...@gmail.com>
AuthorDate: Fri Sep 16 16:58:09 2022 -0700

    pattern matching
---
 python/tvm/relay/qnn/transform.py                  | 78 ++++++++++++++++++++++
 .../transform/fake_quantization_to_integer.py      |  2 +-
 2 files changed, 79 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/qnn/transform.py b/python/tvm/relay/qnn/transform.py
index 0485cecb99..7b42942c8b 100644
--- a/python/tvm/relay/qnn/transform.py
+++ b/python/tvm/relay/qnn/transform.py
@@ -114,3 +114,81 @@ def Legalize():
     """
 
     return relay.transform.Legalize("FTVMQnnLegalize")
+
+
+from tvm.relay.dataflow_pattern import (
+    DFPatternCallback,
+    is_constant,
+    is_expr,
+    is_op,
+    rewrite,
+    wildcard,
+)
+
+
+class RSqrtPattern(DFPatternCallback):
+    """
+    Rewrites QNN RSQRT Pattern
+    """
+
+    def __init__(self):
+        super(RSqrtPattern, self).__init__()
+
+        self.sqrt_data = wildcard()
+        self.sqrt_data_input_scale = wildcard()
+        self.sqrt_data_input_zp = wildcard()
+
+        self.numerator = wildcard()
+        self.numerator_scale = wildcard()
+        self.numerator_zp = wildcard()
+
+        self.output_scale = wildcard()
+        self.output_zp = wildcard()
+
+        self.sqrt = is_op("qnn.sqrt")(
+            self.sqrt_data,
+            self.sqrt_data_input_scale,
+            self.sqrt_data_input_zp,
+            wildcard(),
+            wildcard(),
+        )
+
+        # TODO: match axis properly
+        self.rsqrt = is_op("qnn.div")(
+            self.numerator,
+            self.sqrt,
+            self.numerator_scale,
+            self.numerator_zp,
+            wildcard(),
+            wildcard(),
+            self.output_scale,
+            self.output_zp,
+        )
+
+        self.pattern = self.rsqrt
+
+    def callback(self, pre, post, node_map):
+        sqrt_data = node_map[self.sqrt_data][0]
+        sqrt_data_scale = node_map[self.sqrt_data_input_scale][0]
+        sqrt_data_zp = node_map[self.sqrt_data_input_zp][0]
+
+        numerator = node_map[self.numerator][0]
+        numerator_scale = node_map[self.numerator_scale][0]
+        numerator_zp = node_map[self.numerator_zp][0]
+
+        output_scale = node_map[self.output_scale][0]
+        output_zp = node_map[self.output_zp][0]
+
+        rsqrt = relay.qnn.op.rsqrt(
+            sqrt_data, sqrt_data_scale, sqrt_data_zp, numerator_scale, numerator_zp
+        )
+        return relay.qnn.op.mul(
+            numerator,
+            rsqrt,
+            numerator_scale,
+            numerator_zp,
+            numerator_scale,
+            numerator_zp,
+            output_scale,
+            output_zp,
+        )
diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py
index 5b6845bd63..0b31474eb4 100644
--- a/python/tvm/relay/transform/fake_quantization_to_integer.py
+++ b/python/tvm/relay/transform/fake_quantization_to_integer.py
@@ -664,4 +664,4 @@ register_unary_qnn("sigmoid", relay.qnn.op.sigmoid)
 register_unary_qnn("hardswish", relay.qnn.op.hardswish)
 register_unary_qnn("tanh", relay.qnn.op.tanh)
 register_unary_qnn("abs", relay.qnn.op.abs)
-register_unary_qnn("log", relay.qnn.op.log)
\ No newline at end of file
+register_unary_qnn("log", relay.qnn.op.log)