You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2020/07/13 23:09:05 UTC
[incubator-tvm] branch master updated: Add support for tflite
arg_min and arg_max (#5992)
This is an automated email from the ASF dual-hosted git repository.
anijain2305 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 712c82f Add support for tflite arg_min and arg_max (#5992)
712c82f is described below
commit 712c82fb38ec2beea5a72662fb00899ab9bc0a08
Author: Dmitriy Smirnov <sm...@gmail.com>
AuthorDate: Tue Jul 14 00:08:56 2020 +0100
Add support for tflite arg_min and arg_max (#5992)
* [Relay][Frontend][TFLite] Add parser support for arg_min_max
* this implementation supports only the case when the axis is a scalar
* tflite 1.13 removes all dims of size 1, Relay doesn't do this
* WARNING: every newer version of tflite > 1.13 needs keepdims=TRUE
* Migrated to tflite 2.1.0
keepdims set to False and added some checks
Note the unit tests emmitted following warning:
/workspace/src/te/schedule/bound.cc:119: not in feed graph consumer = compute(T_multiply_red_temp, 0x53f5050)
* linter
* Removed quantized argmin
Removed quantized argmin due to inablility to provide proper test case
* added negative ranges
* re-trigger CI
Co-authored-by: Ina_Dobreva <In...@arm.com>
---
python/tvm/relay/frontend/tflite.py | 50 ++++++++++++++++++++++++++++
tests/python/frontend/tflite/test_forward.py | 34 +++++++++++++++++++
2 files changed, 84 insertions(+)
diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py
index 36221b7..1ec8237 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -67,6 +67,8 @@ class OperatorConverter(object):
'ABS': self.convert_abs,
'ADD': self.convert_add,
'ADD_N': self.convert_add_n,
+ 'ARG_MAX': self.convert_arg_max,
+ 'ARG_MIN': self.convert_arg_min,
'AVERAGE_POOL_2D': self.convert_average_pool2d,
'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
'CAST': self.convert_cast,
@@ -1634,6 +1636,54 @@ class OperatorConverter(object):
def convert_reduce_any(self, op):
return self._convert_reduce(_op.reduce.any, op)
+ def _convert_arg_min_max(self, relay_op, op):
+ """Generic method converting TFLite arg_min_max"""
+ try:
+ from tflite.BuiltinOptions import BuiltinOptions
+ from tflite.ArgMinOptions import ArgMinOptions
+ from tflite.ArgMaxOptions import ArgMaxOptions
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "two input tensor arguments expected"
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) == 1, "one output tensor expected"
+
+ input_tensor = input_tensors[0]
+ in_expr = self.get_expr(input_tensor.tensor_idx)
+ axis_tensor = input_tensors[1]
+ # In Tensorflow, `axis` argument is a Tensor, not attribute. We
+ # support the case where it inputs from a scalar constant.
+ axis_value = self.get_tensor_value(axis_tensor)
+ assert axis_value.size == 1
+ axis_value = axis_value.item()
+
+ if op.BuiltinOptionsType() == BuiltinOptions.ArgMinOptions:
+ arg_min_max_options = ArgMinOptions()
+ elif op.BuiltinOptionsType() == BuiltinOptions.ArgMaxOptions:
+ arg_min_max_options = ArgMaxOptions()
+ op_options = op.BuiltinOptions()
+ arg_min_max_options.Init(op_options.Bytes, op_options.Pos)
+
+ # set keepdims to True since tflite 1.13 removes all dims of size 1
+ # WARNING: all other versions of tflite > 1.13 need keepdims=False
+ out = relay_op(in_expr, axis=axis_value, keepdims=False, exclude=False)
+
+ return out
+
+ def convert_arg_min(self, op):
+ """Convert TFLite ARG_MIN"""
+ if self.is_quantized(op):
+ raise tvm.error.OpNotImplemented(
+ 'TFlite quantized ARG_MIN operator is not supported yet.')
+ return self._convert_arg_min_max(_op.argmin, op)
+
+ def convert_arg_max(self, op):
+ """Convert TFLite ARG_MAX"""
+ return self._convert_arg_min_max(_op.argmax, op)
+
def convert_fully_connected(self, op):
"""Convert TFLite fully connected"""
try:
diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py
index 52491b2..5118467 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -1755,6 +1755,39 @@ def test_all_reduce():
if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
_test_forward_reduce(_test_reduce_any, dtype="bool")
+#######################################################################
+# Arg_min_max
+# -----------
+
+def _test_arg_min_max(math_op, data, axis, quantized=False):
+ """ One iteration of arg_min_max"""
+
+ with tf.Graph().as_default():
+ t_name="in"
+ in_data = array_ops.placeholder(shape=data.shape, dtype=np.float32, name=t_name )
+ input_range=None
+ qmin, qmax = -100, 102
+ if quantized:
+ inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=qmin, max=qmax, name= 'q' + t_name )
+ input_range = { inq_data.name.split(':')[0]: (qmin, qmax)}
+ out = math_op(input=inq_data, axis=axis)
+ compare_tflite_with_tvm([data], [inq_data.name], [inq_data], [out], quantized=True, input_range=input_range)
+ else:
+ out = math_op(input=in_data, axis=axis)
+ compare_tflite_with_tvm([data], [in_data.name], [in_data], [out])
+
+def test_forward_arg_min_max():
+ # test quantized
+ for data in [np.array(np.random.uniform(-100, 100, (3, 4)), dtype=np.uint8)]:
+ # There is no quantized version of ArgMin
+ for axis in [None, 0, 1, -1]:
+ _test_arg_min_max(math_ops.argmax, data, axis, True)
+
+ for data in [np.array(np.random.uniform(-100, 100, (3, 4)), dtype=np.float32)]:
+ for axis in [None, 0, 1, -1]:
+ _test_arg_min_max(math_ops.argmax, data, axis)
+ _test_arg_min_max(math_ops.argmin, data, axis)
+
#######################################################################
# Select, Where
@@ -2834,6 +2867,7 @@ if __name__ == '__main__':
test_forward_sparse_to_dense()
test_forward_select()
test_forward_quantize_dequantize()
+ test_forward_arg_min_max()
# NN
test_forward_convolution()