You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/11/06 22:43:09 UTC

[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6839: [WIP][ONNX] NMS in ONNX

mbrookhart commented on a change in pull request #6839:
URL: https://github.com/apache/incubator-tvm/pull/6839#discussion_r519041432



##########
File path: python/tvm/topi/cuda/nms.py
##########
@@ -519,14 +557,90 @@ def non_max_suppression(
             coord_start,
             id_index,
             score_index,
+            return_indices,
         ),
         dtype=[data.dtype, "int32"],
-        in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
+        in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf],
         name="nms",
         tag="nms",
     )
-    # TODO(yongwww): Update cuda nms to be consistent with cpu version
     if return_indices:
-        return box_indices
+        out_shape = box_indices.shape
+        valid_box_count_shape = [box_indices.shape[0], 1]
+        valid_box_count = tvm.tir.decl_buffer(valid_box_count_shape, "int32", "valid_box_count")
+        output = tvm.tir.decl_buffer(box_indices.shape, "int32", "output")
+        return te.extern(
+            [out_shape, valid_box_count_shape],
+            [box_indices],
+            lambda ins, outs: rearrange_indices_out_ir(ins[0], outs[0], outs[1]),
+            dtype="int32",
+            out_buffers=[output, valid_box_count],
+            name="rearrange_indices_out_gpu",
+            tag="rearrange_indices_out_gpu",
+        )
 
     return out
+
+
+def rearrange_indices_out_ir(data, output, valid_box_count):
+    """Hybrid routine to rearrange nms output to
+    move all valid entries to top.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor or numpy NDArray
+        NMS output. 3-D tensor with shape
+        [batch_size, num_anchors, 6] or
+        [batch_size, num_anchors, 5], or 2-D
+        tensor with shape [batch_size, num_anchors].
+
+    one: tvm.tir.const
+        Constant one with the same dtype as data.
+
+    batch_size: tvm.tir.IntImm or tvm.tir.Var
+        Batch size. We need to pass it in since hybrid script doesn't support
+        binding variable to symbolic dim.
+
+    num_anchors: tvm.tir.IntImm or tvm.tir.Var
+        Number of anchors.
+
+    Returns
+    -------
+    output : tvm.te.Tensor or numpy NDArray
+        2-D tensor with shape [batch_size, num_anchors].
+
+    valid_box_count : tvm.te.Tensor or numpy NDArray
+        Tensor with shape [batch_size, 1], indicates
+        the valid number of boxes.
+    """
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
+
+    ib = tvm.tir.ir_builder.create()
+
+    data = ib.buffer_ptr(data)
+    valid_box_count = ib.buffer_ptr(valid_box_count)
+    output = ib.buffer_ptr(output)
+
+    with ib.new_scope():
+        i = te.thread_axis("blockIdx.x")
+        ib.scope_attr(i, "thread_extent", batch_size)
+        valid_idx = ib.allocate("int32", (1), name="valid_idx", scope="local")
+        valid_idx[0] = 0
+        with ib.for_range(0, num_anchors, name="j") as j:
+            with ib.if_scope(data[i, j] >= 0):
+                with ib.if_scope(data[i, j] > num_anchors):
+                    output[i, valid_idx[0]] = 0
+                    valid_idx[0] = valid_idx[0] + 1
+                with ib.else_scope():
+                    output[i, valid_idx[0]] = data[i, j]
+                    valid_idx[0] = valid_idx[0] + 1
+            with ib.else_scope():
+                with ib.if_scope(data[i, j] < -num_anchors):
+                    output[i, valid_idx[0]] = 0
+                    valid_idx[0] = valid_idx[0] + 1
+            with ib.if_scope(j >= valid_idx[0]):
+                output[i, j] = -1
+        valid_box_count[i, 0] = valid_idx[0]
+
+    return ib.get()

Review comment:
       I've been attempted to get SSD-Mobilenet from the ONNX model zoo working, but I"m hitting other bugs. I'll post some perf metrics as soon as I can get that working.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org