You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2021/11/07 11:29:55 UTC

[tvm] branch main updated: Support quantised SQRT operator in TFLite (#9258)

This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 01141d4  Support quantised SQRT operator in TFLite (#9258)
01141d4 is described below

commit 01141d4acf913e946d774c728f4d42f3702040ba
Author: Elinx Hsi <xi...@163.com>
AuthorDate: Sun Nov 7 19:29:14 2021 +0800

    Support quantised SQRT operator in TFLite (#9258)
---
 python/tvm/relay/frontend/tflite.py          |  2 --
 tests/python/frontend/tflite/test_forward.py | 53 +++++++++++++++++++++-------
 2 files changed, 41 insertions(+), 14 deletions(-)

diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py
index 97382ff..5da6fd2 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -1191,8 +1191,6 @@ class OperatorConverter(object):
 
     def convert_sqrt(self, op):
         """Convert TFLite SQRT"""
-        if self.is_quantized(op):
-            raise tvm.error.OpNotImplemented("TFlite quantized SQRT operator is not supported yet.")
         return self._convert_unary_elemwise(_op.sqrt, op)
 
     def convert_rsqrt(self, op):
diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py
index f8a603c..8bedb23 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -1851,16 +1851,6 @@ def _test_tan(data):
 
 
 #######################################################################
-# Sqrt
-# ----
-
-
-def _test_sqrt(data):
-    """One iteration of sqrt"""
-    return _test_unary_elemwise(math_ops.sqrt, data)
-
-
-#######################################################################
 # Square
 # ------
 
@@ -1882,7 +1872,7 @@ def _test_elu(data):
 
 def _test_forward_unary_elemwise(test_op):
     # functions that need positive input
-    if test_op.__name__ in {"_test_log", "_test_sqrt"}:
+    if test_op.__name__ in {"_test_log"}:
         test_op(np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)))
     else:
         test_op(np.random.uniform(-10, 10, (3, 2)).astype(np.float32))
@@ -1893,7 +1883,6 @@ def test_all_unary_elemwise():
     _test_forward_unary_elemwise(_test_exp)
     _test_forward_unary_elemwise(_test_log)
     _test_forward_unary_elemwise(_test_sin)
-    _test_forward_unary_elemwise(_test_sqrt)
     _test_forward_unary_elemwise(_test_square)
     # ceil and cos come with TFLite 1.14.0.post1 fbs schema
     if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
@@ -3361,6 +3350,45 @@ def test_forward_rsqrt():
 
 
 #######################################################################
+# SQRT
+# ----
+
+
+def _test_sqrt(data, quantized=False):
+    """One iteration of SQRT"""
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0")
+
+        if quantized:
+            inq_data = tf.quantization.fake_quant_with_min_max_args(
+                in_data, min=1, max=6, name="inq_0"
+            )
+            input_range = {"inq_0": (1, 6)}
+            out = math_ops.sqrt(inq_data)
+            out = tf.quantization.fake_quant_with_min_max_args(out, min=1, max=6, name="out")
+            compare_tflite_with_tvm(
+                data,
+                "inq_0:0",
+                [inq_data],
+                [out],
+                quantized=True,
+                input_range=input_range,
+                experimental_new_converter=True,
+            )
+        else:
+            out = math_ops.sqrt(in_data)
+            compare_tflite_with_tvm(data, "in_0:0", [in_data], [out])
+
+
+def test_forward_sqrt():
+    """SQRT"""
+    _test_sqrt(np.arange(1.0, 7.0, dtype=np.float32), quantized=False)
+    _test_sqrt(np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)), quantized=False)
+    _test_sqrt(np.arange(1, 240, 40, dtype=np.uint8), quantized=True)
+    _test_sqrt(np.arange(1, 240, 40, dtype=np.uint8).reshape((2, 1, 3)), quantized=True)
+
+
+#######################################################################
 # NEG
 # ----
 
@@ -4742,6 +4770,7 @@ if __name__ == "__main__":
     test_forward_rsqrt()
     test_forward_neg()
     test_forward_abs()
+    test_forward_sqrt()
     test_forward_relu()
     test_forward_relu6()
     test_forward_leaky_relu()