You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by wk...@apache.org on 2020/04/11 13:44:07 UTC

[incubator-mxnet] branch master updated: [Bug Fix] support multiple-dim input for unravel_index (#17748)

This is an automated email from the ASF dual-hosted git repository.

wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 6692d2c  [Bug Fix] support multiple-dim input for unravel_index (#17748)
6692d2c is described below

commit 6692d2cc76c4bb841d43abbe53f4d4aff059ba77
Author: JackieWu <wk...@live.cn>
AuthorDate: Sat Apr 11 21:43:12 2020 +0800

    [Bug Fix] support multiple-dim input for unravel_index (#17748)
    
    * support multiple-dim input for unravel_index
    
    * sanity
---
 src/operator/tensor/ravel.cc           | 12 ++++++++++--
 src/operator/tensor/ravel.h            | 19 ++++++++++++++-----
 tests/python/unittest/test_operator.py | 13 +++++++++++++
 3 files changed, 37 insertions(+), 7 deletions(-)

diff --git a/src/operator/tensor/ravel.cc b/src/operator/tensor/ravel.cc
index e04628e..e7cd303 100644
--- a/src/operator/tensor/ravel.cc
+++ b/src/operator/tensor/ravel.cc
@@ -62,8 +62,16 @@ NNVM_REGISTER_OP(_unravel_index)
 Examples::
 
    A = [22,41,37]
-   unravel(A, shape=(7,6)) = [[3,6,6],[4,5,1]]
-   unravel(A, shape=(-1,6)) = [[3,6,6],[4,5,1]]
+   unravel_index(A, shape=(7,6)) = [[3,6,6],
+                                    [4,5,1]]
+   unravel_index(A, shape=(-1,6)) = [[3,6,6],
+                                     [4,5,1]]
+
+   B = [[22,41,37],[10,11,15]]
+   unravel_index(B, shape=(7,6)) = [[[3,6,6],[1,1,2]],
+                                    [[4,5,1],[4,5,3]]]
+   unravel_index(B, shape=(-1,6)) = [[[3,6,6],[1,1,2]],
+                                     [[4,5,1],[4,5,3]]]
 
 )code" ADD_FILELINE)
 .set_num_inputs(1)
diff --git a/src/operator/tensor/ravel.h b/src/operator/tensor/ravel.h
index d96b9cf..abf9383 100644
--- a/src/operator/tensor/ravel.h
+++ b/src/operator/tensor/ravel.h
@@ -76,16 +76,24 @@ inline bool UnravelOpShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_attrs->size(), 1);
   CHECK_EQ(out_attrs->size(), 1);
   CHECK_GT(shape.ndim(), 0) << "Empty shape parameter for unravel operator.";
-  if ((*in_attrs)[0].ndim() > 0) {
-    SHAPE_ASSIGN_CHECK(*out_attrs, 0, Shape2(shape.ndim(), (*in_attrs)[0][0]));
+  const mxnet::TShape &in_shape = (*in_attrs)[0];
+  if (in_shape.ndim() > 0) {
+    mxnet::TShape out_shape(in_shape.ndim() + 1, -1);
+    out_shape[0] = shape.ndim();
+    for (int i = 0; i < in_shape.ndim(); ++i) {
+      out_shape[i+1] = in_shape[i];
+    }
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape);
     return true;
   }
   if ((*out_attrs)[0].ndim() > 0) {
+    const mxnet::TShape &out_shape = (*out_attrs)[0];
     CHECK_EQ((*out_attrs)[0].ndim(), 2)
       << "Output of unravel operator must be two-dimensional.";
     CHECK_EQ((*out_attrs)[0][0], shape.ndim())
       << "First dimension of output of ravel operator does not match shape parameter dimension.";
-    SHAPE_ASSIGN_CHECK(*in_attrs, 0, Shape1((*out_attrs)[0][1]));
+    SHAPE_ASSIGN_CHECK(*in_attrs, 0, mxnet::TShape(
+          out_shape.data() + 1, out_shape.data() + out_shape.ndim()));
     return true;
   }
   return false;
@@ -156,8 +164,9 @@ void UnravelForward(const nnvm::NodeAttrs& attrs,
   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
     Tensor<xpu, 1, OType> in = inputs[0].FlatTo1D<xpu, OType>(s);
     Tensor<xpu, 1, OType> out = outputs[0].FlatTo1D<xpu, OType>(s);
-    mxnet_op::Kernel<unravel_index, xpu>::Launch(s, in.size(0), in.size(0), out.size(0)/in.size(0),
-                                                 work.dptr_, out.dptr_, in.dptr_);
+    mxnet_op::Kernel<unravel_index, xpu>::Launch(
+        s, in.shape_.Size(), in.shape_.Size(), shape.ndim(),
+        work.dptr_, out.dptr_, in.dptr_);
   });
 }
 
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 0c795db..230073a 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -8427,6 +8427,19 @@ def test_ravel():
       c = mx.sym.unravel_index(a, shape=shape2)
       check_symbolic_forward(c, location={'a': ravel_npy}, expected=[data])
 
+
+@with_seed()
+def test_unravel_index():
+    unravel_shape = (2, 10)
+    unravel_size = np.prod(unravel_shape)
+    for shape in [(10,), (2, 10), (3, 4, 5)]:
+        a = np.random.randint(0, unravel_size, size=shape)
+        b = np.stack(np.unravel_index(a, shape=unravel_shape), 0)
+        a_mx = mx.nd.array(a)
+        b_mx = mx.nd.unravel_index(a_mx, shape=unravel_shape)
+        assert_array_equal(b, b_mx.asnumpy())
+
+
 def test_context_num_gpus():
     try:
         # Note: the test is run both on GPU and CPU hosts, so that we can not assert