You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/06/29 22:23:41 UTC

[incubator-mxnet] branch master updated: [MXNET-517] add sample ratio for ROI Align (#11145)

This is an automated email from the ASF dual-hosted git repository.

zhreshold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new e892301  [MXNET-517] add sample ratio for ROI Align (#11145)
e892301 is described below

commit e8923011523f900b1f7e9f180feecb89c1a1d6e1
Author: Hang Zhang <80...@users.noreply.github.com>
AuthorDate: Fri Jun 29 16:23:33 2018 -0600

    [MXNET-517] add sample ratio for ROI Align (#11145)
    
    * add sample ratio
    
    * pylint
    
    * increase size limit for bilinearup
    
    * add test case
    
    * fix typo
    
    * rm comments and cpu back
---
 src/operator/contrib/bilinear_resize-inl.h |  4 ++--
 src/operator/contrib/roi_align-inl.h       |  3 +++
 src/operator/contrib/roi_align.cc          |  6 +++---
 src/operator/contrib/roi_align.cu          | 21 ++-------------------
 tests/python/unittest/test_operator.py     | 16 +++++++++-------
 5 files changed, 19 insertions(+), 31 deletions(-)

diff --git a/src/operator/contrib/bilinear_resize-inl.h b/src/operator/contrib/bilinear_resize-inl.h
index b73ead9..c096f01 100644
--- a/src/operator/contrib/bilinear_resize-inl.h
+++ b/src/operator/contrib/bilinear_resize-inl.h
@@ -51,9 +51,9 @@ struct BilinearSampleParam : public dmlc::Parameter<BilinearSampleParam> {
   int height;
   int width;
   DMLC_DECLARE_PARAMETER(BilinearSampleParam) {
-    DMLC_DECLARE_FIELD(height).set_range(1, 1000)
+    DMLC_DECLARE_FIELD(height).set_range(1, 10000)
     .describe("output height (required)");
-    DMLC_DECLARE_FIELD(width).set_range(1, 1000)
+    DMLC_DECLARE_FIELD(width).set_range(1, 10000)
     .describe("output width (required)");
   }
 };
diff --git a/src/operator/contrib/roi_align-inl.h b/src/operator/contrib/roi_align-inl.h
index 5ac420c..263f72a 100644
--- a/src/operator/contrib/roi_align-inl.h
+++ b/src/operator/contrib/roi_align-inl.h
@@ -47,6 +47,7 @@ enum ROIAlignOpOutputs {kOut};
 struct ROIAlignParam : public dmlc::Parameter<ROIAlignParam> {
   TShape pooled_size;
   float spatial_scale;
+  int sample_ratio;
   DMLC_DECLARE_PARAMETER(ROIAlignParam) {
     DMLC_DECLARE_FIELD(pooled_size)
     .set_expect_ndim(2).enforce_nonzero()
@@ -54,6 +55,8 @@ struct ROIAlignParam : public dmlc::Parameter<ROIAlignParam> {
     DMLC_DECLARE_FIELD(spatial_scale).set_range(0.0, 1.0)
     .describe("Ratio of input feature map height (or w) to raw image height (or w). "
     "Equals the reciprocal of total stride in convolutional layers");
+    DMLC_DECLARE_FIELD(sample_ratio).set_default(-1)
+    .describe("Optional sampling ratio of ROI align, using adaptive size by default.");
   }
 };
 
diff --git a/src/operator/contrib/roi_align.cc b/src/operator/contrib/roi_align.cc
index c2cb929..2261127 100644
--- a/src/operator/contrib/roi_align.cc
+++ b/src/operator/contrib/roi_align.cc
@@ -440,8 +440,8 @@ void ROIAlignForwardCompute(const nnvm::NodeAttrs& attrs,
     DType *top_data = out_data[roialign::kOut].dptr<DType>();
 
     ROIAlignForward<DType>(count, bottom_data, param.spatial_scale, channels,
-                           height, width, pooled_height, pooled_width, -1, bottom_rois,
-                           rois_cols, top_data);
+                           height, width, pooled_height, pooled_width, param.sample_ratio,
+                           bottom_rois, rois_cols, top_data);
   })
 }
 
@@ -490,7 +490,7 @@ void ROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs,
       }
       ROIAlignBackward<DType>(count, top_diff, num_rois, param.spatial_scale,
                      channels, height, width, pooled_height, pooled_width,
-                     -1, grad_in, bottom_rois, rois_cols);
+                     param.sample_ratio, grad_in, bottom_rois, rois_cols);
     }
     if (kWriteTo == req[roialign::kBox]) {
       Fill<false>(s, outputs[1], kWriteTo, static_cast<DType>(0));
diff --git a/src/operator/contrib/roi_align.cu b/src/operator/contrib/roi_align.cu
index 21066ea..d3db70b 100644
--- a/src/operator/contrib/roi_align.cu
+++ b/src/operator/contrib/roi_align.cu
@@ -231,13 +231,6 @@ __device__ void bilinear_interpolate_gradient(
   T lx = x - *x_low;
   T hy = 1. - ly, hx = 1. - lx;
 
-  // reference in forward
-  // T v1 = bottom_data[*y_low * width + *x_low];
-  // T v2 = bottom_data[*y_low * width + *x_high];
-  // T v3 = bottom_data[*y_high * width + *x_low];
-  // T v4 = bottom_data[*y_high * width + *x_high];
-  // T val = (w1 * v1 + *w2 * v2 + *w3 * v3 + *w4 * v4);
-
   *w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx;
 
   return;
@@ -341,16 +334,6 @@ __global__ void RoIAlignBackwardKernel(
               offset_bottom_diff + y_high * width + x_low, static_cast<T>(g3));
           atomicAdd(
               offset_bottom_diff + y_high * width + x_high, static_cast<T>(g4));
-          /*
-          gpu_atomic_add(
-              static_cast<T>(g1), offset_bottom_diff + y_low * width + x_low);
-          gpu_atomic_add(
-              static_cast<T>(g2), offset_bottom_diff + y_low * width + x_high);
-          gpu_atomic_add(
-              static_cast<T>(g3), offset_bottom_diff + y_high * width + x_low);
-          gpu_atomic_add(
-              static_cast<T>(g4), offset_bottom_diff + y_high * width + x_high);
-          */
         }  // if
       }  // ix
     }  // iy
