You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/03/05 00:22:30 UTC

[incubator-mxnet] branch master updated: Optimize NMS (#14290)

This is an automated email from the ASF dual-hosted git repository.

zhreshold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 780bddc  Optimize NMS (#14290)
780bddc is described below

commit 780bddcf1241c8de93c2b26b82a7f0ee093a8662
Author: Przemyslaw Tredak <pt...@gmail.com>
AuthorDate: Mon Mar 4 16:22:08 2019 -0800

    Optimize NMS (#14290)
    
    * Optimize NMS
    
    * Fix lint
---
 src/operator/contrib/bounding_box-common.h | 118 +++++++++++++++
 src/operator/contrib/bounding_box-inl.cuh  | 223 +++++++++++++++++++++++++++++
 src/operator/contrib/bounding_box-inl.h    | 115 ++++-----------
 3 files changed, 368 insertions(+), 88 deletions(-)

diff --git a/src/operator/contrib/bounding_box-common.h b/src/operator/contrib/bounding_box-common.h
new file mode 100644
index 0000000..70215ab
--- /dev/null
+++ b/src/operator/contrib/bounding_box-common.h
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file bounding_box-common.h
+ * \brief bounding box util functions and operators commonly used by CPU and GPU implementations
+ * \author Joshua Zhang
+*/
+#ifndef MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_COMMON_H_
+#define MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_COMMON_H_
+#include "../mshadow_op.h"
+#include "../mxnet_op.h"
+#include "../operator_common.h"
+
+namespace mxnet {
+namespace op {
+namespace box_common_enum {
+enum BoxType {kCorner, kCenter};
+}
+
+// compute line intersect along either height or width
+template<typename DType>
+MSHADOW_XINLINE DType Intersect(const DType *a, const DType *b, int encode) {
+  DType a1 = a[0];
+  DType a2 = a[2];
+  DType b1 = b[0];
+  DType b2 = b[2];
+  DType w;
+  if (box_common_enum::kCorner == encode) {
+    DType left = a1 > b1 ? a1 : b1;
+    DType right = a2 < b2 ? a2 : b2;
+    w = right - left;
+  } else {
+    DType aw = a2 / 2;
+    DType bw = b2 / 2;
+    DType al = a1 - aw;
+    DType ar = a1 + aw;
+    DType bl = b1 - bw;
+    DType br = b1 + bw;
+    DType left = bl > al ? bl : al;
+    DType right = br < ar ? br : ar;
+    w = right - left;
+  }
+  return w > 0 ? w : DType(0);
+}
+
+/*!
+   * \brief Implementation of the non-maximum suppression operation
+   *
+   * \param i the launched thread index
+   * \param index sorted index in descending order
+   * \param batch_start map (b, k) to compact index by indices[batch_start[b] + k]
+   * \param input the input of nms op
+   * \param areas pre-computed box areas
+   * \param k nms topk number
+   * \param ref compare reference position
+   * \param num number of input boxes in each batch
+   * \param stride input stride, usually 6 (id-score-x1-y1-x2-y2)
+   * \param offset_box box offset, usually 2
+   * \param thresh nms threshold
+   * \param force force suppress regardless of class id
+   * \param offset_id class id offset, used when force == false, usually 0
+   * \param encode box encoding type, corner(0) or center(1)
+   * \param DType the data type
+   */
+struct nms_impl {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, int32_t *index, const int32_t *batch_start,
+                                  const DType *input, const DType *areas,
+                                  int k, int ref, int num,
+                                  int stride, int offset_box, int offset_id,
+                                  float thresh, bool force, int encode) {
+    int b = i / k;  // batch
+    int pos = i % k + ref + 1;  // position
+    ref = static_cast<int>(batch_start[b]) + ref;
+    pos = static_cast<int>(batch_start[b]) + pos;
+    if (ref >= static_cast<int>(batch_start[b + 1])) return;
+    if (pos >= static_cast<int>(batch_start[b + 1])) return;
+    if (index[ref] < 0) return;  // reference has been suppressed
+    if (index[pos] < 0) return;  // self been suppressed
+    int ref_offset = static_cast<int>(index[ref]) * stride + offset_box;
+    int pos_offset = static_cast<int>(index[pos]) * stride + offset_box;
+    if (!force && offset_id >=0) {
+      int ref_id = static_cast<int>(input[ref_offset - offset_box + offset_id]);
+      int pos_id = static_cast<int>(input[pos_offset - offset_box + offset_id]);
+      if (ref_id != pos_id) return;  // different class
+    }
+    DType intersect = Intersect(input + ref_offset, input + pos_offset, encode);
+    intersect *= Intersect(input + ref_offset + 1, input + pos_offset + 1, encode);
+    int ref_area_offset = static_cast<int>(index[ref]);
+    int pos_area_offset = static_cast<int>(index[pos]);
+    DType iou = intersect / (areas[ref_area_offset] + areas[pos_area_offset] - intersect);
+    if (iou > thresh) {
+      index[pos] = -1;
+    }
+  }
+};
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_COMMON_H_
diff --git a/src/operator/contrib/bounding_box-inl.cuh b/src/operator/contrib/bounding_box-inl.cuh
index fd5e30b..4b7cf34 100644
--- a/src/operator/contrib/bounding_box-inl.cuh
+++ b/src/operator/contrib/bounding_box-inl.cuh
@@ -24,12 +24,15 @@
 */
 #ifndef MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_INL_CUH_
 #define MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_INL_CUH_
+#include <cmath>
+#include <cstdio>
 #include <mxnet/operator_util.h>
 #include <thrust/copy.h>
 #include <thrust/execution_policy.h>
 #include "../mshadow_op.h"
 #include "../mxnet_op.h"
 #include "../operator_common.h"
+#include "./bounding_box-common.h"
 
 namespace mxnet {
 namespace op {
@@ -57,6 +60,226 @@ int FilterScores(mshadow::Tensor<gpu, 1, DType> out_scores,
   return end_scores - out_scores.dptr_;
 }
 
+// compute line intersect along either height or width
+template<typename DType>
+MSHADOW_XINLINE DType Intersect2(const DType *a, const DType b1, const DType b2, int encode) {
+  const DType a1 = a[0];
+  const DType a2 = a[2];
+  DType left, right;
+  if (box_common_enum::kCorner == encode) {
+    left = a1 > b1 ? a1 : b1;
+    right = a2 < b2 ? a2 : b2;
+  } else {
+    const DType aw = a2 / 2;
+    const DType bw = b2 / 2;
+    const DType al = a1 - aw;
+    const DType ar = a1 + aw;
+    const DType bl = b1 - bw;
+    const DType br = b1 + bw;
+    left = bl > al ? bl : al;
+    right = br < ar ? br : ar;
+  }
+  const DType w = right - left;
+  return w > 0 ? w : DType(0);
+}
+
+template<typename DType, int N, bool check_class>
+__launch_bounds__(512)
+__global__ void nms_apply_kernel(const int topk, int32_t *index,
+                                 const int32_t *batch_start,
+                                 const DType *input,
+                                 const DType *areas,
+                                 const int num, const int stride,
+                                 const int offset_box, const int offset_id,
+                                 const float thresh, const bool force,
+                                 const int encode, const int start_offset) {
+  constexpr int block_size = 512;
+  const int start = static_cast<int>(batch_start[blockIdx.x]) + start_offset;
+  const int size_of_batch = static_cast<int>(batch_start[blockIdx.x + 1]) - start;
+  const int end = min(min(size_of_batch, topk - start_offset), N * block_size);
+  __shared__ int s_index[N * block_size];
+
+  for (int i = threadIdx.x; i < end; i += block_size) {
+    s_index[i] = static_cast<int>(index[start + i]);
+  }
+
+  __syncthreads();
+  for (int ref = 0; ref < end; ++ref) {
+    const int ref_area_offset = static_cast<int>(s_index[ref]);
+    if (ref_area_offset >= 0) {
+      const int ref_offset = ref_area_offset * stride + offset_box;
+      int ref_id = 0;
+      if (check_class) {
+        ref_id = static_cast<int>(input[ref_offset - offset_box + offset_id]);
+      }
+      for (int i = 0; i < N; ++i) {
+        const int my_pos = threadIdx.x + i * block_size;
+        if (my_pos > ref && my_pos < end && s_index[my_pos] >= 0) {
+          const int pos_area_offset = static_cast<int>(s_index[my_pos]);
+          const int pos_offset = pos_area_offset * stride + offset_box;
+          if (check_class) {
+            const int pos_id = static_cast<int>(input[pos_offset - offset_box + offset_id]);
+            if (ref_id != pos_id) continue;  // different class
+          }
+          DType intersect = Intersect(input + ref_offset, input + pos_offset, encode);
+          intersect *= Intersect(input + ref_offset + 1, input + pos_offset + 1, encode);
+          const DType iou = intersect /
+                            (areas[ref_area_offset] + areas[pos_area_offset] - intersect);
+          if (iou > thresh) {
+            s_index[my_pos] = -1;
+          }
+        }
+      }
+      __syncthreads();
+    }
+  }
+
+  for (int i = threadIdx.x; i < end; i += block_size) {
+    index[start + i] = s_index[i];
+  }
+}
+
+template<typename DType, int N, bool check_class>
+__launch_bounds__(512)
+__global__ void nms_apply_kernel_rest(const int topk, int32_t *index,
+                                 const int32_t *batch_start,
+                                 const DType *input,
+                                 const DType *areas,
+                                 const int num, const int stride,
+                                 const int offset_box, const int offset_id,
+                                 const float thresh, const bool force,
+                                 const int encode, const int start_offset,
+                                 const int blocks_per_batch) {
+  constexpr int block_size = 512;
+  const int batch = blockIdx.x / blocks_per_batch;
+  const int start_ref = static_cast<int>(batch_start[batch]) + start_offset;
+  const int block_offset = (N + blockIdx.x % blocks_per_batch) * block_size;
+  const int start = start_ref + block_offset;
+
+  const int size_of_batch = static_cast<int>(batch_start[batch + 1]) - start;
+  const int end = min(size_of_batch, topk - start_offset - block_offset);
+  const int my_pos = start + threadIdx.x;
+  if (threadIdx.x < end && index[my_pos] >= 0) {
+    const int pos_area_offset = static_cast<int>(index[my_pos]);
+    const int pos_offset = pos_area_offset * stride + offset_box;
+    DType my_box[4];
+#pragma unroll
+    for (int i = 0; i < 4; ++i) {
+      my_box[i] = input[pos_offset + i];
+    }
+    const DType my_area = areas[pos_area_offset];
+    int pos_id = 0;
+    if (check_class) {
+      pos_id = static_cast<int>(input[pos_offset - offset_box + offset_id]);
+    }
+
+    for (int ref = start_ref; ref < start_ref + N * block_size; ++ref) {
+      const int ref_area_offset = static_cast<int>(index[ref]);
+      if (ref_area_offset >= 0) {
+        const int ref_offset = ref_area_offset * stride + offset_box;
+        int ref_id = 0;
+        if (check_class) {
+          ref_id = static_cast<int>(input[ref_offset - offset_box + offset_id]);
+          if (ref_id != pos_id) continue;  // different class
+        }
+        DType intersect = Intersect2(input + ref_offset, my_box[0], my_box[2], encode);
+        intersect *= Intersect2(input + ref_offset + 1, my_box[1], my_box[3], encode);
+        const DType iou = intersect /
+          (areas[ref_area_offset] + my_area - intersect);
+        if (iou > thresh) {
+          index[my_pos] = -1;
+          break;
+        }
+      }
+    }
+  }
+}
+
+template<typename DType>
+void NMSApply(mshadow::Stream<gpu> *s,
+              int num_batch, int topk,
+              mshadow::Tensor<gpu, 1, int32_t>* sorted_index,
+              mshadow::Tensor<gpu, 1, int32_t>* batch_start,
+              mshadow::Tensor<gpu, 3, DType>* buffer,
+              mshadow::Tensor<gpu, 1, DType>* areas,
+              int num_elem, int width_elem,
+              int coord_start, int id_index,
+              float threshold, bool force_suppress,
+              int in_format) {
+  using namespace mxnet_op;
+  constexpr int THRESHOLD = 1024;
+  for (int ref = 0; ref < topk; ref += THRESHOLD) {
+    constexpr int block_size = 512;
+    constexpr int N = THRESHOLD / block_size;
+    auto stream = mshadow::Stream<gpu>::GetStream(s);
+    if (!force_suppress && id_index >= 0) {
+      nms_apply_kernel<DType, N, true><<<num_batch, block_size, 0, stream>>>(topk,
+                                                                      sorted_index->dptr_,
+                                                                      batch_start->dptr_,
+                                                                      buffer->dptr_,
+                                                                      areas->dptr_,
+                                                                      num_elem,
+                                                                      width_elem,
+                                                                      coord_start,
+                                                                      id_index,
+                                                                      threshold,
+                                                                      force_suppress,
+                                                                      in_format,
+                                                                      ref);
+      int blocks_per_batch = (topk - ref - THRESHOLD + block_size - 1)/block_size;
+      int blocks = blocks_per_batch  * num_batch;
+      if (blocks > 0) {
+        nms_apply_kernel_rest<DType, N, true><<<blocks, block_size, 0, stream>>>(topk,
+                                                                        sorted_index->dptr_,
+                                                                        batch_start->dptr_,
+                                                                        buffer->dptr_,
+                                                                        areas->dptr_,
+                                                                        num_elem,
+                                                                        width_elem,
+                                                                        coord_start,
+                                                                        id_index,
+                                                                        threshold,
+                                                                        force_suppress,
+                                                                        in_format,
+                                                                        ref,
+                                                                        blocks_per_batch);
+      }
+    } else {
+      nms_apply_kernel<DType, N, false><<<num_batch, block_size, 0, stream>>>(topk,
+                                                                       sorted_index->dptr_,
+                                                                       batch_start->dptr_,
+                                                                       buffer->dptr_,
+                                                                       areas->dptr_,
+                                                                       num_elem,
+                                                                       width_elem,
+                                                                       coord_start,
+                                                                       id_index,
+                                                                       threshold,
+                                                                       force_suppress,
+                                                                       in_format,
+                                                                       ref);
+      int blocks_per_batch = (topk - ref - THRESHOLD + block_size - 1)/block_size;
+      int blocks = blocks_per_batch  * num_batch;
+      if (blocks > 0) {
+        nms_apply_kernel_rest<DType, N, false><<<blocks, block_size, 0, stream>>>(topk,
+                                                                        sorted_index->dptr_,
+                                                                        batch_start->dptr_,
+                                                                        buffer->dptr_,
+                                                                        areas->dptr_,
+                                                                        num_elem,
+                                                                        width_elem,
+                                                                        coord_start,
+                                                                        id_index,
+                                                                        threshold,
+                                                                        force_suppress,
+                                                                        in_format,
+                                                                        ref,
+                                                                        blocks_per_batch);
+      }
+    }
+  }
+}
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/contrib/bounding_box-inl.h b/src/operator/contrib/bounding_box-inl.h
index 650e58d..35ab19d 100644
--- a/src/operator/contrib/bounding_box-inl.h
+++ b/src/operator/contrib/bounding_box-inl.h
@@ -34,12 +34,10 @@
 #include "../mxnet_op.h"
 #include "../operator_common.h"
 #include "../tensor/sort_op.h"
+#include "./bounding_box-common.h"
 
 namespace mxnet {
 namespace op {
-namespace box_common_enum {
-enum BoxType {kCorner, kCenter};
-}
 namespace box_nms_enum {
 enum BoxNMSOpInputs {kData};
 enum BoxNMSOpOutputs {kOut, kTemp};
@@ -254,85 +252,32 @@ struct compute_area {
   }
 };
 
-// compute line intersect along either height or width
 template<typename DType>
-MSHADOW_XINLINE DType Intersect(const DType *a, const DType *b, int encode) {
-  DType a1 = a[0];
-  DType a2 = a[2];
-  DType b1 = b[0];
-  DType b2 = b[2];
-  DType w;
-  if (box_common_enum::kCorner == encode) {
-    DType left = a1 > b1 ? a1 : b1;
-    DType right = a2 < b2 ? a2 : b2;
-    w = right - left;
-  } else {
-    DType aw = a2 / 2;
-    DType bw = b2 / 2;
-    DType al = a1 - aw;
-    DType ar = a1 + aw;
-    DType bl = b1 - bw;
-    DType br = b1 + bw;
-    DType left = bl > al ? bl : al;
-    DType right = br < ar ? br : ar;
-    w = right - left;
+void NMSApply(mshadow::Stream<cpu> *s,
+              int num_batch, int topk,
+              mshadow::Tensor<cpu, 1, int32_t>* sorted_index,
+              mshadow::Tensor<cpu, 1, int32_t>* batch_start,
+              mshadow::Tensor<cpu, 3, DType>* buffer,
+              mshadow::Tensor<cpu, 1, DType>* areas,
+              int num_elem, int width_elem,
+              int coord_start, int id_index,
+              float threshold, bool force_suppress,
+              int in_format) {
+  using namespace mxnet_op;
+  // go through each box as reference, suppress if overlap > threshold
+  // sorted_index with -1 is marked as suppressed
+  for (int ref = 0; ref < topk; ++ref) {
+    int num_worker = topk - ref - 1;
+    if (num_worker < 1) continue;
+    Kernel<nms_impl, cpu>::Launch(s, num_batch * num_worker,
+      sorted_index->dptr_, batch_start->dptr_, buffer->dptr_, areas->dptr_,
+      num_worker, ref, num_elem,
+      width_elem, coord_start, id_index,
+      threshold, force_suppress, in_format);
   }
-  return w > 0 ? w : DType(0);
 }
 
 /*!
-   * \brief Implementation of the non-maximum suppression operation
-   *
-   * \param i the launched thread index
-   * \param index sorted index in descending order
-   * \param batch_start map (b, k) to compact index by indices[batch_start[b] + k]
-   * \param input the input of nms op
-   * \param areas pre-computed box areas
-   * \param k nms topk number
-   * \param ref compare reference position
-   * \param num number of input boxes in each batch
-   * \param stride input stride, usually 6 (id-score-x1-y1-x2-y2)
-   * \param offset_box box offset, usually 2
-   * \param thresh nms threshold
-   * \param force force suppress regardless of class id
-   * \param offset_id class id offset, used when force == false, usually 0
-   * \param encode box encoding type, corner(0) or center(1)
-   * \param DType the data type
-   */
-struct nms_impl {
-  template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, int32_t *index, const int32_t *batch_start,
-                                  const DType *input, const DType *areas,
-                                  int k, int ref, int num,
-                                  int stride, int offset_box, int offset_id,
-                                  float thresh, bool force, int encode) {
-    int b = i / k;  // batch
-    int pos = i % k + ref + 1;  // position
-    ref = static_cast<int>(batch_start[b]) + ref;
-    pos = static_cast<int>(batch_start[b]) + pos;
-    if (ref >= static_cast<int>(batch_start[b + 1])) return;
-    if (pos >= static_cast<int>(batch_start[b + 1])) return;
-    if (index[ref] < 0) return;  // reference has been suppressed
-    if (index[pos] < 0) return;  // self been suppressed
-    int ref_offset = static_cast<int>(index[ref]) * stride + offset_box;
-    int pos_offset = static_cast<int>(index[pos]) * stride + offset_box;
-    if (!force && offset_id >=0) {
-      int ref_id = static_cast<int>(input[ref_offset - offset_box + offset_id]);
-      int pos_id = static_cast<int>(input[pos_offset - offset_box + offset_id]);
-      if (ref_id != pos_id) return;  // different class
-    }
-    DType intersect = Intersect(input + ref_offset, input + pos_offset, encode);
-    intersect *= Intersect(input + ref_offset + 1, input + pos_offset + 1, encode);
-    int ref_area_offset = static_cast<int>(index[ref]);
-    int pos_area_offset = static_cast<int>(index[pos]);
-    DType iou = intersect / (areas[ref_area_offset] + areas[pos_area_offset] - intersect);
-    if (iou > thresh) {
-      index[pos] = -1;
-    }
-  }
-};
-
-/*!
    * \brief Assign output of nms by indexing input
    *
    * \param i the launched thread index (total num_batch)
@@ -502,17 +447,11 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
      topk, num_elem, width_elem, param.in_format);
 
     // apply nms
-    // go through each box as reference, suppress if overlap > threshold
-    // sorted_index with -1 is marked as suppressed
-    for (int ref = 0; ref < topk; ++ref) {
-      int num_worker = topk - ref - 1;
-      if (num_worker < 1) continue;
-      Kernel<nms_impl, xpu>::Launch(s, num_batch * num_worker,
-        sorted_index.dptr_, batch_start.dptr_, buffer.dptr_, areas.dptr_,
-        num_worker, ref, num_elem,
-        width_elem, coord_start, id_index,
-        param.overlap_thresh, param.force_suppress, param.in_format);
-    }
+    mxnet::op::NMSApply(s, num_batch, topk, &sorted_index,
+                        &batch_start, &buffer, &areas,
+                        num_elem, width_elem, coord_start,
+                        id_index, param.overlap_thresh,
+                        param.force_suppress, param.in_format);
 
     // store the results to output, keep a record for backward
     record = -1;