You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ap...@apache.org on 2019/03/18 18:38:27 UTC

[incubator-mxnet] branch master updated: begin=end not a valid input (#14403)

This is an automated email from the ASF dual-hosted git repository.

apeforest pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new d671528  begin=end not a valid input (#14403)
d671528 is described below

commit d671528b6fa08eb36af73ca085371ed8045939d6
Author: Manu Seth <22...@users.noreply.github.com>
AuthorDate: Mon Mar 18 11:37:39 2019 -0700

    begin=end not a valid input (#14403)
    
    refactoring logic for indexing
---
 src/operator/tensor/matrix_op-inl.h    | 65 +++++++++++++++-------------------
 tests/python/unittest/test_operator.py |  9 +++++
 2 files changed, 38 insertions(+), 36 deletions(-)

diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index 3a58c12..5eecda6 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -653,50 +653,43 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
   }
 
   for (index_t i = 0; i < param_begin.ndim(); ++i) {
-    index_t b = 0, e = dshape[i], s = 1;
-    const index_t len = dshape[i];
-    if (param_step.ndim() != 0U) {
-      const auto& opt_step_val = param_step[i];
-      if (opt_step_val.has_value()) {
-        s = opt_step_val.value();
-        CHECK_NE(s, 0) << "slice op step[" << i << "] cannot be 0";
-      }
-    }
+    index_t s = param_step.ndim() != 0U && param_step[i].has_value() ? param_step[i].value() : 1;
+    CHECK_NE(s, 0) << "slice op step[" << i << "] cannot be 0";
 
-    if (len) {
-      if (param_begin[i].has_value()) {
-        b = param_begin[i].value();
-        if (b < 0) {
-          b += len;
-          CHECK_GE(b, 0) << "slicing with begin[" << i << "]="
-                         << b - len << " exceeds limit of " << len;
-        }
-      } else if (s < 0) {
-        b = len - 1;
+    index_t b = 0, e = 0;
+    const index_t len = dshape[i];
+    if (len > 0) {
+      b = param_begin[i].has_value() ? param_begin[i].value() : (s < 0 ? len - 1 : 0);
+      e = param_end[i].has_value() ? param_end[i].value() : (s < 0 ? -1 : len);
+
+      // checking upper and lower bounds for begin
+      if (b < 0) {
+        b += len;
+        CHECK_GE(b, 0) << "slicing with begin[" << i << "]=" << b - len
+                       << " exceeds limit of input dimension[" << i << "]=" << len;
       }
-      CHECK_LT(b, len) << "slicing with begin[" << i << "]="
-                       << b << " exceends limit of " << len;
-
-      if (param_end[i].has_value()) {
-        e = param_end[i].value();
-        if (e < 0) {
-          e += len;
-          CHECK_GE(e, 0) << "slicing with end[" << i << "]="
-                         << e - len << " exceeds limit of " << len;
-        }
-      } else if (s < 0) {
-        e = -1;
+      CHECK_LT(b, len) << "slicing with begin[" << i << "]=" << b
+                       << " exceeds limit of input dimension[" << i << "]=" << len;
+
+      // checking upper and lower bounds for end
+      if (e < 0 && param_end[i].has_value()) {
+        e += len;
+        CHECK_GE(e, 0) << "slicing with end[" << i << "]=" << e - len
+                       << " exceeds limit of input dimension[" << i << "]=" << len;
       }
-      CHECK_LE(e, len) << "slicing with end[" << i << "]="
-                       << e << " exceeds limit of " << len;
-    } else {
-      b = 0;
-      e = 0;
+      CHECK_LE(e, len) << "slicing with end[" << i << "]=" << e
+                       << " exceeds limit of input dimension[" << i << "]=" << len;
+
+      // checking begin==end case which is not supported
+      CHECK_NE(b, e) << "slicing with begin[" << i << "]=end[" << i << "]="
+                     << e << " results in an empty tensor and is not supported";
     }
+
     (*begin)[i] = b;
     (*end)[i] = e;
     (*step)[i] = s;
   }
+
   for (index_t i = param_begin.ndim(); i < dshape.ndim(); ++i) {
     (*begin)[i] = 0;
     (*end)[i] = dshape[i];
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 7169395..f4d2ef3 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -6606,6 +6606,15 @@ def test_slice():
     for index in index_list:
         test_slice_forward_backward(arr, index)
 
+    def test_begin_equals_end(shape, begin, end, step):
+        in_arr = mx.nd.arange(np.prod(shape)).reshape(shape=shape)
+        out_arr = mx.nd.slice(in_arr, begin=begin, end=end, step=step)
+
+    assertRaises(MXNetError, test_begin_equals_end, (4,), (2,), (2,), (1,))
+    assertRaises(MXNetError, test_begin_equals_end, (1, 5), (None, 3), (None, 3), (-1, 1))
+    assertRaises(MXNetError, test_begin_equals_end, (3, 4, 5), (1, 3, 1), (3, 3, 1), (1, -3, 2))
+    assertRaises(MXNetError, test_begin_equals_end, (2, 4), (None, 2), (None, 2), (1, -1))
+
     # check numeric gradient
     in_data = np.arange(36).reshape(2, 2, 3, 3)
     data = mx.sym.Variable('data')