You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2019/12/18 17:02:52 UTC

[GitHub] [incubator-tvm] mbarrett97 opened a new pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

mbarrett97 opened a new pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543
 
 
   This adds support for the custom operator TFLite_Detection_PostProcess which is commonly used in object detection networks such as SSD Mobilenet. It only adds support for when use_regular_nms = False.
   
   This implementation makes use of the existing multibox_transform_loc and non_max_suppression operators. Their design is closely coupled to the mxnet version of this custom operator, so a number of transformations have to be added before and after to get them to accept the tflite format for this operator.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r366438454
 
 

 ##########
 File path: tests/python/frontend/tflite/test_forward.py
 ##########
 @@ -1113,6 +1113,49 @@ def test_forward_fully_connected():
     _test_fully_connected([5, 1, 1, 150], [150, 100], [100])
 
 
+#######################################################################
+# Custom Operators
+# -------
+
+def test_detection_postprocess():
+    tf_model_file = tf_testing.get_workload_official(
+        "http://download.tensorflow.org/models/object_detection/"
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
 
 Review comment:
   do you mean the issue of quantized rounding here? https://github.com/apache/incubator-tvm/pull/3900#discussion_r334334418

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] yongwww commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
yongwww commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r363440816
 
 

 ##########
 File path: python/tvm/relay/frontend/tflite.py
 ##########
 @@ -1494,6 +1499,112 @@ def convert_transpose_conv(self, op):
 
         return out
 
