You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/17 18:54:07 UTC

[GitHub] sandeep-krishnamurthy closed pull request #12090: [MXNET-791] Pick with negative indices

sandeep-krishnamurthy closed pull request #12090: [MXNET-791] Pick with negative indices
URL: https://github.com/apache/incubator-mxnet/pull/12090
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 6bc97bb71fc..8d8aeaca73e 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -177,3 +177,4 @@ List of Contributors
 * [Istvan Fehervari](https://github.com/ifeherva)
 * [Aaron Markham](https://github.com/aaronmarkham)
 * [Sam Skalicky](https://github.com/samskalicky)
+* [Per Goncalves da Silva](https://github.com/perdasilva)
diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h
index 41cc6e527a6..351315ab0c8 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -99,6 +99,8 @@ struct ReduceAxisParam : public dmlc::Parameter<ReduceAxisParam> {
   }
 };
 
+enum PickOpMode {kWrap, kClip};
+
 struct PickParam : public dmlc::Parameter<PickParam> {
   dmlc::optional<int> axis;
   int mode;
@@ -112,6 +114,14 @@ struct PickParam : public dmlc::Parameter<PickParam> {
     DMLC_DECLARE_FIELD(keepdims).set_default(false)
       .describe("If true, the axis where we pick the elements is left "
                 "in the result as dimension with size one.");
+    DMLC_DECLARE_FIELD(mode)
+    .add_enum("wrap", kWrap)
+    .add_enum("clip", kClip)
+    .set_default(kClip)
+    .describe("Specify how out-of-bound indices behave. Default is \"clip\"."
+              " \"clip\" means clip to the range. So, if all indices mentioned are too large,"
+              " they are replaced by the index that addresses the last element along an axis. "
+              " \"wrap\" means to wrap around.");
   }
 };
 
@@ -1108,7 +1118,7 @@ void L2NormComputeEx(const nnvm::NodeAttrs& attrs,
                      const std::vector<NDArray>& outputs);
 
 /*! \brief index element from array along axes */
-template<int ndim>
+template<int ndim, bool clip = true>
 struct pick {
   template<typename DType, typename IType>
   MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
@@ -1117,15 +1127,20 @@ struct pick {
                                   mshadow::Shape<ndim> sshape) {
     using namespace broadcast;
     int j = static_cast<int>(idx[i]);
-    if (j < 0) j = 0;
-    else if (j >= M) j = M-1;
+    if (clip) {
+      if (j <= 0) j = 0;
+      else if (j >= M) j = M - 1;
+    } else {
+      j = j % M;
+      j += (j < 0) ? M : 0;
+    }
     j = ravel(unravel(i, sshape), bshape) + j*stride;
     out[i] = a[j];
   }
 };
 
 /*! \brief index element from array along axes */
-template<int ndim>
+template<int ndim, bool clip = true>
 struct pick_grad {
   template<typename DType, typename IType>
   MSHADOW_XINLINE static void Map(int i, DType* igrad, const DType* ograd,
@@ -1134,8 +1149,13 @@ struct pick_grad {
                                   mshadow::Shape<ndim> sshape) {
     using namespace broadcast;
     int j = static_cast<int>(idx[i]);
-    if (j < 0) j = 0;
-    else if (j >= M) j = M-1;
+    if (clip) {
+      if (j <= 0) j = 0;
+      else if (j >= M) j = M - 1;
+    } else {
+      j = j % M;
+      j += (j < 0) ? M : 0;
+    }
     j = ravel(unravel(i, sshape), bshape) + j*stride;
     igrad[j] += ograd[i];
   }
@@ -1195,15 +1215,28 @@ void PickOpForward(const nnvm::NodeAttrs& attrs,
 
   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {  // output type
     MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {  // index type
-      if (trailing == 1) {
-        Kernel<pick<2>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
-                                     inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
-                                     M, 1, Shape2(leading, M), Shape2(leading, 1));
+      if (param.mode == kWrap) {
+        if (trailing == 1) {
+            Kernel<pick<2, false>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
+                                    inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
+                                    M, 1, Shape2(leading, M), Shape2(leading, 1));
+        } else {
+            Kernel<pick<3, false>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
+                                    inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
+                                    M, trailing, Shape3(leading, M, trailing),
+                                    Shape3(leading, 1, trailing));
+        }
       } else {
-        Kernel<pick<3>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
-                                     inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
-                                     M, trailing, Shape3(leading, M, trailing),
-                                     Shape3(leading, 1, trailing));
+        if (trailing == 1) {
+            Kernel<pick<2, true>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
+                                   inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
+                                   M, 1, Shape2(leading, M), Shape2(leading, 1));
+        } else {
+            Kernel<pick<3, true>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
+                                   inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
+                                   M, trailing, Shape3(leading, M, trailing),
+                                   Shape3(leading, 1, trailing));
+        }
       }
     });
   });
