You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/02/16 04:58:06 UTC

[incubator-mxnet] branch v1.x updated: refactor code (#19887)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 3b470d1  refactor code (#19887)
3b470d1 is described below

commit 3b470d134914c461f0b6be6b9cb828351f27486f
Author: waytrue17 <52...@users.noreply.github.com>
AuthorDate: Mon Feb 15 20:56:02 2021 -0800

    refactor code (#19887)
    
    Co-authored-by: Wei Chu <we...@amazon.com>
---
 tests/python-pytest/onnx/test_onnxruntime.py | 86 +++++++++++++++++++++-------
 1 file changed, 65 insertions(+), 21 deletions(-)

diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py
index ce372c7..86e19fa 100644
--- a/tests/python-pytest/onnx/test_onnxruntime.py
+++ b/tests/python-pytest/onnx/test_onnxruntime.py
@@ -272,25 +272,62 @@ def obj_detection_test_images(tmpdir_factory):
     'center_net_resnet18_v1b_coco',
     'center_net_resnet50_v1b_coco',
     'center_net_resnet101_v1b_coco',
-    # the following models are failing due to onnxruntime errors
-    #'ssd_300_vgg16_atrous_voc',
-    #'ssd_512_vgg16_atrous_voc',
-    #'ssd_512_resnet50_v1_voc',
-    #'ssd_512_mobilenet1.0_voc',
-    #'faster_rcnn_resnet50_v1b_voc',
-    #'yolo3_darknet53_voc',
-    #'yolo3_mobilenet1.0_voc',
-    #'ssd_300_vgg16_atrous_coco',
-    #'ssd_512_vgg16_atrous_coco',
-    #'ssd_300_resnet34_v1b_coco',
-    #'ssd_512_resnet50_v1_coco',
-    #'ssd_512_mobilenet1.0_coco',
-    #'faster_rcnn_resnet50_v1b_coco',
-    #'faster_rcnn_resnet101_v1d_coco',
-    #'yolo3_darknet53_coco',
-    #'yolo3_mobilenet1.0_coco',
+    'ssd_300_vgg16_atrous_voc',
+    'ssd_512_vgg16_atrous_voc',
+    'ssd_512_resnet50_v1_voc',
+    'ssd_512_mobilenet1.0_voc',
+    'faster_rcnn_resnet50_v1b_voc',
+    'yolo3_darknet53_voc',
+    'yolo3_mobilenet1.0_voc',
+    'ssd_300_vgg16_atrous_coco',
+    'ssd_512_vgg16_atrous_coco',
+    # 'ssd_300_resnet34_v1b_coco', #cannot import
+    'ssd_512_resnet50_v1_coco',
+    'ssd_512_mobilenet1.0_coco',
+    'faster_rcnn_resnet50_v1b_coco',
+    'faster_rcnn_resnet101_v1d_coco',
+    'yolo3_darknet53_coco',
+    'yolo3_mobilenet1.0_coco',
 ])
 def test_obj_detection_model_inference_onnxruntime(tmp_path, model, obj_detection_test_images):
+    def assert_obj_detetion_result(mx_ids, mx_scores, mx_boxes,
+                                   onnx_ids, onnx_scores, onnx_boxes,
+                                   score_thresh=0.6, score_tol=1e-4):
+        def assert_bbox(mx_boxe, onnx_boxe, box_tol=1e-2):
+            def assert_scalar(a, b, tol=box_tol):
+                return np.abs(a-b) <= tol
+            return assert_scalar(mx_boxe[0], onnx_boxe[0]) and assert_scalar(mx_boxe[1], onnx_boxe[1]) \
+                      and assert_scalar(mx_boxe[2], onnx_boxe[2]) and assert_scalar(mx_boxe[3], onnx_boxe[3])
+
+        found_match = False
+        for i in range(len(onnx_ids)):
+            onnx_id = onnx_ids[i][0]
+            onnx_score = onnx_scores[i][0]
+            onnx_boxe = onnx_boxes[i]
+
+            if onnx_score < score_thresh:
+                break
+            for j in range(len(mx_ids)):
+                mx_id = mx_ids[j].asnumpy()[0]
+                mx_score = mx_scores[j].asnumpy()[0]
+                mx_boxe = mx_boxes[j].asnumpy()
+                # check socre 
+                if onnx_score < mx_score - score_tol:
+                    continue
+                if onnx_score > mx_score + score_tol:
+                    return False
+                # check id
+                if onnx_id != mx_id:
+                    continue
+                # check bounding box
+                if assert_bbox(mx_boxe, onnx_boxe):
+                    found_match = True
+                    break
+            if not found_match:
+                return False
+            found_match = False
+        return True
+
     def normalize_image(imgfile):
         img = mx.image.imread(imgfile)
         img, _ = mx.image.center_crop(img, size=(512, 512))
@@ -310,10 +347,17 @@ def test_obj_detection_model_inference_onnxruntime(tmp_path, model, obj_detectio
         for img in obj_detection_test_images:
             img_data = normalize_image(img)
             mx_class_ids, mx_scores, mx_boxes = M.predict(img_data)
-            onnx_scores, onnx_class_ids, onnx_boxes = session.run([], {input_name: img_data.asnumpy()})
-            assert_almost_equal(mx_class_ids, onnx_class_ids)
-            assert_almost_equal(mx_scores, onnx_scores)
-            assert_almost_equal(mx_boxes, onnx_boxes)
+            # center_net_resnet models have different output format
+            if 'center_net_resnet' in model:
+                onnx_scores, onnx_class_ids, onnx_boxes = session.run([], {input_name: img_data.asnumpy()})
+                assert_almost_equal(mx_class_ids, onnx_class_ids)
+                assert_almost_equal(mx_scores, onnx_scores)
+                assert_almost_equal(mx_boxes, onnx_boxes)
+            else:
+                onnx_class_ids, onnx_scores, onnx_boxes = session.run([], {input_name: img_data.asnumpy()})
+                if not assert_obj_detetion_result(mx_class_ids[0], mx_scores[0], mx_boxes[0], \
+                        onnx_class_ids[0], onnx_scores[0], onnx_boxes[0]):
+                    raise AssertionError("Assertion error on model: " + model)
 
     finally:
         shutil.rmtree(tmp_path)