You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/07 23:29:39 UTC

[GitHub] zhreshold closed pull request #11162: Add valid_thresh to contrib.box_nms

zhreshold closed pull request #11162: Add valid_thresh to contrib.box_nms
URL: https://github.com/apache/incubator-mxnet/pull/11162
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/contrib/bounding_box-inl.cuh b/src/operator/contrib/bounding_box-inl.cuh
new file mode 100644
index 00000000000..fb1dacc11f4
--- /dev/null
+++ b/src/operator/contrib/bounding_box-inl.cuh
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file bounding_box-inl.cuh
+ * \brief bounding box CUDA operators
+ * \author Joshua Zhang
+*/
+#ifndef MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_INL_CUH_
+#define MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_INL_CUH_
+#include <mxnet/operator_util.h>
+#include <thrust/copy.h>
+#include <thrust/execution_policy.h>
+#include "../mshadow_op.h"
+#include "../mxnet_op.h"
+#include "../operator_common.h"
+
+namespace mxnet {
+namespace op {
+
+template<typename DType>
+struct valid_score {
+  DType thresh;
+  explicit valid_score(DType _thresh) : thresh(_thresh) {}
+  __host__ __device__ bool operator()(const DType x) {
+    return x > thresh;
+  }
+};
+
+template<typename DType>
+int FilterScores(mshadow::Tensor<gpu, 1, DType> out_scores,
+                 mshadow::Tensor<gpu, 1, DType> out_sorted_index,
+                 mshadow::Tensor<gpu, 1, DType> scores,
+                 mshadow::Tensor<gpu, 1, DType> sorted_index,
+                 float valid_thresh) {
+  valid_score<DType> pred(static_cast<DType>(valid_thresh));
+  DType * end_scores = thrust::copy_if(thrust::device, scores.dptr_, scores.dptr_ + scores.MSize(),
+                                       out_scores.dptr_, pred);
+  thrust::copy_if(thrust::device, sorted_index.dptr_, sorted_index.dptr_ + sorted_index.MSize(),
+                  scores.dptr_, out_sorted_index.dptr_, pred);
+  return end_scores - out_scores.dptr_;
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_INL_CUH_
diff --git a/src/operator/contrib/bounding_box-inl.h b/src/operator/contrib/bounding_box-inl.h
index 40dbdd81669..f739dbc8a52 100644
--- a/src/operator/contrib/bounding_box-inl.h
+++ b/src/operator/contrib/bounding_box-inl.h
@@ -49,6 +49,7 @@ enum BoxNMSOpResource {kTempSpace};
 
 struct BoxNMSParam : public dmlc::Parameter<BoxNMSParam> {
   float overlap_thresh;
+  float valid_thresh;
   int topk;
   int coord_start;
   int score_index;
@@ -59,6 +60,8 @@ struct BoxNMSParam : public dmlc::Parameter<BoxNMSParam> {
   DMLC_DECLARE_PARAMETER(BoxNMSParam) {
     DMLC_DECLARE_FIELD(overlap_thresh).set_default(0.5)
     .describe("Overlapping(IoU) threshold to suppress object with smaller score.");
+    DMLC_DECLARE_FIELD(valid_thresh).set_default(0)
+    .describe("Filter input boxes to those whose scores greater than valid_thresh.");
     DMLC_DECLARE_FIELD(topk).set_default(-1)
     .describe("Apply nms to topk boxes with descending scores, -1 to no restriction.");
     DMLC_DECLARE_FIELD(coord_start).set_default(2)
@@ -145,6 +148,33 @@ inline uint32_t BoxNMSNumVisibleOutputs(const NodeAttrs& attrs) {
   return static_cast<uint32_t>(1);
 }
 
+template<typename DType>
+int FilterScores(mshadow::Tensor<cpu, 1, DType> out_scores,
+                 mshadow::Tensor<cpu, 1, DType> out_sorted_index,
+                 mshadow::Tensor<cpu, 1, DType> scores,
+                 mshadow::Tensor<cpu, 1, DType> sorted_index,
+                 float valid_thresh) {
+  index_t j = 0;
+  for (index_t i = 0; i < scores.size(0); i++) {
+    if (scores[i] > valid_thresh) {
+      out_scores[j] = scores[i];
+      out_sorted_index[j] = sorted_index[i];
+      j++;
+    }
+  }
+  return j;
+}
+
+namespace mshadow_op {
+struct less_than : public mxnet_op::tunable {
+  // a is x, b is sigma
+  template<typename DType>
+  MSHADOW_XINLINE static DType Map(DType a, DType b) {
+    return static_cast<DType>(a < b);
+  }
+};  // struct equal_to
+}   // namespace mshadow_op
+
 struct corner_to_center {
   template<typename DType>
   MSHADOW_XINLINE static void Map(int i, DType *data, int stride) {
@@ -198,15 +228,28 @@ MSHADOW_XINLINE DType BoxArea(const DType *box, int encode) {
   }
 }
 
-// compute areas specialized for nms to reduce computation
+/*!
+ * \brief compute areas specialized for nms to reduce computation
+ * 
+ * \param i the launched thread index (total thread num_batch * topk)
+ * \param out 1d array for areas (size num_batch * num_elem)
+ * \param in 1st coordinate of 1st box (buffer + coord_start)
+ * \param indices index to areas and in buffer (sorted_index)
+ * \param batch_start map (b, k) to compact index by indices[batch_start[b] + k]
+ * \param topk effective batch size of boxes, to be mapped to real index
+ * \param stride should be width_elem (e.g. 6 including cls and scores)
+ * \param encode passed to BoxArea to compute area
+ */
 struct compute_area {
   template<typename DType>
   MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in,
-                                  const DType *indices, int topk, int num_elem,
-                                  int stride, int encode) {
+                                  const DType *indices, const DType *batch_start,
+                                  int topk, int num_elem, int stride, int encode) {
     int b = i / topk;
     int k = i % topk;
-    int index = static_cast<int>(indices[b * num_elem + k]);
+    int pos = static_cast<int>(batch_start[b]) + k;
+    if (pos >= static_cast<int>(batch_start[b + 1])) return;
+    int index = static_cast<int>(indices[pos]);
     int in_index = index * stride;
     out[index] = BoxArea(in + in_index, encode);
   }
@@ -243,6 +286,7 @@ MSHADOW_XINLINE DType Intersect(const DType *a, const DType *b, int encode) {
    *
    * \param i the launched thread index
    * \param index sorted index in descending order
+   * \param batch_start map (b, k) to compact index by indices[batch_start[b] + k]
    * \param input the input of nms op
    * \param areas pre-computed box areas
    * \param k nms topk number
@@ -254,20 +298,25 @@ MSHADOW_XINLINE DType Intersect(const DType *a, const DType *b, int encode) {
    * \param force force suppress regardless of class id
    * \param offset_id class id offset, used when force == false, usually 0
    * \param encode box encoding type, corner(0) or center(1)
-   * \tparam DType the data type
+   * \param DType the data type
    */
 struct nms_impl {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType *index, const DType *input,
-                                  const DType *areas, int k, int ref, int num,
+  MSHADOW_XINLINE static void Map(int i, DType *index, const DType *batch_start,
+                                  const DType *input, const DType *areas,
+                                  int k, int ref, int num,
                                   int stride, int offset_box, int offset_id,
                                   float thresh, bool force, int encode) {
     int b = i / k;  // batch
     int pos = i % k + ref + 1;  // position
-    if (index[b * num + ref] < 0) return;  // reference has been suppressed
-    if (index[b * num + pos] < 0) return;  // self been suppressed
-    int ref_offset = static_cast<int>(index[b * num + ref]) * stride + offset_box;
-    int pos_offset = static_cast<int>(index[b * num + pos]) * stride + offset_box;
+    ref = static_cast<int>(batch_start[b]) + ref;
+    pos = static_cast<int>(batch_start[b]) + pos;
+    if (ref >= static_cast<int>(batch_start[b + 1])) return;
+    if (pos >= static_cast<int>(batch_start[b + 1])) return;
+    if (index[ref] < 0) return;  // reference has been suppressed
+    if (index[pos] < 0) return;  // self been suppressed
+    int ref_offset = static_cast<int>(index[ref]) * stride + offset_box;
+    int pos_offset = static_cast<int>(index[pos]) * stride + offset_box;
     if (!force && offset_id >=0) {
       int ref_id = static_cast<int>(input[ref_offset - offset_box + offset_id]);
       int pos_id = static_cast<int>(input[pos_offset - offset_box + offset_id]);
@@ -275,23 +324,38 @@ struct nms_impl {
     }
     DType intersect = Intersect(input + ref_offset, input + pos_offset, encode);
     intersect *= Intersect(input + ref_offset + 1, input + pos_offset + 1, encode);
-    int ref_area_offset = static_cast<int>(index[b * num + ref]);
-    int pos_area_offset = static_cast<int>(index[b * num + pos]);
+    int ref_area_offset = static_cast<int>(index[ref]);
+    int pos_area_offset = static_cast<int>(index[pos]);
     DType iou = intersect / (areas[ref_area_offset] + areas[pos_area_offset] -
       intersect);
     if (iou > thresh) {
-      index[b * num + pos] = -1;
+      index[pos] = -1;
     }
   }
 };
 
+/*!
+   * \brief Assign output of nms by indexing input
+   * 
+   * \param i the launched thread index (total num_batch)
+   * \param out output array [cls, conf, b0, b1, b2, b3]
+   * \param record book keeping the selected index for backward
+   * \param index compact sorted_index, use batch_start to access
+   * \param batch_start map(b, k) to compact index by index[batch_start[b] + k]
+   * \param k nms topk number
+   * \param num number of input boxes in each batch
+   * \param stride input stride, usually 6 (id-score-x1-y2-x2-y2)
+   */
 struct nms_assign {
   template<typename DType>
   MSHADOW_XINLINE static void Map(int i, DType *out, DType *record, const DType *input,
-                                  const DType *index, int k, int num, int stride) {
+                                  const DType *index, const DType *batch_start,
+                                  int k, int num, int stride) {
     int count = 0;
     for (int j = 0; j < k; ++j) {
-      int location = static_cast<int>(index[i * num + j]);
+      int pos = static_cast<int>(batch_start[i]) + j;
+      if (pos >= static_cast<int>(batch_start[i + 1])) return;
+      int location = static_cast<int>(index[pos]);
       if (location >= 0) {
         // copy to output
         int out_location = (i * num + count) * stride;
@@ -352,6 +416,8 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
     Shape<1> sort_index_shape = Shape1(num_batch * num_elem);
     Shape<3> buffer_shape = Shape3(num_batch, num_elem, width_elem);
     index_t workspace_size = 4 * sort_index_shape.Size();
+    Shape<1> batch_start_shape = Shape1(num_batch + 1);
+    workspace_size += batch_start_shape.Size();
     if (req[0] == kWriteInplace) {
       workspace_size += buffer_shape.Size();
     }
@@ -363,10 +429,11 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
     Tensor<xpu, 1, DType> batch_id(scores.dptr_ + scores.MSize(), sort_index_shape,
       s);
     Tensor<xpu, 1, DType> areas(batch_id.dptr_ + batch_id.MSize(), sort_index_shape, s);
+    Tensor<xpu, 1, DType> batch_start(areas.dptr_ + areas.MSize(), batch_start_shape, s);
     Tensor<xpu, 3, DType> buffer = data;
     if (req[0] == kWriteInplace) {
       // make copy
-      buffer = Tensor<xpu, 3, DType>(areas.dptr_ + areas.MSize(), buffer_shape, s);
+      buffer = Tensor<xpu, 3, DType>(batch_start.dptr_ + batch_start.MSize(), buffer_shape, s);
       buffer = F<mshadow_op::identity>(data);
     }
 
@@ -382,19 +449,51 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
       record = reshape(range<DType>(0, num_batch * num_elem), record.shape_);
       return;
     }
-    scores = reshape(slice<2>(buffer, score_index, score_index + 1), scores.shape_);
-    sorted_index = range<DType>(0, num_batch * num_elem);
-    mxnet::op::SortByKey(scores, sorted_index, false);
-    batch_id = F<mshadow_op::floor>(sorted_index / ScalarExp<DType>(num_elem));
-    mxnet::op::SortByKey(batch_id, scores, true);
-    batch_id = F<mshadow_op::floor>(sorted_index / ScalarExp<DType>(num_elem));
-    mxnet::op::SortByKey(batch_id, sorted_index, true);
+
+    // use batch_id and areas as temporary storage
+    Tensor<xpu, 1, DType> all_scores = batch_id;
+    Tensor<xpu, 1, DType> all_sorted_index = areas;
+    all_scores = reshape(slice<2>(buffer, score_index, score_index + 1), all_scores.shape_);
+    all_sorted_index = range<DType>(0, num_batch * num_elem);
+
+    // filter scores but keep original sorted_index value
+    // move valid score and index to the front, return valid size
+    int num_valid = mxnet::op::FilterScores(scores, sorted_index, all_scores, all_sorted_index,
+                                            param.valid_thresh);
+    // if everything is filtered, output -1
+    if (num_valid == 0) {
+      record = -1;
+      out = -1;
+      return;
+    }
+    // mark the invalid boxes before nms
+    if (num_valid < num_batch * num_elem) {
+      slice<0>(sorted_index, num_valid, num_batch * num_elem) = -1;
+    }
+
+    // only sort the valid scores and batch_id
+    Shape<1> valid_score_shape = Shape1(num_valid);
+    Tensor<xpu, 1, DType> valid_scores(scores.dptr_, valid_score_shape, s);
+    Tensor<xpu, 1, DType> valid_sorted_index(sorted_index.dptr_, valid_score_shape, s);
+    Tensor<xpu, 1, DType> valid_batch_id(batch_id.dptr_, valid_score_shape, s);
+
+    // sort index by batch_id then score (stable sort)
+    mxnet::op::SortByKey(valid_scores, valid_sorted_index, false);
+    valid_batch_id = F<mshadow_op::floor>(valid_sorted_index / ScalarExp<DType>(num_elem));
+    mxnet::op::SortByKey(valid_batch_id, valid_sorted_index, true);
+
+    // calculate batch_start: accumulated sum to denote 1st sorted_index for a given batch_index
+    valid_batch_id = F<mshadow_op::floor>(valid_sorted_index / ScalarExp<DType>(num_elem));
+    for (int b = 0; b < num_batch + 1; b++) {
+      slice<0>(batch_start, b, b + 1) = reduce_keepdim<red::sum, false>(
+        F<mshadow_op::less_than>(valid_batch_id, ScalarExp<DType>(b)), 0);
+    }
 
     // pre-compute areas of candidates
     areas = 0;
-    Kernel<compute_area, xpu>::Launch(s, num_batch * topk, areas.dptr_,
-     buffer.dptr_ + coord_start, sorted_index.dptr_, topk, num_elem, width_elem,
-     param.in_format);
+    Kernel<compute_area, xpu>::Launch(s, num_batch * topk,
+     areas.dptr_, buffer.dptr_ + coord_start, sorted_index.dptr_, batch_start.dptr_,
+     topk, num_elem, width_elem, param.in_format);
 
     // apply nms
     // go through each box as reference, suppress if overlap > threshold
@@ -402,16 +501,19 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
     for (int ref = 0; ref < topk; ++ref) {
       int num_worker = topk - ref - 1;
       if (num_worker < 1) continue;
-      Kernel<nms_impl, xpu>::Launch(s, num_batch * num_worker, sorted_index.dptr_,
-        buffer.dptr_, areas.dptr_, num_worker, ref, num_elem, width_elem,
-        coord_start, id_index, param.overlap_thresh, param.force_suppress, param.in_format);
+      Kernel<nms_impl, xpu>::Launch(s, num_batch * num_worker,
+        sorted_index.dptr_, batch_start.dptr_, buffer.dptr_, areas.dptr_,
+        num_worker, ref, num_elem,
+        width_elem, coord_start, id_index,
+        param.overlap_thresh, param.force_suppress, param.in_format);
     }
 
     // store the results to output, keep a record for backward
     record = -1;
     out = -1;
-    Kernel<nms_assign, xpu>::Launch(s, num_batch, out.dptr_, record.dptr_,
-      buffer.dptr_, sorted_index.dptr_, topk, num_elem, width_elem);
+    Kernel<nms_assign, xpu>::Launch(s, num_batch,
+      out.dptr_, record.dptr_, buffer.dptr_, sorted_index.dptr_, batch_start.dptr_,
+      topk, num_elem, width_elem);
 
     // convert encoding
     if (param.in_format != param.out_format) {
diff --git a/src/operator/contrib/bounding_box.cu b/src/operator/contrib/bounding_box.cu
index 6662d932700..2677d2f7947 100644
--- a/src/operator/contrib/bounding_box.cu
+++ b/src/operator/contrib/bounding_box.cu
@@ -24,6 +24,7 @@
   * \author Joshua Zhang
   */
 
+#include "./bounding_box-inl.cuh"
 #include "./bounding_box-inl.h"
 #include "../elemwise_op_common.h"
 
diff --git a/tests/python/unittest/test_contrib_operator.py b/tests/python/unittest/test_contrib_operator.py
index 5618e11a040..a220f08d20d 100644
--- a/tests/python/unittest/test_contrib_operator.py
+++ b/tests/python/unittest/test_contrib_operator.py
@@ -26,20 +26,20 @@
 import unittest
 
 def test_box_nms_op():
-    def test_box_nms_forward(data, expected, thresh=0.5, topk=-1, coord=2, score=1, cid=0,
+    def test_box_nms_forward(data, expected, thresh=0.5, valid=0, topk=-1, coord=2, score=1, cid=0,
                          force=False, in_format='corner', out_format='corner'):
         data = mx.nd.array(data)
-        out = mx.contrib.nd.box_nms(data, overlap_thresh=thresh, topk=topk,
+        out = mx.contrib.nd.box_nms(data, overlap_thresh=thresh, valid_thresh=valid, topk=topk,
                                 coord_start=coord, score_index=score, id_index=cid,
                                 force_suppress=force, in_format=in_format, out_format=out_format)
         assert_almost_equal(out.asnumpy(), expected)
 
-    def test_box_nms_backward(data, grad, expected, thresh=0.5, topk=-1, coord=2, score=1,
+    def test_box_nms_backward(data, grad, expected, thresh=0.5, valid=0, topk=-1, coord=2, score=1,
                           cid=0, force=False, in_format='corner', out_format='corner'):
         in_var = mx.sym.Variable('data')
         arr_data = mx.nd.array(data)
         arr_grad = mx.nd.empty(arr_data.shape)
-        op = mx.contrib.sym.box_nms(in_var, overlap_thresh=thresh, topk=topk,
+        op = mx.contrib.sym.box_nms(in_var, overlap_thresh=thresh, valid_thresh=valid, topk=topk,
                                 coord_start=coord, score_index=score, id_index=cid,
                                 force_suppress=force, in_format=in_format, out_format=out_format)
         exe = op.bind(ctx=default_context(), args=[arr_data], args_grad=[arr_grad])
@@ -158,6 +158,23 @@ def swap_position(data, expected, coord=2, score=1, cid=0, new_col=0):
     thresh = 0.5
     test_box_nms_forward(np.array(boxes), np.array(expected), force=force, thresh=thresh, cid=-1)
 
+    # case8: multi-batch thresh + topk
+    boxes8 = [[[1, 1, 0, 0, 10, 10], [1, 0.4, 0, 0, 10, 10], [1, 0.3, 0, 0, 10, 10]],
+              [[2, 1, 0, 0, 10, 10], [2, 0.4, 0, 0, 10, 10], [2, 0.3, 0, 0, 10, 10]],
+              [[3, 1, 0, 0, 10, 10], [3, 0.4, 0, 0, 10, 10], [3, 0.3, 0, 0, 10, 10]]]
+    expected8 = [[[1, 1, 0, 0, 10, 10], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]],
+                 [[2, 1, 0, 0, 10, 10], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]],
+                 [[3, 1, 0, 0, 10, 10], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]]]
+    grad8 = np.random.rand(3, 3, 6)
+    expected_in_grad8 = np.zeros((3, 3, 6))
+    expected_in_grad8[(0, 1, 2), (0, 0, 0), :] = grad8[(0, 1, 2), (0, 0, 0), :]
+    force = False
+    thresh = 0.5
+    valid = 0.5
+    topk = 2
+    test_box_nms_forward(np.array(boxes8), np.array(expected8), force=force, thresh=thresh, valid=valid, topk=topk)
+    test_box_nms_backward(np.array(boxes8), grad8, expected_in_grad8, force=force, thresh=thresh, valid=valid, topk=topk)
+
 def test_box_iou_op():
     def numpy_box_iou(a, b, fmt='corner'):
         def area(left, top, right, bottom):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services