@@ -1230,15 +1263,28 @@ void PickOpBackward(const nnvm::NodeAttrs& attrs,
   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {  // output type
     MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {  // index type
       if (req[0] != kAddTo) outputs[0].FlatTo1D<xpu, DType>(s) = 0;
-      if (trailing == 1) {
-        Kernel<pick_grad<2>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
-                                     inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
-                                     M, 1, Shape2(leading, M), Shape2(leading, 1));
+      if (param.mode == kWrap) {
+        if (trailing == 1) {
+          Kernel<pick_grad<2, false>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
+                                      inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
+                                      M, 1, Shape2(leading, M), Shape2(leading, 1));
+        } else {
+          Kernel<pick_grad<3, false>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
+                                      inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
+                                      M, trailing, Shape3(leading, M, trailing),
+                                      Shape3(leading, 1, trailing));
+        }
       } else {
-        Kernel<pick_grad<3>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
-                                     inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
-                                     M, trailing, Shape3(leading, M, trailing),
-                                     Shape3(leading, 1, trailing));
+          if (trailing == 1) {
+          Kernel<pick_grad<2, true>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
+                                      inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
+                                      M, 1, Shape2(leading, M), Shape2(leading, 1));
+        } else {
+          Kernel<pick_grad<3, true>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
+                                      inputs[0].dptr<DType>(), inputs[1].dptr<IType>(),
+                                      M, trailing, Shape3(leading, M, trailing),
+                                      Shape3(leading, 1, trailing));
+        }
       }
     });
   });
diff --git a/src/operator/tensor/broadcast_reduce_op_index.cc b/src/operator/tensor/broadcast_reduce_op_index.cc
index 6fd90df2107..969c23af974 100644
--- a/src/operator/tensor/broadcast_reduce_op_index.cc
+++ b/src/operator/tensor/broadcast_reduce_op_index.cc
@@ -133,6 +133,14 @@ Examples::
   // picks elements with specified indices along axis 1
   pick(x, y=[0,1,0], 1) = [ 1.,  4.,  5.]
 
+  y = [[ 1.],
+       [ 0.],
+       [ 2.]]
+
+  // picks elements with specified indices along axis 1 using 'wrap' mode
+  // to place indicies that would normally be out of bounds
+  pick(x, y=[2,-1,-2], 1, mode='wrap') = [ 1.,  4.,  5.]
+
   y = [[ 1.],
        [ 0.],
        [ 2.]]
@@ -165,7 +173,7 @@ Examples::
   })
 .add_argument("data", "NDArray-or-Symbol", "The input array")
 .add_argument("index", "NDArray-or-Symbol", "The index array")
-.add_arguments(ReduceAxisParam::__FIELDS__());
+.add_arguments(PickParam::__FIELDS__());
 
 
 NNVM_REGISTER_OP(_backward_pick)
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index f1aec12ccc3..ce737d848f7 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4523,33 +4523,41 @@ def test_log_softmax():
 def test_pick():
     def test_pick_helper(index_type=np.int32):
         for _ in range(100):
-            ndim = np.random.randint(1, 5)
-            bshape = np.random.randint(1, 10, size=ndim)
-            axis = np.random.randint(0, ndim)
-            sshape = bshape.copy()
-            sshape[axis] = 1
-            data = np.random.uniform(-1, 1, size=bshape)
-            index = np.random.randint(0, bshape[axis], size=sshape)
-            exp = []
-            for i in range(ndim):
-                if i == axis:
-                    exp.append(index)
+            for mode in ['clip', 'wrap']:
+                ndim = np.random.randint(1, 5)
+                bshape = np.random.randint(1, 10, size=ndim)
+                axis = np.random.randint(0, ndim)
+                sshape = bshape.copy()
+                sshape[axis] = 1
+                data = np.random.uniform(-1, 1, size=bshape)
+
+                if mode == 'wrap':
+                    index = np.random.randint(-2*bshape[axis], 2*bshape[axis], size=sshape)
                 else:
-                    ishape = [1 for _ in range(ndim)]
-                    ishape[i] = bshape[i]
-                    exp.append(np.arange(bshape[i]).reshape(ishape))
-            expected = data[exp]
-            data = mx.nd.array(data, dtype='float32')
-            index = mx.nd.array(index, dtype=index_type)
-            out = mx.nd.pick(data, index, axis=axis, keepdims=True)
-            assert_almost_equal(out.asnumpy(), expected)
-
-            data_holder = data
-            index_holder = index
-            data = mx.sym.Variable('data')
-            index = mx.sym.Variable('index')
-            sym = mx.sym.pick(data, index, axis=axis, keepdims=True)
-            check_numeric_gradient(sym, [data_holder, index_holder], grad_nodes=['data'])
+                    index = np.random.randint(0, bshape[axis], size=sshape)
+                exp = []
+                for i in range(ndim):
+                    if i == axis:
+                        if mode == 'wrap':
+                            exp.append(index % bshape[axis])
+                        else:
+                            exp.append(index)
+                    else:
+                        ishape = [1 for _ in range(ndim)]
+                        ishape[i] = bshape[i]
+                        exp.append(np.arange(bshape[i]).reshape(ishape))
+                expected = data[exp]
+                data = mx.nd.array(data, dtype='float32')
+                index = mx.nd.array(index, dtype=index_type)
+                out = mx.nd.pick(data, index, axis=axis, keepdims=True, mode=mode)
+                assert_almost_equal(out.asnumpy(), expected)
+
+                data_holder = data
+                index_holder = index
+                data = mx.sym.Variable('data')
+                index = mx.sym.Variable('index')
+                sym = mx.sym.pick(data, index, axis=axis, keepdims=True, mode=mode)
+                check_numeric_gradient(sym, [data_holder, index_holder], grad_nodes=['data'])
 
     test_pick_helper(np.int32)
     test_pick_helper(np.float32)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services