You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2022/10/26 22:09:30 UTC
[incubator-mxnet] branch v1.9.x updated: [BUGFIX] Fix nms kernel's out of range access issue (#21018)
This is an automated email from the ASF dual-hosted git repository.
ptrendx pushed a commit to branch v1.9.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.9.x by this push:
new 7acfb3a50e [BUGFIX] Fix nms kernel's out of range access issue (#21018)
7acfb3a50e is described below
commit 7acfb3a50e2cd737dc7b564b161de83e70016adf
Author: Triston <tr...@gmail.com>
AuthorDate: Wed Oct 26 15:09:13 2022 -0700
[BUGFIX] Fix nms kernel's out of range access issue (#21018)
* Fix the nms kernel out of range access issue
* Add static_assert error message for CalculateGreedyNMSResultsKernel
* Fix lint
Co-authored-by: Triston Cao <tr...@nvidia.com>
---
src/operator/contrib/bounding_box.cu | 61 +++++++++++++++++++++++-------------
1 file changed, 40 insertions(+), 21 deletions(-)
diff --git a/src/operator/contrib/bounding_box.cu b/src/operator/contrib/bounding_box.cu
index 4c373c22ef..db0cb54eaa 100644
--- a/src/operator/contrib/bounding_box.cu
+++ b/src/operator/contrib/bounding_box.cu
@@ -221,13 +221,12 @@ __global__ void ReduceNMSResultRestKernel(DType* data,
template <typename DType>
struct NMS {
static constexpr int THRESHOLD = 512;
-
+ static constexpr int n_threads = 512;
void operator()(Tensor<gpu, 3, DType>* data,
Tensor<gpu, 2, uint32_t>* scratch,
const index_t topk,
const BoxNMSParam& param,
Stream<gpu>* s) {
- const int n_threads = 512;
const index_t num_batches = data->shape_[0];
const index_t num_elements_per_batch = data->shape_[1];
const index_t element_width = data->shape_[2];
@@ -321,7 +320,7 @@ __device__ __forceinline__ DType calculate_intersection(const DType a0, const DT
}
template <int encode, typename DType>
-__launch_bounds__(512)
+__launch_bounds__(NMS<DType>::n_threads)
__global__ void CalculateGreedyNMSResultsKernel(const DType* data, uint32_t* result,
const index_t current_start,
const index_t num_elems,
@@ -335,33 +334,54 @@ __global__ void CalculateGreedyNMSResultsKernel(const DType* data, uint32_t* res
const int class_index,
const int score_index,
const float threshold) {
- constexpr int max_elem_width = 20;
constexpr int num_other_boxes = sizeof(uint32_t) * 8;
- __shared__ DType other_boxes[max_elem_width * num_other_boxes];
+ constexpr int n_index = 6; // 4x coord, class and score
+ __shared__ int indices[n_index];
+ __shared__ DType other_boxes[n_index * num_other_boxes];
__shared__ DType other_boxes_areas[num_other_boxes];
+ constexpr int local_coord_index = 0;
+ constexpr int local_class_index = 4;
+ constexpr int local_score_index = 5;
+ constexpr int index_stride = NMS<DType>::n_threads / n_index;
+ // Ensure that we only need 1 loop iteration to get all the data
+ static_assert(index_stride >= num_other_boxes,
+ "Kernel is launched with too small number of threads");
+ if (threadIdx.x == 0) {
+ indices[local_coord_index + 0] = coord_index + 0;
+ indices[local_coord_index + 1] = coord_index + 1;
+ indices[local_coord_index + 2] = coord_index + 2;
+ indices[local_coord_index + 3] = coord_index + 3;
+ // If class index is -1 load any value that we know exists
+ indices[local_class_index] = class_index != -1 ? class_index : coord_index;
+ indices[local_score_index] = score_index;
+ }
+ __syncthreads();
+
const index_t my_row = blockIdx.x / num_blocks_per_row;
const index_t my_block_offset_in_row = blockIdx.x % num_blocks_per_row;
const index_t my_block_offset_in_batch = my_block_offset_in_row % num_blocks_per_row_batch;
const index_t my_batch = (my_block_offset_in_row) / num_blocks_per_row_batch;
const index_t my_element_in_batch = my_block_offset_in_batch * blockDim.x +
current_start + threadIdx.x;
-
// Load other boxes
const index_t offset = (my_batch * num_elements_per_batch +
current_start + my_row * num_other_boxes) *
element_width;
- for (int i = threadIdx.x; i < element_width * num_other_boxes; i += blockDim.x) {
- other_boxes[i] = data[offset + i];
+ int my_index = threadIdx.x % n_index;
+ int my_element = threadIdx.x / n_index;
+ if (my_element < num_other_boxes) {
+ other_boxes[n_index * my_element + my_index] = data[offset + my_element * element_width +
+ indices[my_index]];
}
__syncthreads();
if (threadIdx.x < num_other_boxes) {
- const int other_boxes_offset = element_width * threadIdx.x;
+ const int other_boxes_offset = n_index * threadIdx.x;
const DType their_area = calculate_area<encode>(
- other_boxes[other_boxes_offset + coord_index + 0],
- other_boxes[other_boxes_offset + coord_index + 1],
- other_boxes[other_boxes_offset + coord_index + 2],
- other_boxes[other_boxes_offset + coord_index + 3]);
+ other_boxes[other_boxes_offset + local_coord_index + 0],
+ other_boxes[other_boxes_offset + local_coord_index + 1],
+ other_boxes[other_boxes_offset + local_coord_index + 2],
+ other_boxes[other_boxes_offset + local_coord_index + 3]);
other_boxes_areas[threadIdx.x] = their_area;
}
__syncthreads();
@@ -387,17 +407,16 @@ __global__ void CalculateGreedyNMSResultsKernel(const DType* data, uint32_t* res
if (my_score != -1) {
#pragma unroll
for (int i = 0; i < num_other_boxes; ++i) {
- const int other_boxes_offset = element_width * i;
- if ((class_index == -1 || my_class == other_boxes[other_boxes_offset + class_index]) &&
- other_boxes[other_boxes_offset + score_index] != -1) {
+ const int other_boxes_offset = n_index * i;
+ if ((class_index == -1 || my_class == other_boxes[other_boxes_offset + local_class_index]) &&
+ other_boxes[other_boxes_offset + local_score_index] != -1) {
const DType their_area = other_boxes_areas[i];
-
const DType intersect = calculate_intersection<encode>(
my_box[0], my_box[1], my_box[2], my_box[3],
- other_boxes[other_boxes_offset + coord_index + 0],
- other_boxes[other_boxes_offset + coord_index + 1],
- other_boxes[other_boxes_offset + coord_index + 2],
- other_boxes[other_boxes_offset + coord_index + 3]);
+ other_boxes[other_boxes_offset + local_coord_index + 0],
+ other_boxes[other_boxes_offset + local_coord_index + 1],
+ other_boxes[other_boxes_offset + local_coord_index + 2],
+ other_boxes[other_boxes_offset + local_coord_index + 3]);
if (intersect > threshold * (my_area + their_area - intersect)) {
ret = ret | (1u << i);
}