You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/12/19 02:27:46 UTC

[GitHub] anirudhacharya commented on a change in pull request #13609: [MXNET-1258]fix unittest for ROIAlign Operator

anirudhacharya commented on a change in pull request #13609: [MXNET-1258]fix unittest for ROIAlign Operator
URL: https://github.com/apache/incubator-mxnet/pull/13609#discussion_r242773831
 
 

 ##########
 File path: tests/python/unittest/test_operator.py
 ##########
 @@ -6891,54 +6894,70 @@ def bilinear_interpolate(bottom, height, width, y, x):
         w3 = ly * hx
         w4 = ly * lx
 
+        assert w1.dtype == T
+        assert w2.dtype == T
+        assert w3.dtype == T
+        assert w4.dtype == T
+
         val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
+        assert val.dtype == T
         grad = [(y_low, x_low, w1), (y_low, x_high, w2),
                 (y_high, x_low, w3), (y_high, x_high, w4)
-               ]
+                ]
         return val, grad
 
+
     def roialign_forward_backward(data, rois, pooled_size, spatial_scale, sampling_ratio, dy):
         N, C, H, W = data.shape
         R = rois.shape[0]
         PH, PW = pooled_size
         assert len(rois.shape) == 2
         assert rois.shape[1] == 5
+        assert data.dtype == T
+        assert rois.dtype == T
 
-        out = np.zeros((R, C, PH, PW))
+        out = np.zeros((R, C, PH, PW), dtype=T)
         dx = np.zeros_like(data)
         drois = np.zeros_like(rois)
 
         for r in range(R):
             batch_ind = int(rois[r, 0])
-            sw, sh, ew, eh = rois[r, 1:5] * spatial_scale
-            roi_w = max(ew - sw, 1.0)
-            roi_h = max(eh - sh, 1.0)
-            bin_h = roi_h * 1.0 / PH
-            bin_w = roi_w * 1.0 / PW
+            sw, sh, ew, eh = rois[r, 1:5] * T(spatial_scale)
+            roi_w = T(max(ew - sw, 1.0))
+            roi_h = T(max(eh - sh, 1.0))
+            bin_h = roi_h / T(PH)
+            bin_w = roi_w / T(PW)
             bdata = data[batch_ind]
             if sampling_ratio > 0:
                 roi_bin_grid_h = roi_bin_grid_w = sampling_ratio
             else:
-                roi_bin_grid_h = int(np.ceil(roi_h * 1.0 / PH))
-                roi_bin_grid_w = int(np.ceil(roi_w * 1.0 / PW))
-            count = roi_bin_grid_h * roi_bin_grid_w
+                roi_bin_grid_h = int(np.ceil(roi_h / T(PH)))
+                roi_bin_grid_w = int(np.ceil(roi_w / T(PW)))
+            count = T(roi_bin_grid_h * roi_bin_grid_w)
             for c in range(C):
                 for ph in range(PH):
                     for pw in range(PW):
-                        val = 0.0
+                        val = T(0.0)
                         for iy in range(roi_bin_grid_h):
-                            y = sh + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h
+                            y = sh + T(ph) * bin_h + (T(iy) + T(0.5)) * \
+                                bin_h / T(roi_bin_grid_h)
                             for ix in range(roi_bin_grid_w):
-                                x = sw + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w
+                                x = sw + T(pw) * bin_w + (T(ix) + T(0.5)) * \
+                                    bin_w / T(roi_bin_grid_w)
                                 v, g = bilinear_interpolate(bdata[c], H, W, y, x)
+                                assert v.dtype == T
                                 val += v
                                 # compute grad
                                 for qy, qx, qw in g:
-                                    dx[batch_ind, c, qy, qx] += dy[r, c, ph, pw] * qw * 1.0 / count
+                                    assert qw.dtype == T
+                                    dx[batch_ind, c, qy, qx] += dy[r,
+                                                                   c, ph, pw] * qw / count
 
-                        out[r, c, ph, pw] = val * 1.0 / count
+                        out[r, c, ph, pw] = val / count
+        assert out.dtype == T, out.dtype
 
 Review comment:
   can you add a more descriptive error message

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services