You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2020/04/18 02:19:42 UTC

[incubator-mxnet] branch v1.x updated: Cherry-pick of #17995 and #17937 to 1.x branch (#18041)

This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 814530d  Cherry-pick of #17995 and #17937 to 1.x branch (#18041)
814530d is described below

commit 814530d04650a97f7a995c91309fedfb79fb8473
Author: Przemyslaw Tredak <pt...@nvidia.com>
AuthorDate: Fri Apr 17 19:18:50 2020 -0700

    Cherry-pick of #17995 and #17937 to 1.x branch (#18041)
    
    * Fix ElemwiseSum for more than 4 inputs (#17995)
    
    * Fix ElemwiseSum for more than 4 inputs
    
    * Added test
    
    * Fix for handling negative indices in the fusion of slice (#17937)
    
    * Fix for handling of negative axis, begin and end in fusion of slice ops
    
    * Added test
---
 src/operator/fusion/fused_op-inl.h     |  8 ++++----
 src/operator/fusion/fused_op.cu        |  7 +++++++
 src/operator/tensor/elemwise_sum.h     |  2 +-
 tests/python/gpu/test_fusion.py        | 16 +++++++++++++++-
 tests/python/unittest/test_operator.py | 19 +++++++++++++++++++
 5 files changed, 46 insertions(+), 6 deletions(-)

diff --git a/src/operator/fusion/fused_op-inl.h b/src/operator/fusion/fused_op-inl.h
index 005ea4d..c838d85 100644
--- a/src/operator/fusion/fused_op-inl.h
+++ b/src/operator/fusion/fused_op-inl.h
@@ -399,8 +399,8 @@ __device__ inline VectorType<DType, nvec> load_slice(const DType * input, const
   strides[ndim-1] = 1;
   #pragma unroll
   for (int dim = ndim-1; dim >=0; dim--) {
-    if (begin[dim] < 0) begin[dim] = shape[dim] - begin[dim];
-    if (end[dim] < 0) end[dim] = shape[dim] - end[dim];
+    if (begin[dim] < 0) begin[dim] = shape[dim] + begin[dim];
+    if (end[dim] < 0) end[dim] = shape[dim] + end[dim];
     if (end[dim] == INT_MAX) end[dim] = shape[dim];
     if (dim > 0) {
       ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]);
@@ -442,8 +442,8 @@ __device__ inline VectorType<DType, nvec> fast_load_slice(const DType * input,
   strides[ndim-1] = 1;
   #pragma unroll
   for (int dim = ndim-1; dim >=0; dim--) {
-    if (begin[dim] < 0) begin[dim] = shape[dim] - begin[dim];
-    if (end[dim] < 0) end[dim] = shape[dim] - end[dim];
+    if (begin[dim] < 0) begin[dim] = shape[dim] + begin[dim];
+    if (end[dim] < 0) end[dim] = shape[dim] + end[dim];
     if (end[dim] == INT_MAX) end[dim] = shape[dim];
     if (dim > 0) {
       ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]);
diff --git a/src/operator/fusion/fused_op.cu b/src/operator/fusion/fused_op.cu
index f883c5c..cb13dbf 100644
--- a/src/operator/fusion/fused_op.cu
+++ b/src/operator/fusion/fused_op.cu
@@ -270,6 +270,13 @@ std::string FusedOp::GenerateCode(const std::vector<OpReqType> &req,
             return out;
           };
           auto build_tuple = [ndim](int axis, const std::string str, const std::string def) {
+            if (axis < 0 &&
+                axis >= -ndim) {
+              axis += ndim;
+            }
+            if (axis < 0 || axis >= ndim) {
+              LOG(FATAL) << "Axis " << axis << " is out of bounds for array of dimension " << ndim;
+            }
             std::string tuple = "{";
             for (int i = 0; i < axis; i++) {
                 tuple = tuple + def + ",";
diff --git a/src/operator/tensor/elemwise_sum.h b/src/operator/tensor/elemwise_sum.h
index e89e9d7..259c80d 100644
--- a/src/operator/tensor/elemwise_sum.h
+++ b/src/operator/tensor/elemwise_sum.h
@@ -94,7 +94,7 @@ void ElementWiseSumCompute_(const nnvm::NodeAttrs& attrs,
       Kernel<Sum, xpu>::Launch(s, out_size, out_dptr, req[0], in_0_dptr);
       for (size_t i = 1; i < size; ++i) {
         DType* in_dptr = in_data[i].dptr<DType>();
-        Kernel<Sum, xpu>::Launch(s, out_size, out_dptr, req[0], out_dptr, in_dptr);
+        Kernel<Sum, xpu>::Launch(s, out_size, out_dptr, kWriteTo, out_dptr, in_dptr);
       }
       break;
     }
diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py
index f69d50c..61fba10 100644
--- a/tests/python/gpu/test_fusion.py
+++ b/tests/python/gpu/test_fusion.py
@@ -188,7 +188,10 @@ def check_other_ops():
     b = mx.sym.Variable('b')
     c = mx.sym.Variable('c')
     shape = rand_shape_2d()
-    shape = (5,) + shape
+    shape = list((5,) + shape)
+    # Make sure there is at least 2 elements for the test with negative indices
+    shape[1] += 1
+    shape[2] += 1
     arr1 = mx.random.uniform(shape=shape)
     arr2 = mx.random.uniform(shape=shape)
     arr3 = mx.random.uniform(shape=shape)
@@ -197,6 +200,9 @@ def check_other_ops():
 
     check_fused_symbol(mx.sym.slice_axis(a, axis=0, begin=1, end=4), a=arr1)
 
+    # Testing handling of negative axis
+    check_fused_symbol(mx.sym.slice_axis(a, axis=-3, begin=1, end=4), a=arr1)
+
     begin = (random.randint(0, shape[0]-1),
              random.randint(0, shape[1]-1),
              random.randint(0, shape[2]-1))
@@ -205,6 +211,14 @@ def check_other_ops():
            random.randint(begin[2]+1, shape[2]))
     check_fused_symbol(mx.sym.slice(a, begin=begin, end=end), a=arr1)
 
+    begin = (random.randint(-shape[0], -2),
+             random.randint(-shape[1], -2),
+             random.randint(-shape[2], -2))
+    end = (random.randint(begin[0]+1, -1),
+           random.randint(begin[1]+1, -1),
+           random.randint(begin[2]+1, -1))
+    check_fused_symbol(mx.sym.slice(a, begin=begin, end=end), a=arr1)
+
     arr1 = mx.random.uniform(shape=(2,3,4,5))
     arr2 = mx.random.uniform(shape=(1,2,3))
     check_fused_symbol(mx.sym.slice_like(a,b, axes=[-2, 0]), a=arr1, b=arr2)
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 6cbbc5d..df4a77f 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -9875,6 +9875,25 @@ def test_im2col_col2im():
         pad         = 1
     )
 
+def test_elemwise_sum_for_gradient_accumulation():
+    for nrepeat in range(1, 10):
+        stored_grad = dict()
+        for grad_req in ['write', 'add']:
+            a = mx.nd.array([1])
+            b = mx.nd.array([2])
+            if grad_req == 'write':
+                a.attach_grad(grad_req='write')
+            elif grad_req == 'add':
+                a.attach_grad(grad_req='add')
+            a.grad[:] = 0
+            with mx.autograd.record():
+                for _ in range(nrepeat):
+                    b = b * a
+                b.backward()
+            stored_grad[grad_req] = a.grad.asscalar()
+        assert stored_grad['write'] == stored_grad['add']
+        assert stored_grad['write'] == 2 * nrepeat
+
 
 if __name__ == '__main__':
     import nose