You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/03/23 01:41:01 UTC

[incubator-tvm] branch master updated: [Relay, Topi] [TF, MXNet] Unravel Index operator (#5082)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new fdc8b0d  [Relay, Topi] [TF, MXNet] Unravel Index operator (#5082)
fdc8b0d is described below

commit fdc8b0dd1763aece4ce457a7baf522c2989ac6c4
Author: Mahesh Ambule <15...@users.noreply.github.com>
AuthorDate: Mon Mar 23 07:10:54 2020 +0530

    [Relay, Topi] [TF, MXNet] Unravel Index operator (#5082)
    
    * first cut unravel_index
    
    * merge fixes
    
    * change rates to dilations
    
    * unravel_index op relay, topi, mxnet, tf
    
    * doc changes
    
    * small changes
    
    * remove empty unravel and argwhere attrs
    
    * remove empty unravel and argwhere attrs
---
 docs/api/python/topi.rst                         |  2 +
 docs/frontend/tensorflow.rst                     |  1 +
 docs/langref/relay_op.rst                        |  3 +-
 include/tvm/relay/attrs/transform.h              |  6 --
 python/tvm/relay/frontend/mxnet.py               |  8 +++
 python/tvm/relay/frontend/tensorflow.py          | 10 +++-
 python/tvm/relay/op/_transform.py                |  1 +
 python/tvm/relay/op/transform.py                 | 23 ++++++++
 src/relay/op/tensor/transform.cc                 | 72 +++++++++++++++++++++++-
 tests/python/frontend/mxnet/test_forward.py      | 29 +++++++++-
 tests/python/frontend/tensorflow/test_forward.py | 52 +++++++++++++++++
 tests/python/relay/test_op_level3.py             | 39 +++++++++++++
 topi/include/topi/transform.h                    | 48 ++++++++++++++++
 topi/python/topi/transform.py                    | 25 +++++++-
 topi/src/topi.cc                                 |  5 ++
 topi/tests/python/test_topi_transform.py         | 44 +++++++++++++++
 16 files changed, 353 insertions(+), 15 deletions(-)

diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst
index 49fd94d..676dde9 100644
--- a/docs/api/python/topi.rst
+++ b/docs/api/python/topi.rst
@@ -47,6 +47,7 @@ List of operators
    topi.strided_slice
    topi.expand_dims
    topi.reshape
+   topi.unravel_index
    topi.squeeze
    topi.concatenate
    topi.split
@@ -147,6 +148,7 @@ topi
 .. autofunction:: topi.strided_slice
 .. autofunction:: topi.expand_dims
 .. autofunction:: topi.reshape
+.. autofunction:: topi.unravel_index
 .. autofunction:: topi.squeeze
 .. autofunction:: topi.concatenate
 .. autofunction:: topi.split
diff --git a/docs/frontend/tensorflow.rst b/docs/frontend/tensorflow.rst
index e06794d..80230c6 100644
--- a/docs/frontend/tensorflow.rst
+++ b/docs/frontend/tensorflow.rst
@@ -242,5 +242,6 @@ Supported Ops
 - Transpose
 - TruncateMod
 - Unpack
+- UnravelIndex
 - Where
 - ZerosLike
diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst
index 35f9eeb..ac636f8 100644
--- a/docs/langref/relay_op.rst
+++ b/docs/langref/relay_op.rst
@@ -124,6 +124,7 @@ This level enables additional math and transform operators.
    tvm.relay.repeat
    tvm.relay.tile
    tvm.relay.reverse
+   tvm.relay.unravel_index
 
 
 **Level 4: Broadcast and Reductions**
@@ -217,4 +218,4 @@ This level supports dialect operators.
    :nosignatures:
 
    tvm.relay.qnn.op.requantize
-   tvm.relay.qnn.op.conv2d
+   tvm.relay.qnn.op.conv2d
\ No newline at end of file
diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h
index 11c7886..ae2ac11 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -315,12 +315,6 @@ struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
   }
 };  // struct OneHotAttrs
 
