You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2020/04/21 06:36:28 UTC

[incubator-mxnet] branch v1.6.x updated: Fix for handling negative indices in the fusion of slice (#17937) (#18118)

This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch v1.6.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.6.x by this push:
     new 0a5e9cc  Fix for handling negative indices in the fusion of slice (#17937) (#18118)
0a5e9cc is described below

commit 0a5e9cc9645511d88f97101825c89e66fd5a5150
Author: Przemyslaw Tredak <pt...@nvidia.com>
AuthorDate: Mon Apr 20 23:35:36 2020 -0700

    Fix for handling negative indices in the fusion of slice (#17937) (#18118)
    
    * Fix for handling of negative axis, begin and end in fusion of slice ops
    
    * Added test
---
 src/operator/fusion/fused_op-inl.h |  8 ++++----
 src/operator/fusion/fused_op.cu    |  7 +++++++
 tests/python/gpu/test_fusion.py    | 16 +++++++++++++++-
 3 files changed, 26 insertions(+), 5 deletions(-)

diff --git a/src/operator/fusion/fused_op-inl.h b/src/operator/fusion/fused_op-inl.h
index 7373cd0..f1d7364 100644
--- a/src/operator/fusion/fused_op-inl.h
+++ b/src/operator/fusion/fused_op-inl.h
@@ -391,8 +391,8 @@ __device__ inline VectorType<DType, nvec> load_slice(const DType * input, const
   strides[ndim-1] = 1;
   #pragma unroll
   for (int dim = ndim-1; dim >=0; dim--) {
-    if (begin[dim] < 0) begin[dim] = shape[dim] - begin[dim];
-    if (end[dim] < 0) end[dim] = shape[dim] - end[dim];
+    if (begin[dim] < 0) begin[dim] = shape[dim] + begin[dim];
+    if (end[dim] < 0) end[dim] = shape[dim] + end[dim];
     if (end[dim] == INT_MAX) end[dim] = shape[dim];
     if (dim > 0) {
       ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]);
@@ -434,8 +434,8 @@ __device__ inline VectorType<DType, nvec> fast_load_slice(const DType * input,
   strides[ndim-1] = 1;
   #pragma unroll
   for (int dim = ndim-1; dim >=0; dim--) {
-    if (begin[dim] < 0) begin[dim] = shape[dim] - begin[dim];
-    if (end[dim] < 0) end[dim] = shape[dim] - end[dim];
+    if (begin[dim] < 0) begin[dim] = shape[dim] + begin[dim];
+    if (end[dim] < 0) end[dim] = shape[dim] + end[dim];
     if (end[dim] == INT_MAX) end[dim] = shape[dim];
     if (dim > 0) {
       ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]);
diff --git a/src/operator/fusion/fused_op.cu b/src/operator/fusion/fused_op.cu
index c8a8883..4a8a7b4 100644
--- a/src/operator/fusion/fused_op.cu
+++ b/src/operator/fusion/fused_op.cu
@@ -270,6 +270,13 @@ std::string FusedOp::GenerateCode(const std::vector<OpReqType> &req,
             return out;
           };
           auto build_tuple = [ndim](int axis, const std::string str, const std::string def) {
+            if (axis < 0 &&
+                axis >= -ndim) {
+              axis += ndim;
+            }
+            if (axis < 0 || axis >= ndim) {
+              LOG(FATAL) << "Axis " << axis << " is out of bounds for array of dimension " << ndim;
+            }
             std::string tuple = "{";
             for (int i = 0; i < axis; i++) {
                 tuple = tuple + def + ",";
diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py
index 9a37c0e..1ec4524 100644
--- a/tests/python/gpu/test_fusion.py
+++ b/tests/python/gpu/test_fusion.py
@@ -188,7 +188,10 @@ def check_other_ops():
     b = mx.sym.Variable('b')
     c = mx.sym.Variable('c')
     shape = rand_shape_2d()
-    shape = (5,) + shape
+    shape = list((5,) + shape)
+    # Make sure there is at least 2 elements for the test with negative indices
+    shape[1] += 1
+    shape[2] += 1
     arr1 = mx.random.uniform(shape=shape)
     arr2 = mx.random.uniform(shape=shape)
     arr3 = mx.random.uniform(shape=shape)
@@ -197,6 +200,9 @@ def check_other_ops():
 
     check_fused_symbol(mx.sym.slice_axis(a, axis=0, begin=1, end=4), a=arr1)
 
+    # Testing handling of negative axis
+    check_fused_symbol(mx.sym.slice_axis(a, axis=-3, begin=1, end=4), a=arr1)
+
     begin = (random.randint(0, shape[0]-1),
              random.randint(0, shape[1]-1),
              random.randint(0, shape[2]-1))
@@ -205,6 +211,14 @@ def check_other_ops():
            random.randint(begin[2]+1, shape[2]))
     check_fused_symbol(mx.sym.slice(a, begin=begin, end=end), a=arr1)
 
+    begin = (random.randint(-shape[0], -2),
+             random.randint(-shape[1], -2),
+             random.randint(-shape[2], -2))
+    end = (random.randint(begin[0]+1, -1),
+           random.randint(begin[1]+1, -1),
+           random.randint(begin[2]+1, -1))
+    check_fused_symbol(mx.sym.slice(a, begin=begin, end=end), a=arr1)
+
     arr1 = mx.random.uniform(shape=(2,3,4,5))
     arr2 = mx.random.uniform(shape=(1,2,3))
     check_fused_symbol(mx.sym.slice_like(a,b, axes=[-2, 0]), a=arr1, b=arr2)