You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/04/22 15:45:18 UTC
[tvm] branch main updated: [ONNX] Support NMS Center Box (#7900)
This is an automated email from the ASF dual-hosted git repository.
mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0e2d5ea [ONNX] Support NMS Center Box (#7900)
0e2d5ea is described below
commit 0e2d5ea4791405045be422adf728064cb91f004e
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Thu Apr 22 09:45:03 2021 -0600
[ONNX] Support NMS Center Box (#7900)
* [ONNX] Support NMS Center Box
* fix silly mistake in contional
---
python/tvm/relay/frontend/onnx.py | 16 +++++++++++-----
tests/python/frontend/onnx/test_forward.py | 1 -
2 files changed, 11 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 4b159a5..fa0eac9 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -2543,11 +2543,17 @@ class NonMaxSuppression(OnnxOpConverter):
iou_threshold = inputs[3]
score_threshold = inputs[4]
- if "center_point_box" in attr:
- if attr["center_point_box"] != 0:
- raise NotImplementedError(
- "Only support center_point_box = 0 in ONNX NonMaxSuprresion"
- )
+ boxes_dtype = infer_type(boxes).checked_type.dtype
+
+ if attr.get("center_point_box", 0) != 0:
+ xc, yc, w, h = _op.split(boxes, 4, axis=2)
+ half_w = w / _expr.const(2.0, boxes_dtype)
+ half_h = h / _expr.const(2.0, boxes_dtype)
+ x1 = xc - half_w
+ x2 = xc + half_w
+ y1 = yc - half_h
+ y2 = yc + half_h
+ boxes = _op.concatenate([y1, x1, y2, x2], axis=2)
if iou_threshold is None:
iou_threshold = _expr.const(0.0, dtype="float32")
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 0a702c5..1eaae6f 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -4215,7 +4215,6 @@ unsupported_onnx_tests = [
"test_maxpool_with_argmax_2d_precomputed_strides/",
"test_maxunpool_export_with_output_shape/",
"test_mvn/",
- "test_nonmaxsuppression_center_point_box_format/",
"test_qlinearconv/",
"test_qlinearmatmul_2D/",
"test_qlinearmatmul_3D/",