You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/05/07 16:39:17 UTC
[incubator-mxnet] branch master updated: Fix sign bug in spatial
transformer interpolation. (#10741)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new cfd62e0 Fix sign bug in spatial transformer interpolation. (#10741)
cfd62e0 is described below
commit cfd62e03fcedabafed1691c74eec88e5d5d4adb8
Author: Martin Kiefel <mk...@nopw.de>
AuthorDate: Mon May 7 12:39:09 2018 -0400
Fix sign bug in spatial transformer interpolation. (#10741)
It's possible that the computed `data_index` wants to take negative numbers for
cases when the interpolation lookup falls out of the source window. This is bad
on 64-bit machines as the test for hitting the source window still works but
once the the 32-bit unsigned number is added to the 64-bit pointer things break
with a segmentation fault.
---
src/operator/spatial_transformer.cc | 4 +--
src/operator/spatial_transformer.cu | 2 +-
tests/python/unittest/test_operator.py | 47 ++++++++++++++++++++++++++++++++++
3 files changed, 50 insertions(+), 3 deletions(-)
diff --git a/src/operator/spatial_transformer.cc b/src/operator/spatial_transformer.cc
index 78f64a7..1393729 100644
--- a/src/operator/spatial_transformer.cc
+++ b/src/operator/spatial_transformer.cc
@@ -105,8 +105,8 @@ inline void BilinearSamplingBackward(const Tensor<cpu, 4, DType> &input_grad,
const DType top_left_x_w = 1.0 - (x_real - top_left_x);
for (index_t c = 0; c < static_cast<index_t>(o_c); ++c) {
index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
- index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w
- + top_left_x;
+ const int data_index = n * i_c * i_h * i_w + c * i_h * i_w +
+ top_left_y * i_w + top_left_x;
// calc 4 vertex value in input data
DType top_left_v = 0;
DType top_right_v = 0;
diff --git a/src/operator/spatial_transformer.cu b/src/operator/spatial_transformer.cu
index 27fe73e..4a39733 100644
--- a/src/operator/spatial_transformer.cu
+++ b/src/operator/spatial_transformer.cu
@@ -102,7 +102,7 @@ __global__ void BilinearSamplingBackwardKernel(const int i_c, const int i_h,
DType top_left_x_w = 1.0 - (x_real - top_left_x);
for (index_t c = 0; c < o_c; ++c) {
index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
- index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x;
+ int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x;
// calc 4 vertex value in input data
DType top_left_v = 0;
DType top_right_v = 0;
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 7ee67dd..96dd0b2 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -2136,6 +2136,53 @@ def test_stn():
assert_almost_equal(out_grad.asnumpy(), grad_grad[0].asnumpy()[:, :, h//4:h-h//4, w//4:w-w//4], rtol=1e-2, atol=1e-4)
+def test_stn_valid_sampling():
+ target_shape = (
+ 28,
+ 28,
+ )
+ src_shape = (
+ 42,
+ 42,
+ )
+
+ data = mx.sym.Variable(name="data")
+ loc = mx.sym.Variable(name="loc")
+
+ data_array = np.zeros((
+ 1,
+ 1,
+ ) + src_shape)
+ # Have an ever so slight rotation.
+ loc_array = np.array(
+ [[9.03887e-05, 1.00015, 0.00174931, 1.0003, 0.000311901,
+ -0.000919065]])
+
+ stn = mx.sym.SpatialTransformer(
+ data=data,
+ loc=loc,
+ target_shape=target_shape,
+ transform_type="affine",
+ sampler_type="bilinear")
+
+ grad_req = {k: 'write' for k in stn.list_arguments()}
+ grads = {
+ 'data': mx.nd.array(np.zeros_like(data_array)),
+ 'loc': mx.nd.array(np.zeros_like(loc_array))
+ }
+ executor = stn.bind(
+ ctx=default_context(),
+ args={'data': mx.nd.array(data_array),
+ 'loc': mx.nd.array(loc_array)},
+ grad_req=grad_req,
+ args_grad=grads)
+ executor.forward(is_train=True)
+ executor.backward(mx.nd.ones((
+ 1,
+ 1,
+ ) + target_shape))
+
+
# Seed set because the test is not robust enough to operate on random data
@with_seed(1234)
def test_dot():
--
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.