You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2021/11/07 11:29:55 UTC
[tvm] branch main updated: Support quantised SQRT operator in
TFLite (#9258)
This is an automated email from the ASF dual-hosted git repository.
leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 01141d4 Support quantised SQRT operator in TFLite (#9258)
01141d4 is described below
commit 01141d4acf913e946d774c728f4d42f3702040ba
Author: Elinx Hsi <xi...@163.com>
AuthorDate: Sun Nov 7 19:29:14 2021 +0800
Support quantised SQRT operator in TFLite (#9258)
---
python/tvm/relay/frontend/tflite.py | 2 --
tests/python/frontend/tflite/test_forward.py | 53 +++++++++++++++++++++-------
2 files changed, 41 insertions(+), 14 deletions(-)
diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py
index 97382ff..5da6fd2 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -1191,8 +1191,6 @@ class OperatorConverter(object):
def convert_sqrt(self, op):
"""Convert TFLite SQRT"""
- if self.is_quantized(op):
- raise tvm.error.OpNotImplemented("TFlite quantized SQRT operator is not supported yet.")
return self._convert_unary_elemwise(_op.sqrt, op)
def convert_rsqrt(self, op):
diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py
index f8a603c..8bedb23 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -1851,16 +1851,6 @@ def _test_tan(data):
#######################################################################
-# Sqrt
-# ----
-
-
-def _test_sqrt(data):
- """One iteration of sqrt"""
- return _test_unary_elemwise(math_ops.sqrt, data)
-
-
-#######################################################################
# Square
# ------
@@ -1882,7 +1872,7 @@ def _test_elu(data):
def _test_forward_unary_elemwise(test_op):
# functions that need positive input
- if test_op.__name__ in {"_test_log", "_test_sqrt"}:
+ if test_op.__name__ in {"_test_log"}:
test_op(np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)))
else:
test_op(np.random.uniform(-10, 10, (3, 2)).astype(np.float32))
@@ -1893,7 +1883,6 @@ def test_all_unary_elemwise():
_test_forward_unary_elemwise(_test_exp)
_test_forward_unary_elemwise(_test_log)
_test_forward_unary_elemwise(_test_sin)
- _test_forward_unary_elemwise(_test_sqrt)
_test_forward_unary_elemwise(_test_square)
# ceil and cos come with TFLite 1.14.0.post1 fbs schema
if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
@@ -3361,6 +3350,45 @@ def test_forward_rsqrt():
#######################################################################
+# SQRT
+# ----
+
+
+def _test_sqrt(data, quantized=False):
+ """One iteration of SQRT"""
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0")
+
+ if quantized:
+ inq_data = tf.quantization.fake_quant_with_min_max_args(
+ in_data, min=1, max=6, name="inq_0"
+ )
+ input_range = {"inq_0": (1, 6)}
+ out = math_ops.sqrt(inq_data)
+ out = tf.quantization.fake_quant_with_min_max_args(out, min=1, max=6, name="out")
+ compare_tflite_with_tvm(
+ data,
+ "inq_0:0",
+ [inq_data],
+ [out],
+ quantized=True,
+ input_range=input_range,
+ experimental_new_converter=True,
+ )
+ else:
+ out = math_ops.sqrt(in_data)
+ compare_tflite_with_tvm(data, "in_0:0", [in_data], [out])
+
+
+def test_forward_sqrt():
+ """SQRT"""
+ _test_sqrt(np.arange(1.0, 7.0, dtype=np.float32), quantized=False)
+ _test_sqrt(np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)), quantized=False)
+ _test_sqrt(np.arange(1, 240, 40, dtype=np.uint8), quantized=True)
+ _test_sqrt(np.arange(1, 240, 40, dtype=np.uint8).reshape((2, 1, 3)), quantized=True)
+
+
+#######################################################################
# NEG
# ----
@@ -4742,6 +4770,7 @@ if __name__ == "__main__":
test_forward_rsqrt()
test_forward_neg()
test_forward_abs()
+ test_forward_sqrt()
test_forward_relu()
test_forward_relu6()
test_forward_leaky_relu()