You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/12/30 09:00:39 UTC

[tvm] branch main updated: [TOPI] Parallelize GPU NMS inner loop (#7172)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 66e123f  [TOPI] Parallelize GPU NMS inner loop (#7172)
66e123f is described below

commit 66e123ff7ce4f5524b3f51ccd95bd4010b7af2c6
Author: masahi <ma...@gmail.com>
AuthorDate: Wed Dec 30 18:00:22 2020 +0900

    [TOPI] Parallelize GPU NMS inner loop (#7172)
    
    * make NMS inner loop parallel
    
    * use one block two avoid global sync issue
    
    * temp disable write by only thread 0
    
    * leave a TODO on write by only one thread
    
    * add some comments, remove check the check on negative class id
    
    * minor improvement when topk is available
    
    * fix write by a single thread
---
 python/tvm/topi/cuda/nms.py | 50 +++++++++++++++++++++++++++++----------------
 1 file changed, 32 insertions(+), 18 deletions(-)

diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py
index 020cf9b..dd9d3f8 100644
--- a/python/tvm/topi/cuda/nms.py
+++ b/python/tvm/topi/cuda/nms.py
@@ -512,26 +512,44 @@ def nms_ir(
 
     with ib.new_scope():
         nthread_by = batch_size
+        nthread_tx = max_threads
+
         by = te.thread_axis("blockIdx.y")
+        tx = te.thread_axis("threadIdx.x")
         ib.scope_attr(by, "thread_extent", nthread_by)
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+
         i = by
+
         base_idx = i * num_anchors * box_data_length
         num_valid_boxes_local = ib.allocate(
             "int32", (1,), name="num_valid_boxes_local", scope="local"
         )
         num_valid_boxes_local[0] = 0
+        nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i])
 
         def nms_inner_loop(ib, j):
+            # The box j is valid, invalidate other boxes that overlap with j above iou_threshold
+
+            # When return_indices is False, no need to populate box_indices
+            if return_indices:
+                with ib.if_scope(tx + 0 == 0):
+                    orig_idx = sorted_index[i * num_anchors + j]
+                    box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx]
+
+            num_valid_boxes_local[0] += 1
+
             offset_j = j * box_data_length
+            num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx)
 
-            with ib.for_range(0, j) as k:
+            with ib.for_range(0, num_iter_per_thread) as _k:
+                k = j + 1 + _k * nthread_tx + tx
                 offset_k = k * box_data_length
 
                 with ib.if_scope(
                     tvm.tir.all(
-                        out[base_idx + offset_j + score_index] > -1.0,  # if already surpressed
-                        out[base_idx + offset_k + score_index] > 0,
-                        tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0),
+                        k < nkeep,
+                        out[base_idx + offset_k + score_index] > 0,  # is the box k still valid?
                         tvm.tir.any(
                             force_suppress > 0,
                             id_index < 0,
@@ -546,27 +564,22 @@ def nms_ir(
                         base_idx + offset_k + coord_start,
                     )
                     with ib.if_scope(iou >= iou_threshold):
-                        out[base_idx + offset_j + score_index] = -1.0
+                        # invalidate the box k
+                        out[base_idx + offset_k + score_index] = -1.0
                         with ib.if_scope(id_index >= 0):
-                            out[base_idx + offset_j + id_index] = -1.0
+                            out[base_idx + offset_k + id_index] = -1.0
 
-            # Has the box j survived IOU tests?
-            with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0):
-                # When return_indices is False, no need to populate box_indices
-                if return_indices:
-                    orig_idx = sorted_index[i * num_anchors + j]
-                    box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx]
-                num_valid_boxes_local[0] += 1
+                # Make sure to do the next loop in a lock step
+                ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])))
 
         if isinstance(max_output_size, int):
             max_output_size = tvm.tir.const(max_output_size)
 
         with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
             # Apply nms
-            with ib.for_range(0, valid_count[i]) as j:
-                with ib.if_scope(
-                    tvm.tir.any(id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0)
-                ):
+            with ib.for_range(0, nkeep) as j:
+                # Proceed to the inner loop if the box j is still valid
+                with ib.if_scope(out[base_idx + (j * box_data_length) + score_index] > -1.0):
                     with ib.if_scope(max_output_size > 0):
                         # No need to do more iteration if we already reach max_output_size boxes
                         with ib.if_scope(num_valid_boxes_local[0] < max_output_size):
@@ -574,7 +587,8 @@ def nms_ir(
                     with ib.else_scope():
                         nms_inner_loop(ib, j)
 
-            num_valid_boxes[i] = num_valid_boxes_local[0]
+            with ib.if_scope(tx + 0 == 0):
+                num_valid_boxes[i] = num_valid_boxes_local[0]
 
         with ib.else_scope():
             num_valid_boxes[i] = 0