+    def _convert_detection_postprocess(self, op):
+        """Convert TFLite_Detection_PostProcess"""
+        _option_names = [
+            "w_scale",
+            "max_detections",
+            "_output_quantized",
+            "detections_per_class",
+            "x_scale",
+            "nms_score_threshold",
+            "num_classes",
+            "max_classes_per_detection",
+            "use_regular_nms",
+            "y_scale",
+            "h_scale",
+            "_support_output_type_float_in_quantized_op",
+            "nms_iou_threshold"
+        ]
+
+        custom_options = get_custom_options(op, _option_names)
+        if custom_options["use_regular_nms"]:
+            raise tvm.error.OpAttributeUnImplemented(
+                "use_regular_nms=True is not yet supported for operator {}."
+                .format("TFLite_Detection_PostProcess")
+            )
+
+        inputs = self.get_input_tensors(op)
+        cls_pred = self.get_expr(inputs[1].tensor_idx)
+        loc_prob = self.get_expr(inputs[0].tensor_idx)
+        anchor_values = self.get_tensor_value(inputs[2])
+        anchor_boxes = len(anchor_values)
+        anchor_type = self.get_tensor_type_str(inputs[2].tensor.Type())
+        anchor_expr = self.exp_tab.new_const(anchor_values, dtype=anchor_type)
+
+        if inputs[0].qnn_params:
+            loc_prob = _qnn.op.dequantize(data=loc_prob,
+                                          input_scale=inputs[0].qnn_params['scale'],
+                                          input_zero_point=inputs[0].qnn_params['zero_point'])
+        if inputs[1].qnn_params:
+            cls_pred = _qnn.op.dequantize(data=cls_pred,
+                                          input_scale=inputs[1].qnn_params['scale'],
+                                          input_zero_point=inputs[1].qnn_params['zero_point'])
+        if inputs[2].qnn_params:
+            anchor_expr = _qnn.op.dequantize(data=anchor_expr,
+                                             input_scale=inputs[2].qnn_params['scale'],
+                                             input_zero_point=inputs[2].qnn_params['zero_point'])
+
+        # reshape the cls_pred and loc_prob tensors so
+        # they can be consumed by multibox_transform_loc
+        cls_pred = _op.transpose(cls_pred, [0, 2, 1])
+        # loc_prob coords are in yxhw format
+        # need to convert to xywh
+        loc_coords = _op.split(loc_prob, 4, axis=2)
+        loc_prob = _op.concatenate(
+            [loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2
+        )
+        loc_prob = _op.reshape(loc_prob, [1, anchor_boxes*4])
+
+        # anchor coords are in yxhw format
+        # need to convert to ltrb
+        anchor_coords = _op.split(anchor_expr, 4, axis=1)
+        anchor_y = anchor_coords[0]
+        anchor_x = anchor_coords[1]
+        anchor_h = anchor_coords[2]
+        anchor_w = anchor_coords[3]
+        plus_half = _expr.const(0.5, dtype='float32')
+        minus_half = _expr.const(-0.5, dtype='float32')
+        anchor_l = _op.add(anchor_x, _op.multiply(anchor_w, minus_half))
+        anchor_r = _op.add(anchor_x, _op.multiply(anchor_w, plus_half))
+        anchor_t = _op.add(anchor_y, _op.multiply(anchor_h, minus_half))
+        anchor_b = _op.add(anchor_y, _op.multiply(anchor_h, plus_half))
+        anchor_expr = _op.concatenate([anchor_l, anchor_t, anchor_r, anchor_b], axis=1)
+        anchor_expr = _op.expand_dims(anchor_expr, 0)
+
+        # attributes for multibox_transform_loc
+        new_attrs0 = {}
+        new_attrs0["clip"] = False
+        new_attrs0["threshold"] = custom_options["nms_score_threshold"]
+        new_attrs0["variances"] = (
+            1/custom_options["x_scale"],
+            1/custom_options["y_scale"],
+            1/custom_options["w_scale"],
+            1/custom_options["h_scale"],
+        )
+
+        # attributes for non_max_suppression
+        new_attrs1 = {}
+        new_attrs1["return_indices"] = False
 
 Review comment:
   Thanks. cool, then relay VM could be used to support dynamism. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r371281677
 
 

 ##########
 File path: tests/python/frontend/tflite/test_forward.py
 ##########
 @@ -1113,6 +1113,49 @@ def test_forward_fully_connected():
     _test_fully_connected([5, 1, 1, 150], [150, 100], [100])
 
 
+#######################################################################
+# Custom Operators
+# -------
+
+def test_detection_postprocess():
+    tf_model_file = tf_testing.get_workload_official(
+        "http://download.tensorflow.org/models/object_detection/"
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
 
 Review comment:
   I think if we could view the TOCO source code, maybe we could find how to construct detection_postprocess. Please refer our `test_prelu` comment. I ever write what the pattern tflite could produce prelu.  However, current way is acceptable too in my opinion. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r362410775
 
 

 ##########
 File path: python/tvm/relay/frontend/tflite.py
 ##########
 @@ -1565,6 +1676,91 @@ def get_tensor_name(subgraph, tensor_idx):
     return subgraph.Tensors(tensor_idx).Name().decode("utf-8")
 
 
+def get_custom_options(op, option_names):
+    """Get the options of a custom operator.
+
+    This implements partial flexbuffer deserialization to be able
+    to read custom options. It is not intended to be a general
+    purpose flexbuffer deserializer and as such only supports a
+    limited number of types and assumes the data is a flat map.
+
+    Parameters
+    ----------
+    op:
+        A custom TFlite operator.
+    option_names: list
+        A complete list of the custom option names.
+
+    Returns
+    -------
+    options: dict
+        A dictionary of the custom options.
+
+    """
+    import struct
+    from enum import IntEnum
+
+    class _FlexBufferType(IntEnum):
+        """Flexbuffer type schema from flexbuffers.h"""
+        FBT_NULL = 0
+        FBT_INT = 1
+        FBT_UINT = 2
+        FBT_FLOAT = 3
+        # Types above stored inline, types below store an offset.
+        FBT_KEY = 4
+        FBT_STRING = 5
+        FBT_INDIRECT_INT = 6
+        FBT_INDIRECT_UINT = 7
+        FBT_INDIRECT_FLOAT = 8
+        FBT_MAP = 9
+        FBT_VECTOR = 10 # Untyped.
+        FBT_VECTOR_INT = 11 # Typed any size (stores no type table).
+        FBT_VECTOR_UINT = 12
+        FBT_VECTOR_FLOAT = 13
+        FBT_VECTOR_KEY = 14
+        FBT_VECTOR_STRING = 15
+        FBT_VECTOR_INT2 = 16 # Typed tuple (no type table, no size field).
+        FBT_VECTOR_UINT2 = 17
+        FBT_VECTOR_FLOAT2 = 18
+        FBT_VECTOR_INT3 = 19 # Typed triple (no type table, no size field).
+        FBT_VECTOR_UINT3 = 20
+        FBT_VECTOR_FLOAT3 = 21
+        FBT_VECTOR_INT4 = 22 # Typed quad (no type table, no size field).
+        FBT_VECTOR_UINT4 = 23
+        FBT_VECTOR_FLOAT4 = 24
+        FBT_BLOB = 25
+        FBT_BOOL = 26
+        FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type
+
+    buffer = op.CustomOptionsAsNumpy().tobytes()
+    value_vector_offset = buffer[-3]
+    buffer = buffer[:-3]
+    num_bytes = 4 # Assume all values are stored in 32 bit width
+    value_vector_size = struct.unpack(
+        "<i", buffer[-value_vector_offset - num_bytes:-value_vector_offset]
+    )[0]
+    type_offset = value_vector_size
+    types = buffer[-type_offset:]
+    values = []
+    for i, t in enumerate(types):
+        flex_type = _FlexBufferType(t >> 2)
+        value_offset = -value_vector_offset + i*num_bytes
+        value_bytes = buffer[value_offset:value_offset+num_bytes]
+        if flex_type == _FlexBufferType.FBT_BOOL:
+            value = True if value_bytes[0] else False
+        if flex_type == _FlexBufferType.FBT_INT:
 
 Review comment:
   This whole function is essentially a quick hack to get around there being no python bindings for flexbuffers (the format custom options are stored in). For the majority of cases I think we only need the types I've implemented, but I've included the full schema in case we need to support additional types for other operators later.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r365486236
 
 

 ##########
 File path: python/tvm/relay/frontend/tflite.py
 ##########
 @@ -98,6 +98,7 @@ def __init__(self, model, subgraph, exp_tab):
             'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
             'PRELU': self.convert_prelu,
             'TRANSPOSE_CONV': self.convert_transpose_conv,
 
 Review comment:
   Please change to `self.convert_detection_postprocess`! We should keep the same code style as other convert function jn the dictionary.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on issue #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on issue #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#issuecomment-579987271
 
 
   @FrozenGene are there any other changes you want?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r372971382
 
 

 ##########
 File path: python/tvm/relay/frontend/tflite.py
 ##########
 @@ -1662,6 +1667,112 @@ def convert_transpose_conv(self, op):
 
         return out
 
+    def convert_detection_postprocess(self, op):
+        """Convert TFLite_Detection_PostProcess"""
+        _option_names = [
+            "w_scale",
+            "max_detections",
+            "_output_quantized",
+            "detections_per_class",
+            "x_scale",
+            "nms_score_threshold",
+            "num_classes",
+            "max_classes_per_detection",
+            "use_regular_nms",
+            "y_scale",
+            "h_scale",
+            "_support_output_type_float_in_quantized_op",
+            "nms_iou_threshold"
+        ]
+
+        custom_options = get_custom_options(op, _option_names)
+        if custom_options["use_regular_nms"]:
+            raise tvm.error.OpAttributeUnImplemented(
+                "use_regular_nms=True is not yet supported for operator {}."
+                .format("TFLite_Detection_PostProcess")
+            )
+
+        inputs = self.get_input_tensors(op)
 
 Review comment:
   Does it make sense adding one assert `assert len(inputs) == 3`

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r365486236
 
 

 ##########
 File path: python/tvm/relay/frontend/tflite.py
 ##########
 @@ -98,6 +98,7 @@ def __init__(self, model, subgraph, exp_tab):
             'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
             'PRELU': self.convert_prelu,
             'TRANSPOSE_CONV': self.convert_transpose_conv,
 
 Review comment:
   Please change to `self.convert_detection_postprocess`! We should keep the same code style as other convert function jn the dictionary.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r365487145
 
 

 ##########
 File path: python/tvm/relay/frontend/tflite.py
 ##########
 @@ -1494,6 +1499,112 @@ def convert_transpose_conv(self, op):
 
         return out
 
+    def _convert_detection_postprocess(self, op):
+        """Convert TFLite_Detection_PostProcess"""
+        _option_names = [
+            "w_scale",
+            "max_detections",
+            "_output_quantized",
+            "detections_per_class",
+            "x_scale",
+            "nms_score_threshold",
+            "num_classes",
+            "max_classes_per_detection",
+            "use_regular_nms",
+            "y_scale",
+            "h_scale",
+            "_support_output_type_float_in_quantized_op",
+            "nms_iou_threshold"
+        ]
+
+        custom_options = get_custom_options(op, _option_names)
+        if custom_options["use_regular_nms"]:
+            raise tvm.error.OpAttributeUnImplemented(
+                "use_regular_nms=True is not yet supported for operator {}."
+                .format("TFLite_Detection_PostProcess")
+            )
+
+        inputs = self.get_input_tensors(op)
+        cls_pred = self.get_expr(inputs[1].tensor_idx)
+        loc_prob = self.get_expr(inputs[0].tensor_idx)
+        anchor_values = self.get_tensor_value(inputs[2])
+        anchor_boxes = len(anchor_values)
+        anchor_type = self.get_tensor_type_str(inputs[2].tensor.Type())
+        anchor_expr = self.exp_tab.new_const(anchor_values, dtype=anchor_type)
+
+        if inputs[0].qnn_params:
+            loc_prob = _qnn.op.dequantize(data=loc_prob,
+                                          input_scale=inputs[0].qnn_params['scale'],
+                                          input_zero_point=inputs[0].qnn_params['zero_point'])
+        if inputs[1].qnn_params:
+            cls_pred = _qnn.op.dequantize(data=cls_pred,
+                                          input_scale=inputs[1].qnn_params['scale'],
+                                          input_zero_point=inputs[1].qnn_params['zero_point'])
+        if inputs[2].qnn_params:
+            anchor_expr = _qnn.op.dequantize(data=anchor_expr,
+                                             input_scale=inputs[2].qnn_params['scale'],
+                                             input_zero_point=inputs[2].qnn_params['zero_point'])
+
+        # reshape the cls_pred and loc_prob tensors so
+        # they can be consumed by multibox_transform_loc
+        cls_pred = _op.transpose(cls_pred, [0, 2, 1])
+        # loc_prob coords are in yxhw format
+        # need to convert to xywh
+        loc_coords = _op.split(loc_prob, 4, axis=2)
+        loc_prob = _op.concatenate(
+            [loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2
+        )
+        loc_prob = _op.reshape(loc_prob, [1, anchor_boxes*4])
+
+        # anchor coords are in yxhw format
+        # need to convert to ltrb
+        anchor_coords = _op.split(anchor_expr, 4, axis=1)
+        anchor_y = anchor_coords[0]
+        anchor_x = anchor_coords[1]
+        anchor_h = anchor_coords[2]
+        anchor_w = anchor_coords[3]
+        plus_half = _expr.const(0.5, dtype='float32')
+        minus_half = _expr.const(-0.5, dtype='float32')
+        anchor_l = _op.add(anchor_x, _op.multiply(anchor_w, minus_half))
+        anchor_r = _op.add(anchor_x, _op.multiply(anchor_w, plus_half))
+        anchor_t = _op.add(anchor_y, _op.multiply(anchor_h, minus_half))
+        anchor_b = _op.add(anchor_y, _op.multiply(anchor_h, plus_half))
+        anchor_expr = _op.concatenate([anchor_l, anchor_t, anchor_r, anchor_b], axis=1)
+        anchor_expr = _op.expand_dims(anchor_expr, 0)
+
+        # attributes for multibox_transform_loc
+        new_attrs0 = {}
+        new_attrs0["clip"] = False
+        new_attrs0["threshold"] = custom_options["nms_score_threshold"]
+        new_attrs0["variances"] = (
+            1 / custom_options["x_scale"],
+            1 / custom_options["y_scale"],
+            1 / custom_options["w_scale"],
+            1 / custom_options["h_scale"],
+        )
+
+        # attributes for non_max_suppression
+        new_attrs1 = {}
 
 Review comment:
   change to `non_max_suppression_attrs`

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] tqchen commented on issue #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
tqchen commented on issue #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#issuecomment-585287746
 
 
   @FrozenGene please https://docs.tvm.ai/contribute/code_review.html#approve-and-request-changes-explicitly

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r371276909
 
 

 ##########
 File path: tests/python/frontend/tflite/test_forward.py
 ##########
 @@ -1113,6 +1113,49 @@ def test_forward_fully_connected():
     _test_fully_connected([5, 1, 1, 150], [150, 100], [100])
 
 
+#######################################################################
+# Custom Operators
+# -------
+
+def test_detection_postprocess():
+    tf_model_file = tf_testing.get_workload_official(
+        "http://download.tensorflow.org/models/object_detection/"
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
 
 Review comment:
   This test is a bit misleading because it doesn't actually run ssd mobilenet, it just test the postprocess op. I couldn't find a way to create the op using the tflite python API, so what I did instead was take a model that has it and then run it through the tflite converter but with the converter inputs set to the inputs of the postprocess op rather than the input to the network.
   
   This has the net effect of producing a single postprocess op, so this should already be a unit test (and it passes). I can add the end-to-end tests if/when we resolve the QNN accuracy issue. I'll open an RFC shortly to describe why rounding is a particularly significant in the case of this operator.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r371868583
 
 

 ##########
 File path: tests/python/frontend/tflite/test_forward.py
 ##########
 @@ -1113,6 +1113,49 @@ def test_forward_fully_connected():
     _test_fully_connected([5, 1, 1, 150], [150, 100], [100])
 
 
+#######################################################################
+# Custom Operators
+# -------
+
+def test_detection_postprocess():
+    tf_model_file = tf_testing.get_workload_official(
+        "http://download.tensorflow.org/models/object_detection/"
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
 
 Review comment:
   I've written a discuss post here: [5528](https://discuss.tvm.ai/t/supporting-bit-exact-tflite-qnn-inference/5528).

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] sjoshi30 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
sjoshi30 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r382285883
 
 

 ##########
 File path: tests/python/frontend/tflite/test_forward.py
 ##########
 @@ -1113,6 +1113,49 @@ def test_forward_fully_connected():
     _test_fully_connected([5, 1, 1, 150], [150, 100], [100])
 
 
+#######################################################################
+# Custom Operators
+# -------
+
+def test_detection_postprocess():
+    tf_model_file = tf_testing.get_workload_official(
+        "http://download.tensorflow.org/models/object_detection/"
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
 
 Review comment:
   @mbaret How did you set converter input as inputs of postprocess op, when I do that it gives me error : 
   tensorflow/lite/toco/model_cmdline_flags.cc:263] Check failed: mean_values.size() == model_flags->input_arrays_size()
   
   The inputs to postprocess op >1 ('raw_outputs/box_encodings','raw_outputs/class_predictions')

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] tqchen commented on issue #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
tqchen commented on issue #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#issuecomment-573238646
 
 
   @mbarrett97 please update as per review comment, @FrozenGene please https://docs.tvm.ai/contribute/code_review.html#approve-and-request-changes-explicitly

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r365486808
 
 

 ##########
 File path: python/tvm/relay/frontend/tflite.py
 ##########
 @@ -98,6 +98,7 @@ def __init__(self, model, subgraph, exp_tab):
             'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
             'PRELU': self.convert_prelu,
             'TRANSPOSE_CONV': self.convert_transpose_conv,
+            'DETECTION_POSTPROCESS': self._convert_detection_postprocess
 
 Review comment:
   Please change to self._convert_detection_postprocess as other convert function's code style. We should keep the same cod style.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r371205476
 
 

 ##########
 File path: tests/python/frontend/tflite/test_forward.py
 ##########
 @@ -1113,6 +1113,49 @@ def test_forward_fully_connected():
     _test_fully_connected([5, 1, 1, 150], [150, 100], [100])
 
 
+#######################################################################
+# Custom Operators
+# -------
+
+def test_detection_postprocess():
+    tf_model_file = tf_testing.get_workload_official(
+        "http://download.tensorflow.org/models/object_detection/"
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
 
 Review comment:
   I can do, but will we continue that as an orthogonal conversation? I'm just clarifying as I don't think that issue affects the correctness of this operator which is already tested by 'test_detection_postprocess'.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] yongwww commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
yongwww commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r360490550
 
 

 ##########
 File path: tests/python/frontend/tflite/test_forward.py
 ##########
 @@ -1113,6 +1113,49 @@ def test_forward_fully_connected():
     _test_fully_connected([5, 1, 1, 150], [150, 100], [100])
 
 
+#######################################################################
+# Custom Operators
+# -------
 
 Review comment:
   ```suggestion
   # ----------------
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r363251705
 
 

 ##########
 File path: tests/python/frontend/tflite/test_forward.py
 ##########
 @@ -1113,6 +1113,49 @@ def test_forward_fully_connected():
     _test_fully_connected([5, 1, 1, 150], [150, 100], [100])
 
 
+#######################################################################
+# Custom Operators
+# -------
+
+def test_detection_postprocess():
+    tf_model_file = tf_testing.get_workload_official(
+        "http://download.tensorflow.org/models/object_detection/"
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
 
 Review comment:
   I can do, but where would you like me to pull it from? I see that ssd mobilenet v1 without the post process op is hosted under "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/", would it be possible to host the version with the post process op here as well?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r363254599
 
 

 ##########
 File path: python/tvm/relay/frontend/tflite.py
 ##########
 @@ -1494,6 +1499,112 @@ def convert_transpose_conv(self, op):
 
         return out
 
+    def _convert_detection_postprocess(self, op):
+        """Convert TFLite_Detection_PostProcess"""
+        _option_names = [
+            "w_scale",
+            "max_detections",
+            "_output_quantized",
+            "detections_per_class",
+            "x_scale",
+            "nms_score_threshold",
+            "num_classes",
+            "max_classes_per_detection",
+            "use_regular_nms",
+            "y_scale",
+            "h_scale",
+            "_support_output_type_float_in_quantized_op",
+            "nms_iou_threshold"
+        ]
+
+        custom_options = get_custom_options(op, _option_names)
+        if custom_options["use_regular_nms"]:
+            raise tvm.error.OpAttributeUnImplemented(
+                "use_regular_nms=True is not yet supported for operator {}."
+                .format("TFLite_Detection_PostProcess")
+            )
+
+        inputs = self.get_input_tensors(op)
+        cls_pred = self.get_expr(inputs[1].tensor_idx)
+        loc_prob = self.get_expr(inputs[0].tensor_idx)
+        anchor_values = self.get_tensor_value(inputs[2])
+        anchor_boxes = len(anchor_values)
+        anchor_type = self.get_tensor_type_str(inputs[2].tensor.Type())
+        anchor_expr = self.exp_tab.new_const(anchor_values, dtype=anchor_type)
+
+        if inputs[0].qnn_params:
+            loc_prob = _qnn.op.dequantize(data=loc_prob,
+                                          input_scale=inputs[0].qnn_params['scale'],
+                                          input_zero_point=inputs[0].qnn_params['zero_point'])
+        if inputs[1].qnn_params:
+            cls_pred = _qnn.op.dequantize(data=cls_pred,
+                                          input_scale=inputs[1].qnn_params['scale'],
+                                          input_zero_point=inputs[1].qnn_params['zero_point'])
+        if inputs[2].qnn_params:
+            anchor_expr = _qnn.op.dequantize(data=anchor_expr,
+                                             input_scale=inputs[2].qnn_params['scale'],
+                                             input_zero_point=inputs[2].qnn_params['zero_point'])
+
+        # reshape the cls_pred and loc_prob tensors so
+        # they can be consumed by multibox_transform_loc
+        cls_pred = _op.transpose(cls_pred, [0, 2, 1])
+        # loc_prob coords are in yxhw format
+        # need to convert to xywh
+        loc_coords = _op.split(loc_prob, 4, axis=2)
+        loc_prob = _op.concatenate(
+            [loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2
+        )
+        loc_prob = _op.reshape(loc_prob, [1, anchor_boxes*4])
+
+        # anchor coords are in yxhw format
+        # need to convert to ltrb
+        anchor_coords = _op.split(anchor_expr, 4, axis=1)
+        anchor_y = anchor_coords[0]
+        anchor_x = anchor_coords[1]
+        anchor_h = anchor_coords[2]
+        anchor_w = anchor_coords[3]
+        plus_half = _expr.const(0.5, dtype='float32')
+        minus_half = _expr.const(-0.5, dtype='float32')
+        anchor_l = _op.add(anchor_x, _op.multiply(anchor_w, minus_half))
+        anchor_r = _op.add(anchor_x, _op.multiply(anchor_w, plus_half))
+        anchor_t = _op.add(anchor_y, _op.multiply(anchor_h, minus_half))
+        anchor_b = _op.add(anchor_y, _op.multiply(anchor_h, plus_half))
+        anchor_expr = _op.concatenate([anchor_l, anchor_t, anchor_r, anchor_b], axis=1)
+        anchor_expr = _op.expand_dims(anchor_expr, 0)
+
+        # attributes for multibox_transform_loc
+        new_attrs0 = {}
+        new_attrs0["clip"] = False
+        new_attrs0["threshold"] = custom_options["nms_score_threshold"]
+        new_attrs0["variances"] = (
+            1/custom_options["x_scale"],
+            1/custom_options["y_scale"],
+            1/custom_options["w_scale"],
+            1/custom_options["h_scale"],
+        )
+
+        # attributes for non_max_suppression
+        new_attrs1 = {}
+        new_attrs1["return_indices"] = False
 
 Review comment:
   The output from tflite always has dynamic shape, however as we're using the graph runtime the tvm output is necessarily fixed in shape. In practice this means the tvm version will always output a tensor big enough to contain the maximal number of detections and only the first 'n' elements of the tensor will be valid. The value of 'n' is also an output of the network (for both tflite and tvm).

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r371193247
 
 

 ##########
 File path: tests/python/frontend/tflite/test_forward.py
 ##########
 @@ -1113,6 +1113,49 @@ def test_forward_fully_connected():
     _test_fully_connected([5, 1, 1, 150], [150, 100], [100])
 
 
+#######################################################################
+# Custom Operators
+# -------
+
+def test_detection_postprocess():
+    tf_model_file = tf_testing.get_workload_official(
+        "http://download.tensorflow.org/models/object_detection/"
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
 
 Review comment:
   I think we should resolve the issue of rounding in TVM. Would you mind opening an RFC to describe it? We could discuss and resolve it. This case is one good candidate why we need to keep the same as the rounding behavior of TFLite when we parse TFLite quantized model.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r371281677
 
 

 ##########
 File path: tests/python/frontend/tflite/test_forward.py
 ##########
 @@ -1113,6 +1113,49 @@ def test_forward_fully_connected():
     _test_fully_connected([5, 1, 1, 150], [150, 100], [100])
 
 
+#######################################################################
+# Custom Operators
+# -------
+
+def test_detection_postprocess():
+    tf_model_file = tf_testing.get_workload_official(
+        "http://download.tensorflow.org/models/object_detection/"
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
 
 Review comment:
   I think if we could view the TOCO source code, maybe we could find how to construct detection_postprocess. Please refer our `_test_prelu` comment. I ever write what the pattern tflite could produce prelu.  However, current way is acceptable too in my opinion. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r365487118
 
 

 ##########
 File path: python/tvm/relay/frontend/tflite.py
 ##########
 @@ -1494,6 +1499,112 @@ def convert_transpose_conv(self, op):
 
         return out
 
+    def _convert_detection_postprocess(self, op):
+        """Convert TFLite_Detection_PostProcess"""
+        _option_names = [
+            "w_scale",
+            "max_detections",
+            "_output_quantized",
+            "detections_per_class",
+            "x_scale",
+            "nms_score_threshold",
+            "num_classes",
+            "max_classes_per_detection",
+            "use_regular_nms",
+            "y_scale",
+            "h_scale",
+            "_support_output_type_float_in_quantized_op",
+            "nms_iou_threshold"
+        ]
+
+        custom_options = get_custom_options(op, _option_names)
+        if custom_options["use_regular_nms"]:
+            raise tvm.error.OpAttributeUnImplemented(
+                "use_regular_nms=True is not yet supported for operator {}."
+                .format("TFLite_Detection_PostProcess")
+            )
+
+        inputs = self.get_input_tensors(op)
+        cls_pred = self.get_expr(inputs[1].tensor_idx)
+        loc_prob = self.get_expr(inputs[0].tensor_idx)
+        anchor_values = self.get_tensor_value(inputs[2])
+        anchor_boxes = len(anchor_values)
+        anchor_type = self.get_tensor_type_str(inputs[2].tensor.Type())
+        anchor_expr = self.exp_tab.new_const(anchor_values, dtype=anchor_type)
+
+        if inputs[0].qnn_params:
+            loc_prob = _qnn.op.dequantize(data=loc_prob,
+                                          input_scale=inputs[0].qnn_params['scale'],
+                                          input_zero_point=inputs[0].qnn_params['zero_point'])
+        if inputs[1].qnn_params:
+            cls_pred = _qnn.op.dequantize(data=cls_pred,
+                                          input_scale=inputs[1].qnn_params['scale'],
+                                          input_zero_point=inputs[1].qnn_params['zero_point'])
+        if inputs[2].qnn_params:
+            anchor_expr = _qnn.op.dequantize(data=anchor_expr,
+                                             input_scale=inputs[2].qnn_params['scale'],
+                                             input_zero_point=inputs[2].qnn_params['zero_point'])
+
+        # reshape the cls_pred and loc_prob tensors so
+        # they can be consumed by multibox_transform_loc
+        cls_pred = _op.transpose(cls_pred, [0, 2, 1])
+        # loc_prob coords are in yxhw format
+        # need to convert to xywh
+        loc_coords = _op.split(loc_prob, 4, axis=2)
+        loc_prob = _op.concatenate(
+            [loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2
+        )
+        loc_prob = _op.reshape(loc_prob, [1, anchor_boxes*4])
+
+        # anchor coords are in yxhw format
+        # need to convert to ltrb
+        anchor_coords = _op.split(anchor_expr, 4, axis=1)
+        anchor_y = anchor_coords[0]
+        anchor_x = anchor_coords[1]
+        anchor_h = anchor_coords[2]
+        anchor_w = anchor_coords[3]
+        plus_half = _expr.const(0.5, dtype='float32')
+        minus_half = _expr.const(-0.5, dtype='float32')
+        anchor_l = _op.add(anchor_x, _op.multiply(anchor_w, minus_half))
+        anchor_r = _op.add(anchor_x, _op.multiply(anchor_w, plus_half))
+        anchor_t = _op.add(anchor_y, _op.multiply(anchor_h, minus_half))
+        anchor_b = _op.add(anchor_y, _op.multiply(anchor_h, plus_half))
+        anchor_expr = _op.concatenate([anchor_l, anchor_t, anchor_r, anchor_b], axis=1)
+        anchor_expr = _op.expand_dims(anchor_expr, 0)
+
+        # attributes for multibox_transform_loc
+        new_attrs0 = {}
 
 Review comment:
   change to `multibox_transform_loc_attrs`, `new_attrs0` is not good.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] yongwww commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
yongwww commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r363439474
 
 

 ##########
 File path: tests/python/frontend/tflite/test_forward.py
 ##########
 @@ -1113,6 +1113,49 @@ def test_forward_fully_connected():
     _test_fully_connected([5, 1, 1, 150], [150, 100], [100])
 
 
+#######################################################################
+# Custom Operators
+# -------
+
+def test_detection_postprocess():
+    tf_model_file = tf_testing.get_workload_official(
+        "http://download.tensorflow.org/models/object_detection/"
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
 
 Review comment:
   If possible, we'd like to pull the model from the related official website, for example https://www.tensorflow.org/lite/models/object_detection/overview for ssd mobilenet v1. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] u99127 commented on issue #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
u99127 commented on issue #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#issuecomment-585199370
 
 
   All changes have been done ? Anything left to merge this in  ? 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] tqchen commented on issue #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
tqchen commented on issue #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#issuecomment-582721889
 
 
   @mbarrett97 please rebase, @FrozenGene please followup :)

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r365486547
 
 

 ##########
 File path: python/tvm/relay/frontend/tflite.py
 ##########
 @@ -98,6 +98,7 @@ def __init__(self, model, subgraph, exp_tab):
             'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
             'PRELU': self.convert_prelu,
             'TRANSPOSE_CONV': self.convert_transpose_conv,
+            'DETECTION_POSTPROCESS': self._convert_detection_postprocess
 
 Review comment:
   Please change to `self._convert_detection_postprocess` as other convert function's code style. We should keep the same cod style.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] yongwww commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
yongwww commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r360491057
 
 

 ##########
 File path: tests/python/frontend/tflite/test_forward.py
 ##########
 @@ -1113,6 +1113,49 @@ def test_forward_fully_connected():
     _test_fully_connected([5, 1, 1, 150], [150, 100], [100])
 
 
+#######################################################################
+# Custom Operators
+# -------
+
+def test_detection_postprocess():
+    tf_model_file = tf_testing.get_workload_official(
+        "http://download.tensorflow.org/models/object_detection/"
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
 
 Review comment:
   would you mind adding more models like ssd_mobilenetv1?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] sjoshi30 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
sjoshi30 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r382285883
 
 

 ##########
 File path: tests/python/frontend/tflite/test_forward.py
 ##########
 @@ -1113,6 +1113,49 @@ def test_forward_fully_connected():
     _test_fully_connected([5, 1, 1, 150], [150, 100], [100])
 
 
+#######################################################################
+# Custom Operators
+# -------
+
+def test_detection_postprocess():
+    tf_model_file = tf_testing.get_workload_official(
+        "http://download.tensorflow.org/models/object_detection/"
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
 
 Review comment:
   @mbaret How did you set converter input as inputs of postprocess op, when I do that it gives me error : 
   tensorflow/lite/toco/model_cmdline_flags.cc:263] Check failed: mean_values.size() == model_flags->input_arrays_size()
   
   The inputs to postprocess op >1 ('raw_outputs/box_encodings','raw_outputs/class_predictions') also anchors constant

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r364290080
 
 

 ##########
 File path: tests/python/frontend/tflite/test_forward.py
 ##########
 @@ -1113,6 +1113,49 @@ def test_forward_fully_connected():
     _test_fully_connected([5, 1, 1, 150], [150, 100], [100])
 
 
+#######################################################################
+# Custom Operators
+# -------
+
+def test_detection_postprocess():
+    tf_model_file = tf_testing.get_workload_official(
+        "http://download.tensorflow.org/models/object_detection/"
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
 
 Review comment:
   OK - I did see that model but weirdly it was as a .zip, not a tar as with most other hosted models. I'll see if I can open another PR to extend get_workload_official to zips and then will add the test.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] yongwww commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
yongwww commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r360489709
 
 

 ##########
 File path: python/tvm/relay/frontend/tflite.py
 ##########
 @@ -1565,6 +1676,91 @@ def get_tensor_name(subgraph, tensor_idx):
     return subgraph.Tensors(tensor_idx).Name().decode("utf-8")
 
 
+def get_custom_options(op, option_names):
+    """Get the options of a custom operator.
+
+    This implements partial flexbuffer deserialization to be able
+    to read custom options. It is not intended to be a general
+    purpose flexbuffer deserializer and as such only supports a
+    limited number of types and assumes the data is a flat map.
+
+    Parameters
+    ----------
+    op:
+        A custom TFlite operator.
+    option_names: list
+        A complete list of the custom option names.
+
+    Returns
+    -------
+    options: dict
+        A dictionary of the custom options.
+
+    """
+    import struct
+    from enum import IntEnum
+
+    class _FlexBufferType(IntEnum):
+        """Flexbuffer type schema from flexbuffers.h"""
+        FBT_NULL = 0
+        FBT_INT = 1
+        FBT_UINT = 2
+        FBT_FLOAT = 3
+        # Types above stored inline, types below store an offset.
+        FBT_KEY = 4
+        FBT_STRING = 5
+        FBT_INDIRECT_INT = 6
+        FBT_INDIRECT_UINT = 7
+        FBT_INDIRECT_FLOAT = 8
+        FBT_MAP = 9
+        FBT_VECTOR = 10 # Untyped.
+        FBT_VECTOR_INT = 11 # Typed any size (stores no type table).
+        FBT_VECTOR_UINT = 12
+        FBT_VECTOR_FLOAT = 13
+        FBT_VECTOR_KEY = 14
+        FBT_VECTOR_STRING = 15
+        FBT_VECTOR_INT2 = 16 # Typed tuple (no type table, no size field).
+        FBT_VECTOR_UINT2 = 17
+        FBT_VECTOR_FLOAT2 = 18
+        FBT_VECTOR_INT3 = 19 # Typed triple (no type table, no size field).
+        FBT_VECTOR_UINT3 = 20
+        FBT_VECTOR_FLOAT3 = 21
+        FBT_VECTOR_INT4 = 22 # Typed quad (no type table, no size field).
+        FBT_VECTOR_UINT4 = 23
+        FBT_VECTOR_FLOAT4 = 24
+        FBT_BLOB = 25
+        FBT_BOOL = 26
+        FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type
+
+    buffer = op.CustomOptionsAsNumpy().tobytes()
+    value_vector_offset = buffer[-3]
+    buffer = buffer[:-3]
+    num_bytes = 4 # Assume all values are stored in 32 bit width
+    value_vector_size = struct.unpack(
+        "<i", buffer[-value_vector_offset - num_bytes:-value_vector_offset]
+    )[0]
+    type_offset = value_vector_size
+    types = buffer[-type_offset:]
+    values = []
+    for i, t in enumerate(types):
+        flex_type = _FlexBufferType(t >> 2)
+        value_offset = -value_vector_offset + i*num_bytes
+        value_bytes = buffer[value_offset:value_offset+num_bytes]
+        if flex_type == _FlexBufferType.FBT_BOOL:
+            value = True if value_bytes[0] else False
+        if flex_type == _FlexBufferType.FBT_INT:
 
 Review comment:
   elif? seems not all type schema above are handled here?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r371269023
 
 

 ##########
 File path: tests/python/frontend/tflite/test_forward.py
 ##########
 @@ -1113,6 +1113,49 @@ def test_forward_fully_connected():
     _test_fully_connected([5, 1, 1, 150], [150, 100], [100])
 
 
+#######################################################################
+# Custom Operators
+# -------
+
+def test_detection_postprocess():
+    tf_model_file = tf_testing.get_workload_official(
+        "http://download.tensorflow.org/models/object_detection/"
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
 
 Review comment:
   Alright, we could remove ssd mobilenet model because of this limitation, but we should still keep the unit testing of detection postprocess. After we resolve the limitation, we could add ssd mobilenet testing back. Morever, we could remove the atol=1 of test_qconv2d and so on. Because we could get the same result completely compared with the tflite. Does it make sense to you?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r371170921
 
 

 ##########
 File path: tests/python/frontend/tflite/test_forward.py
 ##########
 @@ -1113,6 +1113,49 @@ def test_forward_fully_connected():
     _test_fully_connected([5, 1, 1, 150], [150, 100], [100])
 
 
+#######################################################################
+# Custom Operators
+# -------
+
+def test_detection_postprocess():
+    tf_model_file = tf_testing.get_workload_official(
+        "http://download.tensorflow.org/models/object_detection/"
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
 
 Review comment:
   Apologies for the delayed response. Yes, that's probably the source of the error. Normally that could be worked around just by increasing the error tolerances. But that doesn't work in this case because of the sorting and clipping that occurs.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] yongwww commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
yongwww commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r360491655
 
 

 ##########
 File path: python/tvm/relay/frontend/tflite.py
 ##########
 @@ -1494,6 +1499,112 @@ def convert_transpose_conv(self, op):
 
         return out
 
+    def _convert_detection_postprocess(self, op):
+        """Convert TFLite_Detection_PostProcess"""
+        _option_names = [
+            "w_scale",
+            "max_detections",
+            "_output_quantized",
+            "detections_per_class",
+            "x_scale",
+            "nms_score_threshold",
+            "num_classes",
+            "max_classes_per_detection",
+            "use_regular_nms",
+            "y_scale",
+            "h_scale",
+            "_support_output_type_float_in_quantized_op",
+            "nms_iou_threshold"
+        ]
+
+        custom_options = get_custom_options(op, _option_names)
+        if custom_options["use_regular_nms"]:
+            raise tvm.error.OpAttributeUnImplemented(
+                "use_regular_nms=True is not yet supported for operator {}."
+                .format("TFLite_Detection_PostProcess")
+            )
+
+        inputs = self.get_input_tensors(op)
+        cls_pred = self.get_expr(inputs[1].tensor_idx)
+        loc_prob = self.get_expr(inputs[0].tensor_idx)
+        anchor_values = self.get_tensor_value(inputs[2])
+        anchor_boxes = len(anchor_values)
+        anchor_type = self.get_tensor_type_str(inputs[2].tensor.Type())
+        anchor_expr = self.exp_tab.new_const(anchor_values, dtype=anchor_type)
+
+        if inputs[0].qnn_params:
+            loc_prob = _qnn.op.dequantize(data=loc_prob,
+                                          input_scale=inputs[0].qnn_params['scale'],
+                                          input_zero_point=inputs[0].qnn_params['zero_point'])
+        if inputs[1].qnn_params:
+            cls_pred = _qnn.op.dequantize(data=cls_pred,
+                                          input_scale=inputs[1].qnn_params['scale'],
+                                          input_zero_point=inputs[1].qnn_params['zero_point'])
+        if inputs[2].qnn_params:
+            anchor_expr = _qnn.op.dequantize(data=anchor_expr,
+                                             input_scale=inputs[2].qnn_params['scale'],
+                                             input_zero_point=inputs[2].qnn_params['zero_point'])
+
+        # reshape the cls_pred and loc_prob tensors so
+        # they can be consumed by multibox_transform_loc
+        cls_pred = _op.transpose(cls_pred, [0, 2, 1])
+        # loc_prob coords are in yxhw format
+        # need to convert to xywh
+        loc_coords = _op.split(loc_prob, 4, axis=2)
+        loc_prob = _op.concatenate(
+            [loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2
+        )
+        loc_prob = _op.reshape(loc_prob, [1, anchor_boxes*4])
+
+        # anchor coords are in yxhw format
+        # need to convert to ltrb
+        anchor_coords = _op.split(anchor_expr, 4, axis=1)
+        anchor_y = anchor_coords[0]
+        anchor_x = anchor_coords[1]
+        anchor_h = anchor_coords[2]
+        anchor_w = anchor_coords[3]
+        plus_half = _expr.const(0.5, dtype='float32')
+        minus_half = _expr.const(-0.5, dtype='float32')
+        anchor_l = _op.add(anchor_x, _op.multiply(anchor_w, minus_half))
+        anchor_r = _op.add(anchor_x, _op.multiply(anchor_w, plus_half))
+        anchor_t = _op.add(anchor_y, _op.multiply(anchor_h, minus_half))
+        anchor_b = _op.add(anchor_y, _op.multiply(anchor_h, plus_half))
+        anchor_expr = _op.concatenate([anchor_l, anchor_t, anchor_r, anchor_b], axis=1)
+        anchor_expr = _op.expand_dims(anchor_expr, 0)
+
+        # attributes for multibox_transform_loc
+        new_attrs0 = {}
+        new_attrs0["clip"] = False
+        new_attrs0["threshold"] = custom_options["nms_score_threshold"]
+        new_attrs0["variances"] = (
+            1/custom_options["x_scale"],
+            1/custom_options["y_scale"],
+            1/custom_options["w_scale"],
+            1/custom_options["h_scale"],
+        )
+
+        # attributes for non_max_suppression
+        new_attrs1 = {}
+        new_attrs1["return_indices"] = False
 
 Review comment:
   I am curious about if the output has dynamic shape if `return_indices=True`

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r366410943
 
 

 ##########
 File path: tests/python/frontend/tflite/test_forward.py
 ##########
 @@ -1113,6 +1113,49 @@ def test_forward_fully_connected():
     _test_fully_connected([5, 1, 1, 150], [150, 100], [100])
 
 
+#######################################################################
+# Custom Operators
+# -------
+
+def test_detection_postprocess():
+    tf_model_file = tf_testing.get_workload_official(
+        "http://download.tensorflow.org/models/object_detection/"
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
 
 Review comment:
   The test looks non-trivial to add because quite a small difference in the convolutional part of the network can result in significant changes to the ordering of the output tensor (eg. we might see at different detection at the cut off threshold). I'm not sure what the best way is to proceed, do you have any thoughts?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r365486547
 
 

 ##########
 File path: python/tvm/relay/frontend/tflite.py
 ##########
 @@ -98,6 +98,7 @@ def __init__(self, model, subgraph, exp_tab):
             'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
             'PRELU': self.convert_prelu,
             'TRANSPOSE_CONV': self.convert_transpose_conv,
+            'DETECTION_POSTPROCESS': self._convert_detection_postprocess
 
 Review comment:
   Please change to `self._convert_detection_postprocess` as other convert function's code style. We should keep the same cod style.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] FrozenGene commented on issue #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on issue #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#issuecomment-585513267
 
 
   Thanks everyone , merged now

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-tvm] FrozenGene merged pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

Posted by GitBox <gi...@apache.org>.
FrozenGene merged pull request #4543: [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543
 
 
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services