You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/11/12 18:24:56 UTC

[GitHub] [incubator-tvm] jwfromm commented on a change in pull request #6839: [WIP][ONNX] NMS in ONNX

jwfromm commented on a change in pull request #6839:
URL: https://github.com/apache/incubator-tvm/pull/6839#discussion_r522315510



##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -2199,6 +2199,241 @@ def body_fn(*loop_inputs):
         return outputs
 
 
+class NonMaxSuppression(OnnxOpConverter):
+    """Operator converter for NonMaxSuppression."""
+
+    @classmethod
+    def _impl_v10(cls, inputs, attr, params):
+        # Get parameter values
+        boxes = inputs[0]
+        scores = inputs[1]
+        max_output_boxes_per_class = inputs[2]
+        iou_threshold = inputs[3]
+        score_threshold = inputs[4]
+
+        dtype = infer_type(boxes).checked_type.dtype
+
+        if "center_point_box" in attr:
+            assert (
+                attr["center_point_box"] == 0
+            ), "Only support center_point_box = 0 in onnx importer right now"
+
+        if iou_threshold is None:
+            iou_threshold = _expr.const(0.0, dtype="float32")
+        if score_threshold is None:
+            score_threshold = _expr.const(0.0, dtype="float32")
+
+        def conditionally_squeeze_scalar(x):
+            rank = len(infer_shape(x))
+            assert rank <= 1, "nms thresholds must be scalars"
+            if rank == 1:
+                return _op.squeeze(x, [0])
+            return x
+
+        max_output_boxes_per_class = conditionally_squeeze_scalar(max_output_boxes_per_class)
+        iou_threshold = conditionally_squeeze_scalar(iou_threshold)
+        score_threshold = conditionally_squeeze_scalar(score_threshold)
+        zero = _op.const(np.array([0]), dtype="int64")
+        one = _op.const(np.array([1]), dtype="int64")
+        two = _op.const(np.array([2]), dtype="int64")
+        three = _op.const(np.array([3]), dtype="int64")
+        three_ones = _op.const(np.array([1, 1, 1]), dtype="int64")
+        four_ones = _op.const(np.array([1, 1, 1, 1]), dtype="int64")
+
+        # First Loop Vars
+        i = _expr.var("i", shape=(1,), dtype="int64")
+        scores_var = _expr.var("scores_var", shape=(_ty.Any(), _ty.Any(), _ty.Any()), dtype=dtype)
+        boxes_var = _expr.var("boxes_var", shape=(_ty.Any(), _ty.Any(), 4), dtype=dtype)
+        max_output_boxes_per_class_var = _expr.var(
+            "max_output_boxes_per_class_var", shape=(), dtype="int64"
+        )
+        iou_threshold_var = _expr.var("iou_threshold_var", shape=(), dtype="float32")
+        score_threshold_var = _expr.var("score_threshold_var", shape=(), dtype="float32")
+        B = _expr.var("B", shape=(1,), dtype="int64")
+        C = _expr.var("C", shape=(1,), dtype="int64")
+        S = _expr.var("S", shape=(1,), dtype="int64")
+        # Outputs of first loop should be padded nms values shape (B, C, 3)
+        onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64")

Review comment:
       The comment and this variable's shape dont line up. Can you update the comment to describe all 4 dimensions of `onnx_out`?

##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -2199,6 +2199,241 @@ def body_fn(*loop_inputs):
         return outputs
 
 
