You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by di...@apache.org on 2021/04/27 01:57:42 UTC
[incubator-mxnet] branch master updated: Fix workspace of BoxNMS
(#20212)
This is an automated email from the ASF dual-hosted git repository.
dickjc123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 915aa55 Fix workspace of BoxNMS (#20212)
915aa55 is described below
commit 915aa558d02e6a77e49f906425a14537b7af4209
Author: Przemyslaw Tredak <pt...@nvidia.com>
AuthorDate: Mon Apr 26 18:56:10 2021 -0700
Fix workspace of BoxNMS (#20212)
---
src/operator/contrib/bounding_box.cu | 33 +++++++++++++++++----------------
1 file changed, 17 insertions(+), 16 deletions(-)
diff --git a/src/operator/contrib/bounding_box.cu b/src/operator/contrib/bounding_box.cu
index a3d9184..7915273 100644
--- a/src/operator/contrib/bounding_box.cu
+++ b/src/operator/contrib/bounding_box.cu
@@ -40,23 +40,23 @@ using mshadow::Stream;
template <typename DType>
struct TempWorkspace {
- index_t scores_temp_space;
+ size_t scores_temp_space;
DType* scores;
- index_t scratch_space;
+ size_t scratch_space;
uint8_t* scratch;
- index_t buffer_space;
+ size_t buffer_space;
DType* buffer;
- index_t nms_scratch_space;
+ size_t nms_scratch_space;
uint32_t* nms_scratch;
- index_t indices_temp_spaces;
+ size_t indices_temp_spaces;
index_t* indices;
};
-inline index_t ceil_div(index_t x, index_t y) {
+inline size_t ceil_div(size_t x, size_t y) {
return (x + y - 1) / y;
}
-inline index_t align(index_t x, index_t alignment) {
+inline size_t align(size_t x, size_t alignment) {
return ceil_div(x, alignment) * alignment;
}
@@ -150,7 +150,7 @@ void CompactData(const Tensor<gpu, 1, index_t>& indices,
const int score_index,
Stream<gpu>* s) {
const int n_threads = 512;
- const index_t max_blocks = 320;
+ const size_t max_blocks = 320;
index_t N = source.shape_.Size();
const auto blocks = std::min(ceil_div(N, n_threads), max_blocks);
if (topk > 0) {
@@ -175,9 +175,9 @@ void WorkspaceForSort(const index_t num_elem,
const index_t topk,
const int alignment,
TempWorkspace<DType>* workspace) {
- const index_t sort_scores_temp_space =
+ const size_t sort_scores_temp_space =
mxnet::op::SortByKeyWorkspaceSize<DType, index_t, gpu>(num_elem, false, false);
- const index_t sort_topk_scores_temp_space =
+ const size_t sort_topk_scores_temp_space =
mxnet::op::SortByKeyWorkspaceSize<DType, index_t, gpu>(topk, false, false);
workspace->scratch_space = align(std::max(sort_scores_temp_space, sort_topk_scores_temp_space),
alignment);
@@ -529,14 +529,15 @@ TempWorkspace<DType> GetWorkspace(const index_t num_batch,
workspace.nms_scratch_space = align(NMS<DType>::THRESHOLD / (sizeof(uint32_t) * 8) *
num_batch * topk * sizeof(uint32_t), alignment);
- const index_t workspace_size = workspace.scores_temp_space +
- workspace.scratch_space +
- workspace.nms_scratch_space +
- workspace.indices_temp_spaces;
+ const size_t workspace_size = workspace.scores_temp_space +
+ workspace.scratch_space +
+ workspace.buffer_space +
+ workspace.nms_scratch_space +
+ workspace.indices_temp_spaces;
// Obtain the memory for workspace
- Tensor<gpu, 1, uint8_t> scratch_memory = ctx.requested[box_nms_enum::kTempSpace]
- .get_space_typed<gpu, 1, uint8_t>(mshadow::Shape1(workspace_size), s);
+ Tensor<gpu, 1, double> scratch_memory = ctx.requested[box_nms_enum::kTempSpace]
+ .get_space_typed<gpu, 1, double>(mshadow::Shape1(ceil_div(workspace_size, sizeof(double))), s);
// Populate workspace pointers
workspace.scores = reinterpret_cast<DType*>(scratch_memory.dptr_);