@@ -399,7 +382,7 @@ void ROIAlignForwardCompute(const nnvm::NodeAttrs& attrs,
           width,
           pooled_height,
           pooled_width,
-          -1,
+          param.sample_ratio,
           bottom_rois,
           top_data);
   })
@@ -467,7 +450,7 @@ void ROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs,
         width,
         pooled_height,
         pooled_width,
-        -1,
+        param.sample_ratio,
         grad_in,
         bottom_rois);
   })
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 3deb1b9..0fa31de 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -6376,7 +6376,7 @@ def test_op_roi_align():
                         out[r, c, ph, pw] = val * 1.0 / count
         return out, [dx, drois]
 
-    def test_roi_align_value():
+    def test_roi_align_value(sampling_ratio=0):
         ctx=default_context()
         dtype = np.float32
 
@@ -6387,7 +6387,6 @@ def test_op_roi_align():
         pooled_size = (3, 4)
 
         spatial_scale = H * 1.0 / dlen
-        sampling_ratio = 0
         data = mx.nd.array(np.arange(N*C*W*H).reshape((N,C,H,W)), ctx=ctx, dtype = dtype)
         # data = mx.nd.random.uniform(0, 1, (N, C, H, W), dtype = dtype)
         center_xy = mx.nd.random.uniform(0, dlen, (R, 2), ctx=ctx, dtype = dtype)
@@ -6400,21 +6399,23 @@ def test_op_roi_align():
         rois.attach_grad()
         with mx.autograd.record():
             output = mx.nd.contrib.ROIAlign(data, rois, pooled_size=pooled_size,
-                    spatial_scale=spatial_scale)
+                    spatial_scale=spatial_scale, sample_ratio=sampling_ratio)
         dy = mx.nd.random.uniform(-1, 1, (R, C) + pooled_size, ctx=ctx, dtype = dtype)
         output.backward(dy)
-        real_output, [dx, drois] = roialign_forward_backward(data.asnumpy(), rois.asnumpy(), pooled_size, spatial_scale, sampling_ratio, dy.asnumpy())
+        real_output, [dx, drois] = roialign_forward_backward(data.asnumpy(), rois.asnumpy(), pooled_size,
+                                                             spatial_scale, sampling_ratio, dy.asnumpy())
         assert np.allclose(output.asnumpy(), real_output)
         # It seems that the precision between Cfloat and Pyfloat is different.
         assert np.allclose(data.grad.asnumpy(), dx, atol = 1e-5), np.abs(data.grad.asnumpy() - dx).max()
         assert np.allclose(rois.grad.asnumpy(), drois)
 
     # modified from test_roipooling()
-    def test_roi_align_autograd():
-        ctx=default_context()
+    def test_roi_align_autograd(sampling_ratio=0):
+        ctx = default_context()
         data = mx.symbol.Variable(name='data')
         rois = mx.symbol.Variable(name='rois')
-        test = mx.symbol.contrib.ROIAlign(data=data, rois=rois, pooled_size=(4, 4), spatial_scale=1)
+        test = mx.symbol.contrib.ROIAlign(data=data, rois=rois, pooled_size=(4, 4), spatial_scale=1,
+                                          sample_ratio=sampling_ratio)
 
         x1 = np.random.rand(4, 1, 12, 12).astype('float64')
         x2 = np.array([[0, 1.1, 1.1, 6.2, 6.2], [2, 6.1, 2.1, 8.2, 11.2],
@@ -6428,6 +6429,7 @@ def test_op_roi_align():
                                numeric_eps=1e-4, rtol=1e-1, atol=1e-4, ctx=ctx)
 
     test_roi_align_value()
+    test_roi_align_value(2)
     test_roi_align_autograd()