You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/11/06 22:34:59 UTC

[GitHub] [incubator-tvm] Laurawly commented on a change in pull request #6839: [WIP][ONNX] NMS in ONNX

Laurawly commented on a change in pull request #6839:
URL: https://github.com/apache/incubator-tvm/pull/6839#discussion_r519038726



##########
File path: python/tvm/topi/cuda/nms.py
##########
@@ -519,14 +557,90 @@ def non_max_suppression(
             coord_start,
             id_index,
             score_index,
+            return_indices,
         ),
         dtype=[data.dtype, "int32"],
-        in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
+        in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf],
         name="nms",
         tag="nms",
     )
-    # TODO(yongwww): Update cuda nms to be consistent with cpu version
     if return_indices:
-        return box_indices
+        out_shape = box_indices.shape
+        valid_box_count_shape = [box_indices.shape[0], 1]
+        valid_box_count = tvm.tir.decl_buffer(valid_box_count_shape, "int32", "valid_box_count")
+        output = tvm.tir.decl_buffer(box_indices.shape, "int32", "output")
+        return te.extern(
+            [out_shape, valid_box_count_shape],
+            [box_indices],
+            lambda ins, outs: rearrange_indices_out_ir(ins[0], outs[0], outs[1]),
+            dtype="int32",
+            out_buffers=[output, valid_box_count],
+            name="rearrange_indices_out_gpu",
+            tag="rearrange_indices_out_gpu",
+        )
 
     return out
+
+
+def rearrange_indices_out_ir(data, output, valid_box_count):
+    """Hybrid routine to rearrange nms output to
+    move all valid entries to top.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor or numpy NDArray
+        NMS output. 3-D tensor with shape
+        [batch_size, num_anchors, 6] or
+        [batch_size, num_anchors, 5], or 2-D
+        tensor with shape [batch_size, num_anchors].
+
+    one: tvm.tir.const
+        Constant one with the same dtype as data.
+
+    batch_size: tvm.tir.IntImm or tvm.tir.Var
+        Batch size. We need to pass it in since hybrid script doesn't support
+        binding variable to symbolic dim.
+
+    num_anchors: tvm.tir.IntImm or tvm.tir.Var
+        Number of anchors.
+
+    Returns
+    -------
+    output : tvm.te.Tensor or numpy NDArray
+        2-D tensor with shape [batch_size, num_anchors].
+
+    valid_box_count : tvm.te.Tensor or numpy NDArray
+        Tensor with shape [batch_size, 1], indicates
+        the valid number of boxes.
+    """
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
+
+    ib = tvm.tir.ir_builder.create()
+
+    data = ib.buffer_ptr(data)
+    valid_box_count = ib.buffer_ptr(valid_box_count)
+    output = ib.buffer_ptr(output)
+
+    with ib.new_scope():
+        i = te.thread_axis("blockIdx.x")
+        ib.scope_attr(i, "thread_extent", batch_size)
+        valid_idx = ib.allocate("int32", (1), name="valid_idx", scope="local")
+        valid_idx[0] = 0
+        with ib.for_range(0, num_anchors, name="j") as j:
+            with ib.if_scope(data[i, j] >= 0):
+                with ib.if_scope(data[i, j] > num_anchors):
+                    output[i, valid_idx[0]] = 0
+                    valid_idx[0] = valid_idx[0] + 1
+                with ib.else_scope():
+                    output[i, valid_idx[0]] = data[i, j]
+                    valid_idx[0] = valid_idx[0] + 1
+            with ib.else_scope():
+                with ib.if_scope(data[i, j] < -num_anchors):
+                    output[i, valid_idx[0]] = 0
+                    valid_idx[0] = valid_idx[0] + 1
+            with ib.if_scope(j >= valid_idx[0]):
+                output[i, j] = -1
+        valid_box_count[i, 0] = valid_idx[0]
+
+    return ib.get()

Review comment:
       Could you show the performance benchmark on some popular OD model workloads with the modified nms.py?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org