-/*! \brief Attributes for ArgWhere operator */
-struct ArgWhereAttrs : public tvm::AttrsNode<ArgWhereAttrs> {
-  TVM_DECLARE_ATTRS(ArgWhereAttrs, "relay.attrs.ArgWhereAttrs") {
-  }
-};  // struct ArgWhereAttrs
-
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_TRANSFORM_H_
diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py
index 17be368..b3feded 100644
--- a/python/tvm/relay/frontend/mxnet.py
+++ b/python/tvm/relay/frontend/mxnet.py
@@ -120,6 +120,13 @@ def _mx_compare(new_op, wrapper):
     return impl
 
 
+def _mx_unravel_index(inputs, attrs):
+    assert len(inputs) == 1
+    shape = attrs.get_int_tuple("shape")
+    shape_expr = _expr.const(list(shape))
+    return _op.unravel_index(inputs[0], shape_expr)
+
+
 def _mx_zeros(inputs, attrs):
     assert len(inputs) == 0
     shape = attrs.get_int_tuple("shape")
@@ -1826,6 +1833,7 @@ _convert_map = {
     "Embedding"     : _mx_embedding,
     "argsort"       : _mx_argsort,
     "topk"          : _mx_topk,
+    "_unravel_index": _mx_unravel_index,
     "SequenceMask"  : _mx_sequence_mask,
     "SoftmaxOutput" : _mx_softmax_output,
     "SoftmaxActivation" : _mx_softmax_activation,
diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index ff69ccc..4221cac 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -627,6 +627,11 @@ def _decode_image():
         return inputs[0]
     return _impl
 
+def _unravel_index():
+    def _impl(inputs, attr, params):
+        return _op.unravel_index(inputs[0], inputs[1])
+    return _impl
+
 def _crop_and_resize():
     def _impl(inputs, attr, params):
         # input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
@@ -1744,6 +1749,7 @@ _convert_map = {
     'Transpose'                         : _transpose(),
     'TruncateMod'                       : _elemwise('mod'),
     'Unpack'                            : _unpack(),
+    'UnravelIndex'                      : _unravel_index(),
     'Where'                             : _where(),
     'ZerosLike'                         : AttrCvt('zeros_like'),
 
@@ -2517,9 +2523,7 @@ class GraphProto(object):
 
             array_ndim = len(np_array.shape)
             if array_ndim == 0:
-                new_array = np.empty([1], dtype=np_array.dtype)
-                new_array[0] = np_array
-                self._nodes[name] = [tvm.relay.const(new_array)]
+                self._nodes[name] = [tvm.relay.const(np_array)]
             else:
                 self._params[name] = tvm.nd.array(np_array)
                 self._nodes[name] = [_expr.var(name,
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index 4b35009..1f85e31 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -54,6 +54,7 @@ _reg.register_injective_schedule("gather_nd")
 _reg.register_injective_schedule("sequence_mask")
 _reg.register_injective_schedule("one_hot")
 _reg.register_reduce_schedule("collapse_sum_like")
+_reg.register_injective_schedule("unravel_index")
 
 # concatenate
 _reg.register_schedule("concatenate", strategy.schedule_concatenate)
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index 6a30eb2..d7a7b4f 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -861,3 +861,26 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype):
              [0, 0, 1]]
     """
     return _make.one_hot(indices, on_value, off_value, depth, axis, dtype)
+
+
+def unravel_index(indices, shape):
+    """Convert a flat index or array of flat indices into a tuple of coordinate arrays.
+
+    Example::
+    -   unravel_index([22, 41, 37], [7, 6]) = [[3, 6, 6],[4, 5, 1]]
+
+    Parameters
+    ----------
+    indices : relay.Expr
+        An integer array containing indices.
+
+    shape : relay.Expr
+        The shape of the array.
+
+    Returns
+    -------
+    result : relay.Expr
+        The tuple of coordinate arrays.
+    """
+
+    return _make.unravel_index(indices, shape)
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 32df221..942ba7e 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -806,15 +806,13 @@ bool ArgWhereRel(const Array<Type>& types,
 TVM_REGISTER_GLOBAL("relay.op._make.argwhere")
 .set_body_typed([](Expr data) {
   static const Op& op = Op::Get("argwhere");
-  auto attrs = make_object<ArgWhereAttrs>();
-  return CallNode::make(op, {data}, Attrs(attrs), {});
+  return CallNode::make(op, {data}, Attrs(), {});
 });
 
 RELAY_REGISTER_OP("argwhere")
 .describe(R"doc(Find the indices of elements of a tensor that are
 non-zero)doc" TVM_ADD_FILELINE)
 .set_num_inputs(1)
-.set_attrs_type<ArgWhereAttrs>()
 .add_argument("condition", "Tensor", "The input condition tensor.")
 .add_type_rel("ArgWhere", ArgWhereRel)
 .set_attr<TOpIsStateful>("TOpIsStateful", false)
@@ -2662,5 +2660,73 @@ RELAY_REGISTER_OP("one_hot")
 .set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
 
+/* relay.unravel_index */
+bool UnRavelIndexRel(const Array<Type>& types,
+                     int num_inputs,
+                     const Attrs& attrs,
+                     const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+
+  const auto* indices = types[0].as<TensorTypeNode>();
+  if (indices == nullptr) {
+    CHECK(types[0].as<IncompleteTypeNode>())
+        << "unravel_index: expect input type to be TensorType but get "
+        << types[0];
+    return false;
+  }
+  CHECK(indices->dtype.is_int())
+      << "indices of unravel_index must be tensor of integer";
+
+  const auto* shape = types[1].as<TensorTypeNode>();
+  if (shape == nullptr) {
+    CHECK(types[1].as<IncompleteTypeNode>())
+        << "unravel_index: expect input type to be TensorType but get "
+        << types[1];
+    return false;
+  }
+  CHECK(indices->dtype.is_int())
+      << "shape of unravel_index must be tensor of integer";
+
+  Array<IndexExpr> indices_shape;
+  Array<IndexExpr> shape_shape;
+  indices_shape = indices->shape;
+  shape_shape = shape->shape;
+
+  Array<IndexExpr> oshape;
+  oshape.push_back(shape_shape[0]);
+  if (indices_shape.size() != 0) {
+    oshape.push_back(indices_shape[0]);
+  }
+  reporter->Assign(types[2], TensorType(oshape, indices->dtype));
+  return true;
+}
+
+Array<te::Tensor> UnRavelIndexCompute(const Attrs& attrs,
+                                      const Array<te::Tensor>& inputs,
+                                      const Type& out_type) {
+  return Array<te::Tensor>{topi::unravel_index(inputs[0], inputs[1])};
+}
+
+Expr MakeUnRavelIndex(Expr data,
+                      Expr shape) {
+  static const Op& op = Op::Get("unravel_index");
+  return CallNode::make(op, {data, shape}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.unravel_index")
+.set_body_typed(MakeUnRavelIndex);
+
+RELAY_REGISTER_OP("unravel_index")
+.describe(R"code(Converts a flat index or array of flat indices into a tuple of coordinate arrays.
+
+Example::
+  -  unravel_index([22, 41, 37], (7, 6)) = [[3, 6, 6], [4, 5, 1]]
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(2)
+.set_support_level(3)
+.add_type_rel("UnRavelIndexRel", UnRavelIndexRel)
+.set_attr<FTVMCompute>("FTVMCompute", UnRavelIndexCompute)
+.set_attr<TOpPattern>("TOpPattern", kInjective);
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py
index b81fbab..102905a 100644
--- a/tests/python/frontend/mxnet/test_forward.py
+++ b/tests/python/frontend/mxnet/test_forward.py
@@ -949,6 +949,32 @@ def test_forward_cond():
     verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32'))
 
 
+def test_forward_unravel_index():
+    def verify(x, shape, dtype):
+        a_np = np.array(x).astype(dtype)
+        mx_sym = _mx_symbol(mx.sym, 'unravel_index', [mx.sym.var('a'), shape])
+        ref_res = _mx_symbol(mx.nd, 'unravel_index', [mx.nd.array(a_np), shape])
+        shapes = {'a': a_np.shape}
+        mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
+
+        for target, ctx in ctx_list():
+            for kind in ["graph", "vm", "debug"]:
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(a_np)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
+    for dtype in ["int32", "int64"]:
+        verify([0, 1, 2, 3], [2, 2], dtype)
+        verify([144, 13, 45], [6, 7, 10, 2], dtype)
+        verify([456], [6, 7, 10, 2], dtype)
+
+    # In below example, 5 is out of bound for array of size 4.
+    # MXNet implementation provides different result than TVM
+    # TVM implementation is inline with Tensorflow
+    # Ideally error should be thrown just like Numpy
+    # verify([0, 1, 2, 5], [2, 2], dtype)
+
+
 if __name__ == '__main__':
     test_forward_mlp()
     test_forward_vgg()
@@ -1003,4 +1029,5 @@ if __name__ == '__main__':
     test_forward_convolution()
     test_forward_deconvolution()
     test_forward_cond()
-    test_forward_make_loss()
\ No newline at end of file
+    test_forward_make_loss()
+    test_forward_unravel_index()
\ No newline at end of file
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 2342606..3c51977 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -3057,6 +3057,57 @@ def test_forward_add_n():
     _test_forward_add_n(in5)
 
 
+#######################################################################
+# Unravel Index
+# ----------------------
+def _test_forward_unravel_index(inputs):
+    tf.reset_default_graph()
+    with tf.Graph().as_default():
+        temp = []
+        for each in inputs:
+            temp.append(tf.placeholder(shape=each.shape, dtype=each.dtype))
+        output = tf.unravel_index(temp[0], temp[1])
+        compare_tf_with_tvm([each for each in inputs], [
+            each.name for each in temp], output.name)
+
+
+def _test_forward_unravel_index_scalar(x, y, dtype="int32"):
+    tf.reset_default_graph()
+    with tf.Graph().as_default():
+        indices_1 = constant_op.constant(x, dtype=dtype)
+        dims_1 = constant_op.constant(y, dtype=dtype)
+        out_1 = array_ops.unravel_index(indices_1, dims_1)
+        compare_tf_with_tvm([], [], out_1.name)
+
+
+def test_forward_unravel_index():
+    x = np.array([0, 1, 2, 3])
+    y = np.array([2, 2])
+    _test_forward_unravel_index([x, y])
+
+    x = np.array([0, 1, 2, 5])
+    y = np.array([2, 2])
+    _test_forward_unravel_index([x, y])
+
+    x = np.array([0, 1, 2, 5])
+    y = np.array([2])
+    _test_forward_unravel_index([x, y])
+
+    x = np.array([102, 300, 16])
+    y = np.array([10, 10, 9, 6])
+    _test_forward_unravel_index([x, y])
+
+    x = np.array([100])
+    y = np.array([10, 10, 9, 6])
+    _test_forward_unravel_index([x, y])
+
+    # Test scalar input
+    _test_forward_unravel_index_scalar(13, [1, 4, 5, 2])
+
+
+#######################################################################
+# Dilation2d
+# ----------------------
 def _test_dilation2d(tensor_in_sizes, filter_in_sizes,
                      strides, dilations, padding):
     """ One iteration of dilation2d with given shapes and attributes """
@@ -3173,6 +3224,7 @@ if __name__ == '__main__':
     test_forward_squared_difference()
     test_forward_add_n()
     test_forward_floormod()
+    test_forward_unravel_index()
 
     # Reductions
     test_forward_argminmax()
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index 7e5314d..fffb1de 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -683,6 +683,44 @@ def test_gather_nd():
     verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]])
     verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])
 
+
+def test_unravel_index():
+    def verify_unravel_index(indices, shape, dtype):
+        x_data = np.array(indices).astype(dtype)
+        y_data = np.array(shape).astype(dtype)
+        x = relay.var("x", relay.TensorType(x_data.shape, dtype))
+        y = relay.var("y", relay.TensorType(y_data.shape, dtype))
+
+        z = relay.unravel_index(x, y)
+        zz = run_infer_type(z)
+
+        if len(x_data.shape) == 1:
+            out_shape = [y_data.shape[0], x_data.shape[0]]
+        else:
+            out_shape = [y_data.shape[0]]
+        assert zz.checked_type == relay.ty.TensorType(out_shape, dtype)
+
+        func = relay.Function([x, y], z)
+        ref_res = np.unravel_index(x_data, y_data)
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(func)(x_data, y_data)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
+    for dtype in ["int64", "int32"]:
+        verify_unravel_index([0, 1, 2, 3], [2, 2], dtype)
+        verify_unravel_index([144], [5, 5, 5, 2], dtype)
+        verify_unravel_index(144, [5, 5, 5, 2], dtype)
+        verify_unravel_index([100, 13, 5], [5, 5, 5, 2], dtype)
+
+        # In below example, 5 is out of bound for array of size 4.
+        # Numpy implementation throws error for it
+        # TVM implementation does not throw error instead it produces
+        # output which is inline with Tensorflow
+        # verify_unravel_index([0, 1, 2, 5], [2, 2], dtype)
+
+
 if __name__ == "__main__":
     test_arange()
     test_cast()
@@ -713,3 +751,4 @@ if __name__ == "__main__":
     test_tile()
     test_repeat()
     test_gather_nd()
+    test_unravel_index()
diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h
index efbffad..40bdcc6 100644
--- a/topi/include/topi/transform.h
+++ b/topi/include/topi/transform.h
@@ -233,6 +233,54 @@ inline Tensor reshape(const Tensor& x,
 }
 
 /*!
+ * \brief Converts a flat index or array of flat indices into a tuple of coordinate arrays
+ *
+ * \param x The input tensor having indices.
+ * \param shape The shape tensor
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor of coordinate arrays.
+ */
+
+inline Tensor unravel_index(const Tensor& x,
+                            const Tensor& shape,
+                            std::string name = "T_unravel",
+                            std::string tag = kInjective) {
+  auto x_shape = x->shape;
+  auto shape_shape = shape->shape;
+
+  Array<PrimExpr> oshape;
+  oshape.push_back(shape_shape[0]);
+  if (x_shape.size() != 0) {
+    oshape.push_back(x_shape[0]);
+  }
+
+  auto func = [&](const Array<Var>& indices) {
+    auto i = indices[0];
+    std::vector<PrimExpr> indices_divs;
+    PrimExpr ret = 0;
+    PrimExpr cur_val = 0;
+    PrimExpr index_val = 0;
+
+    if (x_shape.size() != 0) {
+      index_val = x[indices[1]];
+    } else {
+      index_val = x();
+    }
+    indices_divs.push_back(index_val);
+    for (int v = GetConstInt(shape_shape[0]) - 1; v >= 0; --v) {
+      ret = tvm::if_then_else(i == v, indexmod(indices_divs.back(), shape[v]), ret);
+      cur_val = indexdiv(indices_divs.back(), shape[v]);
+      indices_divs.push_back(cur_val);
+    }
+    return ret;
+  };
+
+  return compute(oshape, func, name, tag);
+}
+
+/*!
 * \brief Remove size 1 dimensions from the shape of a tensor.
 * The removed dimensions must have a constant size of 1.
 *
diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py
index 036191b..ef54560 100644
--- a/topi/python/topi/transform.py
+++ b/topi/python/topi/transform.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name,consider-using-enumerate
+# pylint: disable=invalid-name,consider-using-enumerate,redefined-outer-name
 """Injective transformation operators"""
 from __future__ import absolute_import as _abs
 import tvm
@@ -653,3 +653,26 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype):
              [0, 0, 1]]
     """
     return cpp.one_hot(indices, on_value, off_value, depth, axis, dtype)
+
+
+def unravel_index(indices, shape):
+    """Convert a flat index or array of flat indices into a tuple of coordinate arrays.
+
+       Example::
+       -   unravel_index([22, 41, 37], [7, 6]) = [[3, 6, 6], [4, 5, 1]]
+
+       Parameters
+       ----------
+       indices : relay.Expr
+           An integer array containing indices.
+
+       shape : relay.Expr
+           The shape of the array.
+
+       Returns
+       -------
+       result : relay.Expr
+           The tuple of coordinate arrays.
+    """
+
+    return cpp.unravel_index(indices, shape)
diff --git a/topi/src/topi.cc b/topi/src/topi.cc
index 5581f2b..3a3175c 100644
--- a/topi/src/topi.cc
+++ b/topi/src/topi.cc
@@ -435,6 +435,11 @@ TVM_REGISTER_GLOBAL("topi.gather_nd")
   *rv = gather_nd(args[0], args[1]);
 });
 
