You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2020/07/13 23:09:05 UTC

[incubator-tvm] branch master updated: Add support for tflite arg_min and arg_max (#5992)

This is an automated email from the ASF dual-hosted git repository.

anijain2305 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 712c82f  Add support for tflite arg_min and arg_max (#5992)
712c82f is described below

commit 712c82fb38ec2beea5a72662fb00899ab9bc0a08
Author: Dmitriy Smirnov <sm...@gmail.com>
AuthorDate: Tue Jul 14 00:08:56 2020 +0100

    Add support for tflite arg_min and arg_max (#5992)
    
    * [Relay][Frontend][TFLite] Add parser support for arg_min_max
    
    * this implementation supports only the case when the axis is a scalar
    * tflite 1.13 removes all dims of size 1, Relay doesn't do this
    * WARNING: every newer version of tflite > 1.13 needs keepdims=TRUE
    
    * Migrated to tflite 2.1.0
    
    keepdims set to False and added some checks
    
    Note the unit tests emmitted following warning:
    /workspace/src/te/schedule/bound.cc:119: not in feed graph consumer = compute(T_multiply_red_temp, 0x53f5050)
    
    * linter
    
    * Removed quantized argmin
    
    Removed quantized argmin due to inablility to provide proper test case
    
    * added negative ranges
    
    * re-trigger CI
    
    Co-authored-by: Ina_Dobreva <In...@arm.com>
---
 python/tvm/relay/frontend/tflite.py          | 50 ++++++++++++++++++++++++++++
 tests/python/frontend/tflite/test_forward.py | 34 +++++++++++++++++++
 2 files changed, 84 insertions(+)

diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py
index 36221b7..1ec8237 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -67,6 +67,8 @@ class OperatorConverter(object):
             'ABS': self.convert_abs,
             'ADD': self.convert_add,
             'ADD_N': self.convert_add_n,
+            'ARG_MAX': self.convert_arg_max,
+            'ARG_MIN': self.convert_arg_min,
             'AVERAGE_POOL_2D': self.convert_average_pool2d,
             'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
             'CAST': self.convert_cast,
@@ -1634,6 +1636,54 @@ class OperatorConverter(object):
     def convert_reduce_any(self, op):
         return self._convert_reduce(_op.reduce.any, op)
 
+    def _convert_arg_min_max(self, relay_op, op):
+        """Generic method converting TFLite arg_min_max"""
+        try:
+            from tflite.BuiltinOptions import BuiltinOptions
+            from tflite.ArgMinOptions import ArgMinOptions
+            from tflite.ArgMaxOptions import ArgMaxOptions
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "two input tensor arguments expected"
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "one output tensor expected"
+
+        input_tensor = input_tensors[0]
+        in_expr = self.get_expr(input_tensor.tensor_idx)
+        axis_tensor = input_tensors[1]
+        # In Tensorflow, `axis` argument is a Tensor, not attribute. We
+        # support the case where it inputs from a scalar constant.
+        axis_value = self.get_tensor_value(axis_tensor)
+        assert axis_value.size == 1
+        axis_value = axis_value.item()
+
+        if op.BuiltinOptionsType() == BuiltinOptions.ArgMinOptions:
+            arg_min_max_options = ArgMinOptions()
+        elif op.BuiltinOptionsType() == BuiltinOptions.ArgMaxOptions:
+            arg_min_max_options = ArgMaxOptions()
+        op_options = op.BuiltinOptions()
+        arg_min_max_options.Init(op_options.Bytes, op_options.Pos)
+
+        # set keepdims to True since tflite 1.13 removes all dims of size 1
+        # WARNING: all other versions of tflite > 1.13 need keepdims=False
+        out = relay_op(in_expr, axis=axis_value, keepdims=False, exclude=False)
+
+        return out
+
+    def convert_arg_min(self, op):
+        """Convert TFLite ARG_MIN"""
+        if self.is_quantized(op):
+            raise tvm.error.OpNotImplemented(
+                'TFlite quantized ARG_MIN operator is not supported yet.')
+        return self._convert_arg_min_max(_op.argmin, op)
+
+    def convert_arg_max(self, op):
+        """Convert TFLite ARG_MAX"""
+        return self._convert_arg_min_max(_op.argmax, op)
+
     def convert_fully_connected(self, op):
         """Convert TFLite fully connected"""
         try:
diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py
index 52491b2..5118467 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -1755,6 +1755,39 @@ def test_all_reduce():
     if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
         _test_forward_reduce(_test_reduce_any, dtype="bool")
 
+#######################################################################
+# Arg_min_max
+# -----------
+
+def _test_arg_min_max(math_op, data, axis, quantized=False):
+    """ One iteration of arg_min_max"""
+
+    with tf.Graph().as_default():
+        t_name="in"
+        in_data = array_ops.placeholder(shape=data.shape, dtype=np.float32, name=t_name )
+        input_range=None
+        qmin, qmax = -100, 102
+        if quantized:
+            inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=qmin, max=qmax, name= 'q' + t_name )
+            input_range = { inq_data.name.split(':')[0]: (qmin, qmax)}
+            out = math_op(input=inq_data, axis=axis)
+            compare_tflite_with_tvm([data], [inq_data.name], [inq_data], [out], quantized=True, input_range=input_range)
+        else:
+            out = math_op(input=in_data, axis=axis)
+            compare_tflite_with_tvm([data], [in_data.name], [in_data], [out])
+
+def test_forward_arg_min_max():
+    # test quantized
+    for data in [np.array(np.random.uniform(-100, 100, (3, 4)), dtype=np.uint8)]:
+        # There is no quantized version of ArgMin
+        for axis in [None, 0, 1, -1]:
+            _test_arg_min_max(math_ops.argmax, data, axis, True)
+
+    for data in [np.array(np.random.uniform(-100, 100, (3, 4)), dtype=np.float32)]:
+        for axis in [None, 0, 1, -1]:
+            _test_arg_min_max(math_ops.argmax, data, axis)
+            _test_arg_min_max(math_ops.argmin, data, axis)
+
 
 #######################################################################
 # Select, Where
@@ -2834,6 +2867,7 @@ if __name__ == '__main__':
     test_forward_sparse_to_dense()
     test_forward_select()
     test_forward_quantize_dequantize()
+    test_forward_arg_min_max()
 
     # NN
     test_forward_convolution()