You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2019/02/05 19:38:38 UTC
[incubator-mxnet] branch master updated: [MXNET-1258]fix unittest
for ROIAlign Operator (#13609)
This is an automated email from the ASF dual-hosted git repository.
skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 7c7af3a [MXNET-1258]fix unittest for ROIAlign Operator (#13609)
7c7af3a is described below
commit 7c7af3ab5794cd66d69192f20b4cf1ea2852afd9
Author: JackieWu <wk...@live.cn>
AuthorDate: Wed Feb 6 03:38:21 2019 +0800
[MXNET-1258]fix unittest for ROIAlign Operator (#13609)
* fix roi align test
* retrigger unittest
* add more test detail for ROIAlign test
* remove url in test_op_roi_align
* remove blank line in test_op_roi_align in test_operator
* merge master
* Update test_operator.py
* retrigger CI
---
tests/python/unittest/test_operator.py | 144 ++++++++++++++++++---------------
1 file changed, 78 insertions(+), 66 deletions(-)
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index fffaf8e..7b5b9eb 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -6937,139 +6937,150 @@ def test_context_num_gpus():
@with_seed()
def test_op_roi_align():
- # Adapted from https://github.com/wkcn/MobulaOP/blob/master/tests/test_roi_align_op.py
+ T = np.float32
+
+ def assert_same_dtype(dtype_a, dtype_b):
+ '''
+ Assert whether the two data type are the same
+ Parameters
+ ----------
+ dtype_a, dtype_b: type
+ Input data types to compare
+ '''
+ assert dtype_a == dtype_b,\
+ TypeError('Unmatched data types: %s vs %s' % (dtype_a, dtype_b))
+
def bilinear_interpolate(bottom, height, width, y, x):
if y < -1.0 or y > height or x < -1.0 or x > width:
- return 0.0, []
- x = max(0.0, x)
- y = max(0.0, y)
+ return T(0.0), []
+ x = T(max(0.0, x))
+ y = T(max(0.0, y))
x_low = int(x)
y_low = int(y)
if x_low >= width - 1:
x_low = x_high = width - 1
- x = x_low
+ x = T(x_low)
else:
x_high = x_low + 1
-
if y_low >= height - 1:
y_low = y_high = height - 1
- y = y_low
+ y = T(y_low)
else:
y_high = y_low + 1
-
- ly = y - y_low
- lx = x - x_low
- hy = 1.0 - ly
- hx = 1.0 - lx
-
+ ly = y - T(y_low)
+ lx = x - T(x_low)
+ hy = T(1.0) - ly
+ hx = T(1.0) - lx
v1 = bottom[y_low, x_low]
v2 = bottom[y_low, x_high]
v3 = bottom[y_high, x_low]
v4 = bottom[y_high, x_high]
-
- '''
- ----------->x
- |hx hy | lx hy
- |------+------
- |hx ly | lx ly
- V
- y
- v1|v2
- --+--
- v3|v4
- '''
w1 = hy * hx
w2 = hy * lx
w3 = ly * hx
w4 = ly * lx
-
+ assert_same_dtype(w1.dtype, T)
+ assert_same_dtype(w2.dtype, T)
+ assert_same_dtype(w3.dtype, T)
+ assert_same_dtype(w4.dtype, T)
val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
+ assert_same_dtype(val.dtype, T)
grad = [(y_low, x_low, w1), (y_low, x_high, w2),
(y_high, x_low, w3), (y_high, x_high, w4)
- ]
+ ]
return val, grad
def roialign_forward_backward(data, rois, pooled_size, spatial_scale, sampling_ratio,
- position_sensitive, dy):
+ position_sensitive, dy):
N, C, H, W = data.shape
R = rois.shape[0]
PH, PW = pooled_size
- assert len(rois.shape) == 2
- assert rois.shape[1] == 5
+ assert rois.ndim == 2,\
+ ValueError(
+ 'The ndim of rois should be 2 rather than %d' % rois.ndim)
+ assert rois.shape[1] == 5,\
+ ValueError(
+ 'The length of the axis 1 of rois should be 5 rather than %d' % rois.shape[1])
+ assert_same_dtype(data.dtype, T)
+ assert_same_dtype(rois.dtype, T)
C_out = C // PH // PW if position_sensitive else C
- out = np.zeros((R, C_out, PH, PW))
+ out = np.zeros((R, C_out, PH, PW), dtype=T)
dx = np.zeros_like(data)
drois = np.zeros_like(rois)
for r in range(R):
batch_ind = int(rois[r, 0])
- sw, sh, ew, eh = rois[r, 1:5] * spatial_scale
- roi_w = max(ew - sw, 1.0)
- roi_h = max(eh - sh, 1.0)
- bin_h = roi_h * 1.0 / PH
- bin_w = roi_w * 1.0 / PW
+ sw, sh, ew, eh = rois[r, 1:5] * T(spatial_scale)
+ roi_w = T(max(ew - sw, 1.0))
+ roi_h = T(max(eh - sh, 1.0))
+ bin_h = roi_h / T(PH)
+ bin_w = roi_w / T(PW)
bdata = data[batch_ind]
if sampling_ratio > 0:
roi_bin_grid_h = roi_bin_grid_w = sampling_ratio
else:
- roi_bin_grid_h = int(np.ceil(roi_h * 1.0 / PH))
- roi_bin_grid_w = int(np.ceil(roi_w * 1.0 / PW))
- count = roi_bin_grid_h * roi_bin_grid_w
+ roi_bin_grid_h = int(np.ceil(roi_h / T(PH)))
+ roi_bin_grid_w = int(np.ceil(roi_w / T(PW)))
+ count = T(roi_bin_grid_h * roi_bin_grid_w)
for c in range(C_out):
for ph in range(PH):
for pw in range(PW):
- val = 0.0
+ val = T(0.0)
c_in = c * PH * PW + ph * PW + pw if position_sensitive else c
for iy in range(roi_bin_grid_h):
- y = sh + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h
+ y = sh + T(ph) * bin_h + (T(iy) + T(0.5)) * \
+ bin_h / T(roi_bin_grid_h)
for ix in range(roi_bin_grid_w):
- x = sw + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w
- v, g = bilinear_interpolate(bdata[c_in], H, W, y, x)
+ x = sw + T(pw) * bin_w + (T(ix) + T(0.5)) * \
+ bin_w / T(roi_bin_grid_w)
+ v, g = bilinear_interpolate(
+ bdata[c_in], H, W, y, x)
+ assert_same_dtype(v.dtype, T)
val += v
# compute grad
for qy, qx, qw in g:
- dx[batch_ind, c_in, qy, qx] += dy[r, c, ph, pw] * qw * 1.0 / count
-
- out[r, c, ph, pw] = val * 1.0 / count
+ assert_same_dtype(qw.dtype, T)
+ dx[batch_ind, c_in, qy, qx] += dy[r,
+ c, ph, pw] * qw / count
+ out[r, c, ph, pw] = val / count
+ assert_same_dtype(out.dtype, T)
return out, [dx, drois]
def test_roi_align_value(sampling_ratio=0, position_sensitive=False):
- ctx=default_context()
+ ctx = default_context()
dtype = np.float32
-
dlen = 224
N, C, H, W = 5, 3, 16, 16
- assert H == W
R = 7
pooled_size = (3, 4)
C = C * pooled_size[0] * pooled_size[1] if position_sensitive else C
-
spatial_scale = H * 1.0 / dlen
- data = mx.nd.array(np.arange(N*C*W*H).reshape((N,C,H,W)), ctx=ctx, dtype = dtype)
- # data = mx.nd.random.uniform(0, 1, (N, C, H, W), dtype = dtype)
- center_xy = mx.nd.random.uniform(0, dlen, (R, 2), ctx=ctx, dtype = dtype)
- wh = mx.nd.random.uniform(0, dlen, (R, 2), ctx=ctx, dtype = dtype)
- batch_ind = mx.nd.array(np.random.randint(0, N, size = (R,1)), ctx=ctx)
- pos = mx.nd.concat(center_xy - wh / 2, center_xy + wh / 2, dim = 1)
- rois = mx.nd.concat(batch_ind, pos, dim = 1)
+ data = mx.nd.array(
+ np.arange(N * C * W * H).reshape((N, C, H, W)), ctx=ctx, dtype=dtype)
+ center_xy = mx.nd.random.uniform(0, dlen, (R, 2), ctx=ctx, dtype=dtype)
+ wh = mx.nd.random.uniform(0, dlen, (R, 2), ctx=ctx, dtype=dtype)
+ batch_ind = mx.nd.array(np.random.randint(0, N, size=(R, 1)), ctx=ctx)
+ pos = mx.nd.concat(center_xy - wh / 2, center_xy + wh / 2, dim=1)
+ rois = mx.nd.concat(batch_ind, pos, dim=1)
data.attach_grad()
rois.attach_grad()
with mx.autograd.record():
output = mx.nd.contrib.ROIAlign(data, rois, pooled_size=pooled_size,
- spatial_scale=spatial_scale, sample_ratio=sampling_ratio,
- position_sensitive=position_sensitive)
+ spatial_scale=spatial_scale, sample_ratio=sampling_ratio,
+ position_sensitive=position_sensitive)
C_out = C // pooled_size[0] // pooled_size[1] if position_sensitive else C
- dy = mx.nd.random.uniform(-1, 1, (R, C_out) + pooled_size, ctx=ctx, dtype = dtype)
+ dy = mx.nd.random.uniform(-1, 1, (R, C_out) +
+ pooled_size, ctx=ctx, dtype=dtype)
output.backward(dy)
real_output, [dx, drois] = roialign_forward_backward(data.asnumpy(), rois.asnumpy(), pooled_size,
spatial_scale, sampling_ratio,
position_sensitive, dy.asnumpy())
- assert np.allclose(output.asnumpy(), real_output)
- # It seems that the precision between Cfloat and Pyfloat is different.
- assert np.allclose(data.grad.asnumpy(), dx, atol = 1e-5), np.abs(data.grad.asnumpy() - dx).max()
- assert np.allclose(rois.grad.asnumpy(), drois)
+
+ assert_almost_equal(output.asnumpy(), real_output, atol=1e-3)
+ assert_almost_equal(data.grad.asnumpy(), dx, atol=1e-3)
+ assert_almost_equal(rois.grad.asnumpy(), drois, atol=1e-3)
# modified from test_roipooling()
def test_roi_align_autograd(sampling_ratio=0):
@@ -7084,10 +7095,10 @@ def test_op_roi_align():
[1, 3.1, 1.1, 5.2, 10.2]], dtype='float64')
check_numeric_gradient(sym=test, location=[x1, x2],
- grad_nodes={'data':'write', 'rois':'null'},
+ grad_nodes={'data': 'write', 'rois': 'null'},
numeric_eps=1e-4, rtol=1e-1, atol=1e-4, ctx=ctx)
check_numeric_gradient(sym=test, location=[x1, x2],
- grad_nodes={'data':'add', 'rois':'null'},
+ grad_nodes={'data': 'add', 'rois': 'null'},
numeric_eps=1e-4, rtol=1e-1, atol=1e-4, ctx=ctx)
test_roi_align_value()
@@ -7095,6 +7106,7 @@ def test_op_roi_align():
test_roi_align_value(position_sensitive=True)
test_roi_align_autograd()
+
@with_seed()
def test_diag():