+TVM_REGISTER_GLOBAL("topi.unravel_index")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  *rv = unravel_index(args[0], args[1]);
+  });
+
 TVM_REGISTER_GLOBAL("topi.matmul")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
   switch ( args.size() ) {
diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py
index 097c87d..b98ce09 100644
--- a/topi/tests/python/test_topi_transform.py
+++ b/topi/tests/python/test_topi_transform.py
@@ -562,6 +562,40 @@ def verify_one_hot(indices_shape, depth, on_value, off_value, axis, dtype):
     for device in get_all_backend():
         check_device(device)
 
+
+def verify_unravel_index(indices, shape, dtype):
+    x_data = np.array(indices).astype(dtype)
+    y_data = np.array(shape).astype(dtype)
+    if len(x_data.shape) == 1:
+        dst_shape = [y_data.shape[0], x_data.shape[0]]
+    else:
+        dst_shape = [y_data.shape[0]]
+
+    X = te.placeholder(shape=x_data.shape, dtype=dtype, name="X")
+    Y = te.placeholder(shape=y_data.shape, dtype=dtype, name="Y")
+    Z = topi.unravel_index(X, Y)
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        with tvm.target.create(device):
+            s = topi.testing.get_injective_schedule(device)(Z)
+        foo = tvm.build(s, [X, Y, Z], device, name="unravel_index")
+
+        out_npy = np.unravel_index(x_data, y_data)
+        datax_nd = tvm.nd.array(x_data, ctx)
+        datay_nd = tvm.nd.array(y_data, ctx)
+        out_nd = tvm.nd.empty(dst_shape, ctx=ctx, dtype=Z.dtype)
+        foo(datax_nd, datay_nd, out_nd)
+        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
+
+    for device in get_all_backend():
+        check_device(device)
+
+
 def test_strided_slice():
     verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
     verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
@@ -882,6 +916,15 @@ def test_one_hot():
     verify_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32")
     verify_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
 
+
+def test_unravel_index():
+    for dtype in ["int32", "int64"]:
+        verify_unravel_index([0, 1, 2, 3], [2, 2], dtype)
+        verify_unravel_index([144], [5, 5, 5, 2], dtype)
+        verify_unravel_index(144, [5, 5, 5, 2], dtype)
+        verify_unravel_index([100, 13, 5], [5, 5, 5, 2], dtype)
+
+
 if __name__ == "__main__":
     test_strided_slice()
     test_concatenate()
@@ -905,3 +948,4 @@ if __name__ == "__main__":
     test_ndarray_size()
     test_where_fusion()
     test_one_hot()
+    test_unravel_index()