You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/05/07 16:39:17 UTC

[incubator-mxnet] branch master updated: Fix sign bug in spatial transformer interpolation. (#10741)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new cfd62e0  Fix sign bug in spatial transformer interpolation. (#10741)
cfd62e0 is described below

commit cfd62e03fcedabafed1691c74eec88e5d5d4adb8
Author: Martin Kiefel <mk...@nopw.de>
AuthorDate: Mon May 7 12:39:09 2018 -0400

    Fix sign bug in spatial transformer interpolation. (#10741)
    
    It's possible that the computed `data_index` wants to take negative numbers for
    cases when the interpolation lookup falls out of the source window. This is bad
    on 64-bit machines as the test for hitting the source window still works but
    once the the 32-bit unsigned number is added to the 64-bit pointer things break
    with a segmentation fault.
---
 src/operator/spatial_transformer.cc    |  4 +--
 src/operator/spatial_transformer.cu    |  2 +-
 tests/python/unittest/test_operator.py | 47 ++++++++++++++++++++++++++++++++++
 3 files changed, 50 insertions(+), 3 deletions(-)

diff --git a/src/operator/spatial_transformer.cc b/src/operator/spatial_transformer.cc
index 78f64a7..1393729 100644
--- a/src/operator/spatial_transformer.cc
+++ b/src/operator/spatial_transformer.cc
@@ -105,8 +105,8 @@ inline void BilinearSamplingBackward(const Tensor<cpu, 4, DType> &input_grad,
           const DType top_left_x_w = 1.0 - (x_real - top_left_x);
           for (index_t c = 0; c < static_cast<index_t>(o_c); ++c) {
             index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
-            index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w
-                                 + top_left_x;
+            const int data_index = n * i_c * i_h * i_w + c * i_h * i_w +
+                                   top_left_y * i_w + top_left_x;
             // calc 4 vertex value in input data
             DType top_left_v = 0;
             DType top_right_v = 0;
diff --git a/src/operator/spatial_transformer.cu b/src/operator/spatial_transformer.cu
index 27fe73e..4a39733 100644
--- a/src/operator/spatial_transformer.cu
+++ b/src/operator/spatial_transformer.cu
@@ -102,7 +102,7 @@ __global__ void BilinearSamplingBackwardKernel(const int i_c, const int i_h,
     DType top_left_x_w = 1.0 - (x_real - top_left_x);
     for (index_t c = 0; c < o_c; ++c) {
       index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
-      index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x;
+      int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x;
       // calc 4 vertex value in input data
       DType top_left_v = 0;
       DType top_right_v = 0;
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 7ee67dd..96dd0b2 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -2136,6 +2136,53 @@ def test_stn():
                     assert_almost_equal(out_grad.asnumpy(), grad_grad[0].asnumpy()[:, :, h//4:h-h//4, w//4:w-w//4], rtol=1e-2, atol=1e-4)
 
 
+def test_stn_valid_sampling():
+    target_shape = (
+        28,
+        28,
+    )
+    src_shape = (
+        42,
+        42,
+    )
+
+    data = mx.sym.Variable(name="data")
+    loc = mx.sym.Variable(name="loc")
+
+    data_array = np.zeros((
+        1,
+        1,
+    ) + src_shape)
+    # Have an ever so slight rotation.
+    loc_array = np.array(
+        [[9.03887e-05, 1.00015, 0.00174931, 1.0003, 0.000311901,
+          -0.000919065]])
+
+    stn = mx.sym.SpatialTransformer(
+        data=data,
+        loc=loc,
+        target_shape=target_shape,
+        transform_type="affine",
+        sampler_type="bilinear")
+
+    grad_req = {k: 'write' for k in stn.list_arguments()}
+    grads = {
+        'data': mx.nd.array(np.zeros_like(data_array)),
+        'loc': mx.nd.array(np.zeros_like(loc_array))
+    }
+    executor = stn.bind(
+        ctx=default_context(),
+        args={'data': mx.nd.array(data_array),
+              'loc': mx.nd.array(loc_array)},
+        grad_req=grad_req,
+        args_grad=grads)
+    executor.forward(is_train=True)
+    executor.backward(mx.nd.ones((
+        1,
+        1,
+    ) + target_shape))
+
+
 # Seed set because the test is not robust enough to operate on random data
 @with_seed(1234)
 def test_dot():

-- 
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.