You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2020/04/29 16:13:28 UTC

[incubator-tvm] branch master updated: [TFLITE] Match TFLite shape for SSD custom op (#5473)

This is an automated email from the ASF dual-hosted git repository.

anijain2305 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 95a816c  [TFLITE] Match TFLite shape for SSD custom op (#5473)
95a816c is described below

commit 95a816c9078c5cc7cb08d354a069a15f5d18951c
Author: mbaret <55...@users.noreply.github.com>
AuthorDate: Wed Apr 29 17:13:16 2020 +0100

    [TFLITE] Match TFLite shape for SSD custom op (#5473)
    
    This patch ensures that the output shape from TVM's
    Detection_PostProcess is the same as TFLite's and
    expands the unit test to confirm this.
    
    Change-Id: If5db95741533f131241dfebbaa7708dbd528fe70
---
 python/tvm/relay/frontend/tflite.py          | 13 +++++++++----
 tests/python/frontend/tflite/test_forward.py |  7 +++++++
 2 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py
index b9a1657..66d0ff3 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -2257,6 +2257,7 @@ class OperatorConverter(object):
         assert len(inputs) == 3, "inputs length should be 3"
         cls_pred = self.get_expr(inputs[1].tensor_idx)
         loc_prob = self.get_expr(inputs[0].tensor_idx)
+        batch_size = inputs[1].tensor.Shape(0)
         anchor_values = self.get_tensor_value(inputs[2])
         anchor_boxes = len(anchor_values)
         anchor_type = self.get_tensor_type_str(inputs[2].tensor.Type())
@@ -2284,7 +2285,7 @@ class OperatorConverter(object):
         loc_prob = _op.concatenate(
             [loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2
         )
-        loc_prob = _op.reshape(loc_prob, [1, anchor_boxes*4])
+        loc_prob = _op.reshape(loc_prob, [batch_size, anchor_boxes*4])
 
         # anchor coords are in yxhw format
         # need to convert to ltrb
@@ -2327,10 +2328,14 @@ class OperatorConverter(object):
         ret = _op.vision.non_max_suppression(ret[0], ret[1], **non_max_suppression_attrs)
         ret = _op.vision.get_valid_counts(ret, 0)
         valid_count = ret[0]
+        # keep only the top 'max_detections' rows
+        ret = _op.strided_slice(ret[1],
+                                [0, 0, 0],
+                                [batch_size, custom_options["max_detections"], anchor_boxes])
         # the output needs some reshaping to match tflite
-        ret = _op.split(ret[1], 6, axis=2)
-        cls_ids = ret[0]
-        scores = ret[1]
+        ret = _op.split(ret, 6, axis=2)
+        cls_ids = _op.reshape(ret[0], [batch_size, -1])
+        scores = _op.reshape(ret[1], [batch_size, -1])
         boxes = _op.concatenate([ret[3], ret[2], ret[5], ret[4]], axis=2)
         ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores, valid_count]), size=4)
         return ret
diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py
index 7ff4c31..bc3f32a 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -1731,7 +1731,14 @@ def test_detection_postprocess():
                                ["raw_outputs/box_encodings", "raw_outputs/class_predictions"], num_output=4)
     # check valid count is the same
     assert tvm_output[3] == tflite_output[3]
+    # check all the output shapes are the same
+    assert tvm_output[0].shape == tflite_output[0].shape
+    assert tvm_output[1].shape == tflite_output[1].shape
+    assert tvm_output[2].shape == tflite_output[2].shape
     valid_count = tvm_output[3][0]
+    # only check the valid detections are the same
+    # tvm has a different convention to tflite for invalid detections, it uses all -1s whereas
+    # tflite appears to put in nonsense data instead
     tvm_boxes = tvm_output[0][0][:valid_count]
     tvm_classes = tvm_output[1][0][:valid_count]
     tvm_scores = tvm_output[2][0][:valid_count]