+class NonMaxSuppression(OnnxOpConverter):
+    """Operator converter for NonMaxSuppression."""
+
+    @classmethod
+    def _impl_v10(cls, inputs, attr, params):
+        # Get parameter values
+        boxes = inputs[0]
+        scores = inputs[1]
+        max_output_boxes_per_class = inputs[2]
+        iou_threshold = inputs[3]
+        score_threshold = inputs[4]
+
+        dtype = infer_type(boxes).checked_type.dtype
+
+        if "center_point_box" in attr:
+            assert (
+                attr["center_point_box"] == 0
+            ), "Only support center_point_box = 0 in onnx importer right now"
+
+        if iou_threshold is None:
+            iou_threshold = _expr.const(0.0, dtype="float32")
+        if score_threshold is None:
+            score_threshold = _expr.const(0.0, dtype="float32")
+
+        def conditionally_squeeze_scalar(x):
+            rank = len(infer_shape(x))
+            assert rank <= 1, "nms thresholds must be scalars"
+            if rank == 1:
+                return _op.squeeze(x, [0])
+            return x
+
+        max_output_boxes_per_class = conditionally_squeeze_scalar(max_output_boxes_per_class)
+        iou_threshold = conditionally_squeeze_scalar(iou_threshold)
+        score_threshold = conditionally_squeeze_scalar(score_threshold)
+        zero = _op.const(np.array([0]), dtype="int64")
+        one = _op.const(np.array([1]), dtype="int64")
+        two = _op.const(np.array([2]), dtype="int64")
+        three = _op.const(np.array([3]), dtype="int64")
+        three_ones = _op.const(np.array([1, 1, 1]), dtype="int64")
+        four_ones = _op.const(np.array([1, 1, 1, 1]), dtype="int64")
+
+        # First Loop Vars
+        i = _expr.var("i", shape=(1,), dtype="int64")
+        scores_var = _expr.var("scores_var", shape=(_ty.Any(), _ty.Any(), _ty.Any()), dtype=dtype)
+        boxes_var = _expr.var("boxes_var", shape=(_ty.Any(), _ty.Any(), 4), dtype=dtype)
+        max_output_boxes_per_class_var = _expr.var(
+            "max_output_boxes_per_class_var", shape=(), dtype="int64"
+        )
+        iou_threshold_var = _expr.var("iou_threshold_var", shape=(), dtype="float32")
+        score_threshold_var = _expr.var("score_threshold_var", shape=(), dtype="float32")
+        B = _expr.var("B", shape=(1,), dtype="int64")
+        C = _expr.var("C", shape=(1,), dtype="int64")
+        S = _expr.var("S", shape=(1,), dtype="int64")
+        # Outputs of first loop should be padded nms values shape (B, C, 3)
+        onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64")
+        # and sizes of valid outputs, shape (B, C, 1)
+        nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64")
+
+        def _first_cond(
+            i,
+            scores,
+            boxes,
+            B,
+            C,
+            S,
+            max_output_boxes_per_class,
+            iou_threshold,
+            score_threshold,
+            onnx_out,
+            nms_size_out,
+        ):
+            return _op.min(_op.less(i, C))

Review comment:
       A comment explaining what this condition is checking would be helpful.

##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -2199,6 +2199,241 @@ def body_fn(*loop_inputs):
         return outputs
 
 
+class NonMaxSuppression(OnnxOpConverter):
+    """Operator converter for NonMaxSuppression."""
+
+    @classmethod

Review comment:
       A quick description of why we need three loops to implement onnx NMS would be helpful for readers trying to make sense of what's going on here.

##########
File path: tests/python/frontend/onnx/test_forward.py
##########
@@ -53,10 +53,15 @@ def get_tvm_output_with_vm(
     mod, params = relay.frontend.from_onnx(
         graph_def, shape_dict, opset=opset, freeze_params=freeze_params
     )
-    if convert_to_static:
-        from tvm.relay import transform
 
-        mod = transform.DynamicToStatic()(mod)
+    from tvm.relay import transform
+
+    # print(mod.astext(show_meta_data=True))
+    # self.mod = transform.AnnotateSpans()(mod)
+    # print(mod.astext(show_meta_data=False))

Review comment:
       Remove these commented lines before merging.

##########
File path: tests/python/frontend/onnx/test_forward.py
##########
@@ -3889,78 +3986,7 @@ def test_if():
 
 
 if __name__ == "__main__":
-    test_flatten()
-    test_reshape()
-    test_shape()
-    test_expand()
-    test_power()
-    test_squeeze()
-    test_unsqueeze()
-    test_slice()
-    test_floor()
-    test_ceil()
-    test_round()
-    test_isinf()
-    test_isnan()
-    test_clip()
-    test_clip_min_max_as_inputs()
-    test_onehot()
-    test_matmul()
-    test_gather()
-    test_gatherelements()
-    test_gather_nd()
-    test_scatter()
-    test_lrn()
-    test_instance_norm()
-    test_upsample()
-    test_forward_min()
-    test_forward_max()
-    test_forward_mean()
-    test_forward_hardsigmoid()
-    test_forward_arg_min_max()
-    test_softmax()
-    test_constantofshape()
-    test_all_reduce_funcs()
-    test_pad()
-    test_split()
-    test_binary_ops()
-    test_single_ops()
-    test_leaky_relu()
-    test_elu()
-    test_selu()
-    test_prelu()
-    test_ThresholdedRelu()
-    test_ScaledTanh()
-    test_ParametricSoftplus()
-    test_Scale()
-    test_LogSoftmax()
-    test_resnet()
-    test_inception()
-    test_densenet()
-    test_sign()
-    test_not()
-    test_and()
-    test_tile()
-    test_erf()
-    test_where()
-    test_or()
-    test_depth_to_space()
-    test_space_to_depth()
-    test_batch_norm()
-    test_batch_norm_dynamic_subgraph()
-    test_conv()
-    test_convtranspose()
-    test_unsqueeze_constant()
-    test_pooling()
-    test_lppool()
-    test_lstm()
-    test_gru()
-    test_resize()
-    test_nonzero()
-    test_topk()
-    test_mod()
-    test_xor()
-    test_max_roi_pool()
-    test_roi_align()
-    test_range()
-    test_loop()
+    import sys
+    import pytest
+
+    pytest.main(sys.argv)

Review comment:
       What's the thought behind this change?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org