You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/04 18:41:45 UTC

[GitHub] [incubator-tvm] siju-samuel commented on a change in pull request #5992: Add support for tflite arg_min and arg_max

siju-samuel commented on a change in pull request #5992:
URL: https://github.com/apache/incubator-tvm/pull/5992#discussion_r449795348



##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -157,6 +157,8 @@ def __init__(self, model, subgraph, exp_tab):
             'UNPACK': self.convert_unpack,
             'WHERE': self.convert_select,
             'ZEROS_LIKE': self.convert_zeros_like,
+            'ARG_MIN': self.convert_arg_min,
+            'ARG_MAX': self.convert_arg_max,

Review comment:
       Please arrange alphabetically similar to other ops.

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1634,6 +1636,56 @@ def convert_reduce_sum(self, op):
     def convert_reduce_any(self, op):
         return self._convert_reduce(_op.reduce.any, op)
 
+    def _convert_arg_min_max(self, relay_op, op):
+        """Generic method to convert TFLite arg_min_max"""
+        try:
+            from tflite.Operator import Operator
+            from tflite.BuiltinOptions import BuiltinOptions
+            from tflite.ArgMinOptions import ArgMinOptions
+            from tflite.ArgMaxOptions import ArgMaxOptions
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        assert isinstance(op, Operator)

Review comment:
       Already checked in `convert_op_to_relay` can remove this assert

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1634,6 +1636,56 @@ def convert_reduce_sum(self, op):
     def convert_reduce_any(self, op):
         return self._convert_reduce(_op.reduce.any, op)
 
+    def _convert_arg_min_max(self, relay_op, op):
+        """Generic method to convert TFLite arg_min_max"""
+        try:
+            from tflite.Operator import Operator
+            from tflite.BuiltinOptions import BuiltinOptions
+            from tflite.ArgMinOptions import ArgMinOptions
+            from tflite.ArgMaxOptions import ArgMaxOptions
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        assert isinstance(op, Operator)
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "two input tensor arguments expected"
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "one output tensor expected"

Review comment:
       add check for condition `if self.is_quantized(op):` similar to other ops
   

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1634,6 +1636,56 @@ def convert_reduce_sum(self, op):
     def convert_reduce_any(self, op):
         return self._convert_reduce(_op.reduce.any, op)
 
+    def _convert_arg_min_max(self, relay_op, op):
+        """Generic method to convert TFLite arg_min_max"""
+        try:
+            from tflite.Operator import Operator

Review comment:
       Remove this. not required

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1634,6 +1636,56 @@ def convert_reduce_sum(self, op):
     def convert_reduce_any(self, op):
         return self._convert_reduce(_op.reduce.any, op)
 
+    def _convert_arg_min_max(self, relay_op, op):
+        """Generic method to convert TFLite arg_min_max"""
+        try:
+            from tflite.Operator import Operator
+            from tflite.BuiltinOptions import BuiltinOptions
+            from tflite.ArgMinOptions import ArgMinOptions
+            from tflite.ArgMaxOptions import ArgMaxOptions
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        assert isinstance(op, Operator)
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "two input tensor arguments expected"
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "one output tensor expected"
+
+        input_tensor = input_tensors[0]
+        in_expr = self.get_expr(input_tensor.tensor_idx)
+        axis_tensor = input_tensors[1]
+        # In Tensorflow, `axis` argument is a Tensor, not attribute. We
+        # support the case where it inputs from a scalar constant.
+        axis_value = self.get_tensor_value(axis_tensor)
+        assert axis_value.size == 1
+        axis_value = axis_value.item()
+
+        if op.BuiltinOptionsType() == BuiltinOptions.ArgMinOptions:
+            arg_min_max_options = ArgMinOptions()
+        elif op.BuiltinOptionsType() == BuiltinOptions.ArgMaxOptions:
+            arg_min_max_options = ArgMaxOptions()
+        op_options = op.BuiltinOptions()
+        arg_min_max_options.Init(op_options.Bytes, op_options.Pos)
+        output_dtype = arg_min_max_options.OutputType()
+
+        # set keepdims to True since tflite 1.13 removes all dims of size 1
+        # WARNING: all other versions of tflite > 1.13 need keepdims=False
+        out = relay_op(in_expr, axis=axis_value, keepdims=False, exclude=False)
+        # cast the output indices to the desired data type
+        casted_output = _op.cast(out, self.get_tensor_type_str(output_dtype))
+
+        return casted_output

Review comment:
       Why need this extra casting?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org