You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/04/22 15:45:18 UTC

[tvm] branch main updated: [ONNX] Support NMS Center Box (#7900)

This is an automated email from the ASF dual-hosted git repository.

mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0e2d5ea  [ONNX] Support NMS Center Box (#7900)
0e2d5ea is described below

commit 0e2d5ea4791405045be422adf728064cb91f004e
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Thu Apr 22 09:45:03 2021 -0600

    [ONNX] Support NMS Center Box (#7900)
    
    * [ONNX] Support NMS Center Box
    
    * fix silly mistake in contional
---
 python/tvm/relay/frontend/onnx.py          | 16 +++++++++++-----
 tests/python/frontend/onnx/test_forward.py |  1 -
 2 files changed, 11 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 4b159a5..fa0eac9 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -2543,11 +2543,17 @@ class NonMaxSuppression(OnnxOpConverter):
         iou_threshold = inputs[3]
         score_threshold = inputs[4]
 
-        if "center_point_box" in attr:
-            if attr["center_point_box"] != 0:
-                raise NotImplementedError(
-                    "Only support center_point_box = 0 in ONNX NonMaxSuprresion"
-                )
+        boxes_dtype = infer_type(boxes).checked_type.dtype
+
+        if attr.get("center_point_box", 0) != 0:
+            xc, yc, w, h = _op.split(boxes, 4, axis=2)
+            half_w = w / _expr.const(2.0, boxes_dtype)
+            half_h = h / _expr.const(2.0, boxes_dtype)
+            x1 = xc - half_w
+            x2 = xc + half_w
+            y1 = yc - half_h
+            y2 = yc + half_h
+            boxes = _op.concatenate([y1, x1, y2, x2], axis=2)
 
         if iou_threshold is None:
             iou_threshold = _expr.const(0.0, dtype="float32")
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 0a702c5..1eaae6f 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -4215,7 +4215,6 @@ unsupported_onnx_tests = [
     "test_maxpool_with_argmax_2d_precomputed_strides/",
     "test_maxunpool_export_with_output_shape/",
     "test_mvn/",
-    "test_nonmaxsuppression_center_point_box_format/",
     "test_qlinearconv/",
     "test_qlinearmatmul_2D/",
     "test_qlinearmatmul_3D/",