You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/04 05:55:36 UTC

[GitHub] szha closed pull request #12026: Fix reduce_kernel_M1

szha closed pull request #12026: Fix reduce_kernel_M1
URL: https://github.com/apache/incubator-mxnet/pull/12026
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/tensor/broadcast_reduce-inl.cuh b/src/operator/tensor/broadcast_reduce-inl.cuh
index be3d1f9223f..33bf72798fd 100644
--- a/src/operator/tensor/broadcast_reduce-inl.cuh
+++ b/src/operator/tensor/broadcast_reduce-inl.cuh
@@ -268,7 +268,11 @@ __global__ void reduce_kernel_M1(const int N, const bool addto,
   for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
     Shape<ndim> coord = unravel(idx, sshape);
     int j = ravel(coord, bshape);
-    assign(&small[idx], addto, OP::Map(big[j]));
+    DType val, residual;
+    Reducer::SetInitValue(val, residual);
+    Reducer::Reduce(val, OP::Map(big[j]), residual);
+    Reducer::Finalize(val, residual);
+    assign(&small[idx], addto, val);
   }
 }
 
@@ -287,7 +291,10 @@ __global__ void reduce_kernel_M1(const int N, const bool addto,
     int idx_big = ravel(coord, big_shape);
     int idx_lhs = ravel(coord, lhs_shape);
     int idx_rhs = ravel(coord, rhs_shape);
-    DType val = OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs]));
+    DType val, residual;
+    Reducer::SetInitValue(val, residual);
+    Reducer::Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual);
+    Reducer::Finalize(val, residual);
     assign(&small[idx], addto, val);
   }
 }
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index e55fa1af90e..ac6ee1561c4 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -1308,25 +1308,31 @@ def test_norm(ctx=default_context()):
 
     def l1norm(input_data, axis=0, keepdims=False):
         return np.sum(abs(input_data), axis=axis, keepdims=keepdims)
-    def l2norm(input_data, axis=0, keepdims=False): 
+    def l2norm(input_data, axis=0, keepdims=False):
         return sp_norm(input_data, axis=axis, keepdims=keepdims)
 
     in_data_dim = random_sample([4,5,6], 1)[0]
-    in_data_shape = rand_shape_nd(in_data_dim)
-    np_arr = np.random.uniform(-1, 1, in_data_shape).astype(np.float32)
-    mx_arr = mx.nd.array(np_arr, ctx=ctx)
-    for ord in [1,2]:
-        for keep_dims in [True, False]:
-            for i in range(4):
-                npy_out = l1norm(np_arr, i, keep_dims) if ord==1 else l2norm(np_arr, i, keep_dims)
-                mx_out = mx.nd.norm(mx_arr, ord=ord, axis=i, keepdims=keep_dims)
-                assert npy_out.shape == mx_out.shape
-                mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy())
-                if (i < 3):
-                    npy_out = l1norm(np_arr, (i, i+1), keep_dims) if ord==1 else l2norm(np_arr, (i, i+1), keep_dims)
-                    mx_out = mx.nd.norm(mx_arr, ord=ord, axis=(i, i+1), keepdims=keep_dims)
+    for force_reduce_dim1 in [True, False]:
+        in_data_shape = rand_shape_nd(in_data_dim)
+        if force_reduce_dim1:
+            in_data_shape = in_data_shape[:3] + (1, ) + in_data_shape[4:]
+        np_arr = np.random.uniform(-1, 1, in_data_shape).astype(np.float32)
+        mx_arr = mx.nd.array(np_arr, ctx=ctx)
+        for ord in [1, 2]:
+            for keep_dims in [True, False]:
+                for i in range(4):
+                    npy_out = l1norm(np_arr, i, keep_dims) if ord == 1 else l2norm(
+                        np_arr, i, keep_dims)
+                    mx_out = mx.nd.norm(mx_arr, ord=ord, axis=i, keepdims=keep_dims)
                     assert npy_out.shape == mx_out.shape
                     mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy())
+                    if (i < 3):
+                        npy_out = l1norm(np_arr, (i, i + 1), keep_dims) if ord == 1 else l2norm(
+                            np_arr, (i, i + 1), keep_dims)
+                        mx_out = mx.nd.norm(mx_arr, ord=ord, axis=(i, i + 1), keepdims=keep_dims)
+                        assert npy_out.shape == mx_out.shape
+                        mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy())
+
 
 @with_seed()
 def test_ndarray_cpu_shared_ctx():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services