You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/02/16 04:58:06 UTC
[incubator-mxnet] branch v1.x updated: refactor code (#19887)
This is an automated email from the ASF dual-hosted git repository.
zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 3b470d1 refactor code (#19887)
3b470d1 is described below
commit 3b470d134914c461f0b6be6b9cb828351f27486f
Author: waytrue17 <52...@users.noreply.github.com>
AuthorDate: Mon Feb 15 20:56:02 2021 -0800
refactor code (#19887)
Co-authored-by: Wei Chu <we...@amazon.com>
---
tests/python-pytest/onnx/test_onnxruntime.py | 86 +++++++++++++++++++++-------
1 file changed, 65 insertions(+), 21 deletions(-)
diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py
index ce372c7..86e19fa 100644
--- a/tests/python-pytest/onnx/test_onnxruntime.py
+++ b/tests/python-pytest/onnx/test_onnxruntime.py
@@ -272,25 +272,62 @@ def obj_detection_test_images(tmpdir_factory):
'center_net_resnet18_v1b_coco',
'center_net_resnet50_v1b_coco',
'center_net_resnet101_v1b_coco',
- # the following models are failing due to onnxruntime errors
- #'ssd_300_vgg16_atrous_voc',
- #'ssd_512_vgg16_atrous_voc',
- #'ssd_512_resnet50_v1_voc',
- #'ssd_512_mobilenet1.0_voc',
- #'faster_rcnn_resnet50_v1b_voc',
- #'yolo3_darknet53_voc',
- #'yolo3_mobilenet1.0_voc',
- #'ssd_300_vgg16_atrous_coco',
- #'ssd_512_vgg16_atrous_coco',
- #'ssd_300_resnet34_v1b_coco',
- #'ssd_512_resnet50_v1_coco',
- #'ssd_512_mobilenet1.0_coco',
- #'faster_rcnn_resnet50_v1b_coco',
- #'faster_rcnn_resnet101_v1d_coco',
- #'yolo3_darknet53_coco',
- #'yolo3_mobilenet1.0_coco',
+ 'ssd_300_vgg16_atrous_voc',
+ 'ssd_512_vgg16_atrous_voc',
+ 'ssd_512_resnet50_v1_voc',
+ 'ssd_512_mobilenet1.0_voc',
+ 'faster_rcnn_resnet50_v1b_voc',
+ 'yolo3_darknet53_voc',
+ 'yolo3_mobilenet1.0_voc',
+ 'ssd_300_vgg16_atrous_coco',
+ 'ssd_512_vgg16_atrous_coco',
+ # 'ssd_300_resnet34_v1b_coco', #cannot import
+ 'ssd_512_resnet50_v1_coco',
+ 'ssd_512_mobilenet1.0_coco',
+ 'faster_rcnn_resnet50_v1b_coco',
+ 'faster_rcnn_resnet101_v1d_coco',
+ 'yolo3_darknet53_coco',
+ 'yolo3_mobilenet1.0_coco',
])
def test_obj_detection_model_inference_onnxruntime(tmp_path, model, obj_detection_test_images):
+ def assert_obj_detetion_result(mx_ids, mx_scores, mx_boxes,
+ onnx_ids, onnx_scores, onnx_boxes,
+ score_thresh=0.6, score_tol=1e-4):
+ def assert_bbox(mx_boxe, onnx_boxe, box_tol=1e-2):
+ def assert_scalar(a, b, tol=box_tol):
+ return np.abs(a-b) <= tol
+ return assert_scalar(mx_boxe[0], onnx_boxe[0]) and assert_scalar(mx_boxe[1], onnx_boxe[1]) \
+ and assert_scalar(mx_boxe[2], onnx_boxe[2]) and assert_scalar(mx_boxe[3], onnx_boxe[3])
+
+ found_match = False
+ for i in range(len(onnx_ids)):
+ onnx_id = onnx_ids[i][0]
+ onnx_score = onnx_scores[i][0]
+ onnx_boxe = onnx_boxes[i]
+
+ if onnx_score < score_thresh:
+ break
+ for j in range(len(mx_ids)):
+ mx_id = mx_ids[j].asnumpy()[0]
+ mx_score = mx_scores[j].asnumpy()[0]
+ mx_boxe = mx_boxes[j].asnumpy()
+ # check socre
+ if onnx_score < mx_score - score_tol:
+ continue
+ if onnx_score > mx_score + score_tol:
+ return False
+ # check id
+ if onnx_id != mx_id:
+ continue
+ # check bounding box
+ if assert_bbox(mx_boxe, onnx_boxe):
+ found_match = True
+ break
+ if not found_match:
+ return False
+ found_match = False
+ return True
+
def normalize_image(imgfile):
img = mx.image.imread(imgfile)
img, _ = mx.image.center_crop(img, size=(512, 512))
@@ -310,10 +347,17 @@ def test_obj_detection_model_inference_onnxruntime(tmp_path, model, obj_detectio
for img in obj_detection_test_images:
img_data = normalize_image(img)
mx_class_ids, mx_scores, mx_boxes = M.predict(img_data)
- onnx_scores, onnx_class_ids, onnx_boxes = session.run([], {input_name: img_data.asnumpy()})
- assert_almost_equal(mx_class_ids, onnx_class_ids)
- assert_almost_equal(mx_scores, onnx_scores)
- assert_almost_equal(mx_boxes, onnx_boxes)
+ # center_net_resnet models have different output format
+ if 'center_net_resnet' in model:
+ onnx_scores, onnx_class_ids, onnx_boxes = session.run([], {input_name: img_data.asnumpy()})
+ assert_almost_equal(mx_class_ids, onnx_class_ids)
+ assert_almost_equal(mx_scores, onnx_scores)
+ assert_almost_equal(mx_boxes, onnx_boxes)
+ else:
+ onnx_class_ids, onnx_scores, onnx_boxes = session.run([], {input_name: img_data.asnumpy()})
+ if not assert_obj_detetion_result(mx_class_ids[0], mx_scores[0], mx_boxes[0], \
+ onnx_class_ids[0], onnx_scores[0], onnx_boxes[0]):
+ raise AssertionError("Assertion error on model: " + model)
finally:
shutil.rmtree(tmp_path)