You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/01/04 16:19:43 UTC
[GitHub] TaoLv closed pull request #13088: make ROIAlign support
position-sensitive pooling
TaoLv closed pull request #13088: make ROIAlign support position-sensitive pooling
URL: https://github.com/apache/incubator-mxnet/pull/13088
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/src/operator/contrib/roi_align-inl.h b/src/operator/contrib/roi_align-inl.h
index 263f72a6abc..9f4d7ce4882 100644
--- a/src/operator/contrib/roi_align-inl.h
+++ b/src/operator/contrib/roi_align-inl.h
@@ -20,7 +20,7 @@
* Copyright (c) 2018 by Contributors
* \file roi_align-inl.h
* \brief roi align operator and symbol
- * \author Hang Zhang
+ * \author Hang Zhang, Shesung
* modified from Caffe2
*/
#ifndef MXNET_OPERATOR_CONTRIB_ROI_ALIGN_INL_H_
@@ -35,7 +35,6 @@
namespace mxnet {
namespace op {
-
// Declare enumeration of input order to make code more intuitive.
// These enums are only visible within this header
namespace roialign {
@@ -48,6 +47,7 @@ struct ROIAlignParam : public dmlc::Parameter<ROIAlignParam> {
TShape pooled_size;
float spatial_scale;
int sample_ratio;
+ bool position_sensitive;
DMLC_DECLARE_PARAMETER(ROIAlignParam) {
DMLC_DECLARE_FIELD(pooled_size)
.set_expect_ndim(2).enforce_nonzero()
@@ -57,6 +57,10 @@ struct ROIAlignParam : public dmlc::Parameter<ROIAlignParam> {
"Equals the reciprocal of total stride in convolutional layers");
DMLC_DECLARE_FIELD(sample_ratio).set_default(-1)
.describe("Optional sampling ratio of ROI align, using adaptive size by default.");
+ DMLC_DECLARE_FIELD(position_sensitive).set_default(false)
+ .describe("Whether to perform position-sensitive RoI pooling. PSRoIPooling is "
+ "first proposaled by R-FCN and it can reduce the input channels by ph*pw times, "
+ "where (ph, pw) is the pooled_size");
}
};
diff --git a/src/operator/contrib/roi_align.cc b/src/operator/contrib/roi_align.cc
index 76675677fa0..e584ea30325 100644
--- a/src/operator/contrib/roi_align.cc
+++ b/src/operator/contrib/roi_align.cc
@@ -20,7 +20,7 @@
* Copyright (c) 2018 by Contributors
* \file roi_align.cc
* \brief roi align operator
- * \author Hang Zhang
+ * \author Hang Zhang, Shesung
* Adapted from Caffe2
*/
#include "./roi_align-inl.h"
@@ -142,6 +142,7 @@ void ROIAlignForward(
const int nthreads,
const T* bottom_data,
const T& spatial_scale,
+ const bool position_sensitive,
const int channels,
const int height,
const int width,
@@ -156,6 +157,8 @@ void ROIAlignForward(
int n_rois = nthreads / channels / pooled_width / pooled_height;
// (n, c, ph, pw) is an element in the pooled output
// can be parallelized using omp
+#pragma omp parallel for \
+num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
for (int n = 0; n < n_rois; n++) {
int index_n = n * channels * pooled_width * pooled_height;
@@ -208,19 +211,23 @@ void ROIAlignForward(
roi_bin_grid_w,
&pre_calc);
- int c;
-#pragma omp parallel for private(c) \
-num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
- for (c = 0; c < channels; c++) {
+ for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * pooled_width * pooled_height;
- const T* offset_bottom_data =
- bottom_data + (roi_batch_ind * channels + c) * height * width;
int pre_calc_index = 0;
for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
int index = index_n_c + ph * pooled_width + pw;
+ int c_unpooled = c;
+ int channels_unpooled = channels;
+ if (position_sensitive) {
+ c_unpooled = c * pooled_height * pooled_width + ph * pooled_width + pw;
+ channels_unpooled = channels * pooled_height * pooled_width;
+ }
+ const T* offset_bottom_data =
+ bottom_data + (roi_batch_ind * channels_unpooled + c_unpooled)
+ * height * width;
T output_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
@@ -310,6 +317,7 @@ void ROIAlignBackward(
const T* top_diff,
const int /*num_rois*/,
const T& spatial_scale,
+ const bool position_sensitive,
const int channels,
const int height,
const int width,
@@ -347,8 +355,15 @@ void ROIAlignBackward(
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+ int c_unpooled = c;
+ int channels_unpooled = channels;
+ if (position_sensitive) {
+ c_unpooled = c * pooled_height * pooled_width + ph * pooled_width + pw;
+ channels_unpooled = channels * pooled_height * pooled_width;
+ }
T* offset_bottom_diff =
- bottom_diff + (roi_batch_ind * channels + c) * height * width;
+ bottom_diff + (roi_batch_ind * channels_unpooled + c_unpooled)
+ * height * width;
int top_offset = (n * channels + c) * pooled_height * pooled_width;
const T* offset_top_diff = top_diff + top_offset;
@@ -426,7 +441,7 @@ void ROIAlignForwardCompute(const nnvm::NodeAttrs& attrs,
const int count = out_data[roialign::kOut].Size();
// const int num_rois = in_data[roialign::kBox].size(0);
- const int channels = in_data[roialign::kData].size(1);
+ const int channels = out_data[roialign::kOut].size(1); // channels of pooled output
const int height = in_data[roialign::kData].size(2);
const int width = in_data[roialign::kData].size(3);
const int pooled_height = out_data[roialign::kOut].size(2);
@@ -439,9 +454,9 @@ void ROIAlignForwardCompute(const nnvm::NodeAttrs& attrs,
const DType *bottom_rois = in_data[roialign::kBox].dptr<DType>();
DType *top_data = out_data[roialign::kOut].dptr<DType>();
- ROIAlignForward<DType>(count, bottom_data, param.spatial_scale, channels,
- height, width, pooled_height, pooled_width, param.sample_ratio,
- bottom_rois, rois_cols, top_data);
+ ROIAlignForward<DType>(count, bottom_data, param.spatial_scale, param.position_sensitive,
+ channels, height, width, pooled_height, pooled_width,
+ param.sample_ratio, bottom_rois, rois_cols, top_data);
})
}
@@ -470,7 +485,7 @@ void ROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs,
const int count = out_grad[0].Size();
const int num_rois = in_data[0].size(0);
- const int channels = outputs[0].size(1);
+ const int channels = out_grad[0].size(1); // channels of pooled output
const int height = outputs[0].size(2);
const int width = outputs[0].size(3);
const int pooled_height = out_grad[0].size(2);
@@ -489,8 +504,9 @@ void ROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs,
Fill<false>(s, outputs[0], kWriteTo, static_cast<DType>(0));
}
ROIAlignBackward<DType>(count, top_diff, num_rois, param.spatial_scale,
- channels, height, width, pooled_height, pooled_width,
- param.sample_ratio, grad_in, bottom_rois, rois_cols);
+ param.position_sensitive, channels, height, width,
+ pooled_height, pooled_width, param.sample_ratio, grad_in,
+ bottom_rois, rois_cols);
}
if (kWriteTo == req[roialign::kBox]) {
Fill<false>(s, outputs[1], kWriteTo, static_cast<DType>(0));
@@ -545,8 +561,17 @@ He, Kaiming, et al. "Mask R-CNN." ICCV, 2017
CHECK_EQ(bshape[1], 5) << "bbox should be a 2D tensor of shape [batch, 5]";
// out: [num_rois, c, pooled_h, pooled_w]
out_shape->clear();
- out_shape->push_back(
- Shape4(bshape[0], dshape[1], param.pooled_size[0], param.pooled_size[1]));
+ if (param.position_sensitive) {
+ CHECK_EQ(dshape[1] % (param.pooled_size[0]*param.pooled_size[1]), 0) <<
+ "Input channels should be divided by pooled_size[0]*pooled_size[1]"
+ "when position_sensitive is true.";
+ out_shape->push_back(
+ Shape4(bshape[0], dshape[1]/param.pooled_size[0]/param.pooled_size[1],
+ param.pooled_size[0], param.pooled_size[1]));
+ } else {
+ out_shape->push_back(
+ Shape4(bshape[0], dshape[1], param.pooled_size[0], param.pooled_size[1]));
+ }
return true;
})
.set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/contrib/roi_align.cu b/src/operator/contrib/roi_align.cu
index d3db70b73b1..38b461d5f58 100644
--- a/src/operator/contrib/roi_align.cu
+++ b/src/operator/contrib/roi_align.cu
@@ -20,7 +20,7 @@
* Copyright (c) 2018 by Contributors
* \file roi_align.cu
* \brief roi align operator
- * \author Hang Zhang
+ * \author Hang Zhang, Shesung
* Adapted from Caffe2
*/
#include "./roi_align-inl.h"
@@ -111,6 +111,7 @@ __global__ void RoIAlignForwardKernel(
const int nthreads,
const T* bottom_data,
const T spatial_scale,
+ const bool position_sensitive,
const int channels,
const int height,
const int width,
@@ -145,8 +146,15 @@ __global__ void RoIAlignForwardKernel(
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+ int c_unpooled = c;
+ int channels_unpooled = channels;
+ if (position_sensitive) {
+ c_unpooled = c * pooled_height * pooled_width + ph * pooled_width + pw;
+ channels_unpooled = channels * pooled_height * pooled_width;
+ }
const T* offset_bottom_data =
- bottom_data + (roi_batch_ind * channels + c) * height * width;
+ bottom_data + (roi_batch_ind * channels_unpooled + c_unpooled)
+ * height * width;
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
@@ -242,6 +250,7 @@ __global__ void RoIAlignBackwardKernel(
const T* top_diff,
const int num_rois,
const T spatial_scale,
+ const bool position_sensitive,
const int channels,
const int height,
const int width,
@@ -276,8 +285,15 @@ __global__ void RoIAlignBackwardKernel(
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+ int c_unpooled = c;
+ int channels_unpooled = channels;
+ if (position_sensitive) {
+ c_unpooled = c * pooled_height * pooled_width + ph * pooled_width + pw;
+ channels_unpooled = channels * pooled_height * pooled_width;
+ }
T* offset_bottom_diff =
- bottom_diff + (roi_batch_ind * channels + c) * height * width;
+ bottom_diff + (roi_batch_ind * channels_unpooled + c_unpooled)
+ * height * width;
int top_offset = (n * channels + c) * pooled_height * pooled_width;
const T* offset_top_diff = top_diff + top_offset;
@@ -357,7 +373,7 @@ void ROIAlignForwardCompute(const nnvm::NodeAttrs& attrs,
const int count = out_data[roialign::kOut].Size();
const int num_rois = in_data[roialign::kBox].size(0);
- const int channels = in_data[roialign::kData].size(1);
+ const int channels = out_data[roialign::kOut].size(1); // channels of pooled output
const int height = in_data[roialign::kData].size(2);
const int width = in_data[roialign::kData].size(3);
const int pooled_height = out_data[roialign::kOut].size(2);
@@ -377,6 +393,7 @@ void ROIAlignForwardCompute(const nnvm::NodeAttrs& attrs,
count,
bottom_data,
param.spatial_scale,
+ param.position_sensitive,
channels,
height,
width,
@@ -414,7 +431,7 @@ void ROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs,
const int count = out_grad[0].Size();
const int num_rois = in_data[0].size(0);
- const int channels = outputs[0].size(1);
+ const int channels = out_grad[0].size(1); // channels of pooled output
const int height = outputs[0].size(2);
const int width = outputs[0].size(3);
const int pooled_height = out_grad[0].size(2);
@@ -445,6 +462,7 @@ void ROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs,
top_diff,
num_rois,
param.spatial_scale,
+ param.position_sensitive,
channels,
height,
width,
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index a895594ce28..d8e80d7d693 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -17,6 +17,7 @@
# pylint: skip-file
from __future__ import print_function
+from __future__ import division
import numpy as np
import mxnet as mx
import copy
@@ -6899,14 +6900,16 @@ def bilinear_interpolate(bottom, height, width, y, x):
]
return val, grad
- def roialign_forward_backward(data, rois, pooled_size, spatial_scale, sampling_ratio, dy):
+ def roialign_forward_backward(data, rois, pooled_size, spatial_scale, sampling_ratio,
+ position_sensitive, dy):
N, C, H, W = data.shape
R = rois.shape[0]
PH, PW = pooled_size
assert len(rois.shape) == 2
assert rois.shape[1] == 5
- out = np.zeros((R, C, PH, PW))
+ C_out = C // PH // PW if position_sensitive else C
+ out = np.zeros((R, C_out, PH, PW))
dx = np.zeros_like(data)
drois = np.zeros_like(rois)
@@ -6924,24 +6927,25 @@ def roialign_forward_backward(data, rois, pooled_size, spatial_scale, sampling_r
roi_bin_grid_h = int(np.ceil(roi_h * 1.0 / PH))
roi_bin_grid_w = int(np.ceil(roi_w * 1.0 / PW))
count = roi_bin_grid_h * roi_bin_grid_w
- for c in range(C):
+ for c in range(C_out):
for ph in range(PH):
for pw in range(PW):
val = 0.0
+ c_in = c * PH * PW + ph * PW + pw if position_sensitive else c
for iy in range(roi_bin_grid_h):
y = sh + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h
for ix in range(roi_bin_grid_w):
x = sw + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w
- v, g = bilinear_interpolate(bdata[c], H, W, y, x)
+ v, g = bilinear_interpolate(bdata[c_in], H, W, y, x)
val += v
# compute grad
for qy, qx, qw in g:
- dx[batch_ind, c, qy, qx] += dy[r, c, ph, pw] * qw * 1.0 / count
+ dx[batch_ind, c_in, qy, qx] += dy[r, c, ph, pw] * qw * 1.0 / count
out[r, c, ph, pw] = val * 1.0 / count
return out, [dx, drois]
- def test_roi_align_value(sampling_ratio=0):
+ def test_roi_align_value(sampling_ratio=0, position_sensitive=False):
ctx=default_context()
dtype = np.float32
@@ -6950,6 +6954,7 @@ def test_roi_align_value(sampling_ratio=0):
assert H == W
R = 7
pooled_size = (3, 4)
+ C = C * pooled_size[0] * pooled_size[1] if position_sensitive else C
spatial_scale = H * 1.0 / dlen
data = mx.nd.array(np.arange(N*C*W*H).reshape((N,C,H,W)), ctx=ctx, dtype = dtype)
@@ -6964,11 +6969,14 @@ def test_roi_align_value(sampling_ratio=0):
rois.attach_grad()
with mx.autograd.record():
output = mx.nd.contrib.ROIAlign(data, rois, pooled_size=pooled_size,
- spatial_scale=spatial_scale, sample_ratio=sampling_ratio)
- dy = mx.nd.random.uniform(-1, 1, (R, C) + pooled_size, ctx=ctx, dtype = dtype)
+ spatial_scale=spatial_scale, sample_ratio=sampling_ratio,
+ position_sensitive=position_sensitive)
+ C_out = C // pooled_size[0] // pooled_size[1] if position_sensitive else C
+ dy = mx.nd.random.uniform(-1, 1, (R, C_out) + pooled_size, ctx=ctx, dtype = dtype)
output.backward(dy)
real_output, [dx, drois] = roialign_forward_backward(data.asnumpy(), rois.asnumpy(), pooled_size,
- spatial_scale, sampling_ratio, dy.asnumpy())
+ spatial_scale, sampling_ratio,
+ position_sensitive, dy.asnumpy())
assert np.allclose(output.asnumpy(), real_output)
# It seems that the precision between Cfloat and Pyfloat is different.
assert np.allclose(data.grad.asnumpy(), dx, atol = 1e-5), np.abs(data.grad.asnumpy() - dx).max()
@@ -6994,7 +7002,8 @@ def test_roi_align_autograd(sampling_ratio=0):
numeric_eps=1e-4, rtol=1e-1, atol=1e-4, ctx=ctx)
test_roi_align_value()
- test_roi_align_value(2)
+ test_roi_align_value(sampling_ratio=2)
+ test_roi_align_value(position_sensitive=True)
test_roi_align_autograd()
@with_seed()
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services