You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/03/23 01:41:01 UTC
[incubator-tvm] branch master updated: [Relay, Topi] [TF,
MXNet] Unravel Index operator (#5082)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new fdc8b0d [Relay, Topi] [TF, MXNet] Unravel Index operator (#5082)
fdc8b0d is described below
commit fdc8b0dd1763aece4ce457a7baf522c2989ac6c4
Author: Mahesh Ambule <15...@users.noreply.github.com>
AuthorDate: Mon Mar 23 07:10:54 2020 +0530
[Relay, Topi] [TF, MXNet] Unravel Index operator (#5082)
* first cut unravel_index
* merge fixes
* change rates to dilations
* unravel_index op relay, topi, mxnet, tf
* doc changes
* small changes
* remove empty unravel and argwhere attrs
* remove empty unravel and argwhere attrs
---
docs/api/python/topi.rst | 2 +
docs/frontend/tensorflow.rst | 1 +
docs/langref/relay_op.rst | 3 +-
include/tvm/relay/attrs/transform.h | 6 --
python/tvm/relay/frontend/mxnet.py | 8 +++
python/tvm/relay/frontend/tensorflow.py | 10 +++-
python/tvm/relay/op/_transform.py | 1 +
python/tvm/relay/op/transform.py | 23 ++++++++
src/relay/op/tensor/transform.cc | 72 +++++++++++++++++++++++-
tests/python/frontend/mxnet/test_forward.py | 29 +++++++++-
tests/python/frontend/tensorflow/test_forward.py | 52 +++++++++++++++++
tests/python/relay/test_op_level3.py | 39 +++++++++++++
topi/include/topi/transform.h | 48 ++++++++++++++++
topi/python/topi/transform.py | 25 +++++++-
topi/src/topi.cc | 5 ++
topi/tests/python/test_topi_transform.py | 44 +++++++++++++++
16 files changed, 353 insertions(+), 15 deletions(-)
diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst
index 49fd94d..676dde9 100644
--- a/docs/api/python/topi.rst
+++ b/docs/api/python/topi.rst
@@ -47,6 +47,7 @@ List of operators
topi.strided_slice
topi.expand_dims
topi.reshape
+ topi.unravel_index
topi.squeeze
topi.concatenate
topi.split
@@ -147,6 +148,7 @@ topi
.. autofunction:: topi.strided_slice
.. autofunction:: topi.expand_dims
.. autofunction:: topi.reshape
+.. autofunction:: topi.unravel_index
.. autofunction:: topi.squeeze
.. autofunction:: topi.concatenate
.. autofunction:: topi.split
diff --git a/docs/frontend/tensorflow.rst b/docs/frontend/tensorflow.rst
index e06794d..80230c6 100644
--- a/docs/frontend/tensorflow.rst
+++ b/docs/frontend/tensorflow.rst
@@ -242,5 +242,6 @@ Supported Ops
- Transpose
- TruncateMod
- Unpack
+- UnravelIndex
- Where
- ZerosLike
diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst
index 35f9eeb..ac636f8 100644
--- a/docs/langref/relay_op.rst
+++ b/docs/langref/relay_op.rst
@@ -124,6 +124,7 @@ This level enables additional math and transform operators.
tvm.relay.repeat
tvm.relay.tile
tvm.relay.reverse
+ tvm.relay.unravel_index
**Level 4: Broadcast and Reductions**
@@ -217,4 +218,4 @@ This level supports dialect operators.
:nosignatures:
tvm.relay.qnn.op.requantize
- tvm.relay.qnn.op.conv2d
+ tvm.relay.qnn.op.conv2d
\ No newline at end of file
diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h
index 11c7886..ae2ac11 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -315,12 +315,6 @@ struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
}
}; // struct OneHotAttrs
-/*! \brief Attributes for ArgWhere operator */
-struct ArgWhereAttrs : public tvm::AttrsNode<ArgWhereAttrs> {
- TVM_DECLARE_ATTRS(ArgWhereAttrs, "relay.attrs.ArgWhereAttrs") {
- }
-}; // struct ArgWhereAttrs
-
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py
index 17be368..b3feded 100644
--- a/python/tvm/relay/frontend/mxnet.py
+++ b/python/tvm/relay/frontend/mxnet.py
@@ -120,6 +120,13 @@ def _mx_compare(new_op, wrapper):
return impl
+def _mx_unravel_index(inputs, attrs):
+ assert len(inputs) == 1
+ shape = attrs.get_int_tuple("shape")
+ shape_expr = _expr.const(list(shape))
+ return _op.unravel_index(inputs[0], shape_expr)
+
+
def _mx_zeros(inputs, attrs):
assert len(inputs) == 0
shape = attrs.get_int_tuple("shape")
@@ -1826,6 +1833,7 @@ _convert_map = {
"Embedding" : _mx_embedding,
"argsort" : _mx_argsort,
"topk" : _mx_topk,
+ "_unravel_index": _mx_unravel_index,
"SequenceMask" : _mx_sequence_mask,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index ff69ccc..4221cac 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -627,6 +627,11 @@ def _decode_image():
return inputs[0]
return _impl
+def _unravel_index():
+ def _impl(inputs, attr, params):
+ return _op.unravel_index(inputs[0], inputs[1])
+ return _impl
+
def _crop_and_resize():
def _impl(inputs, attr, params):
# input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
@@ -1744,6 +1749,7 @@ _convert_map = {
'Transpose' : _transpose(),
'TruncateMod' : _elemwise('mod'),
'Unpack' : _unpack(),
+ 'UnravelIndex' : _unravel_index(),
'Where' : _where(),
'ZerosLike' : AttrCvt('zeros_like'),
@@ -2517,9 +2523,7 @@ class GraphProto(object):
array_ndim = len(np_array.shape)
if array_ndim == 0:
- new_array = np.empty([1], dtype=np_array.dtype)
- new_array[0] = np_array
- self._nodes[name] = [tvm.relay.const(new_array)]
+ self._nodes[name] = [tvm.relay.const(np_array)]
else:
self._params[name] = tvm.nd.array(np_array)
self._nodes[name] = [_expr.var(name,
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index 4b35009..1f85e31 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -54,6 +54,7 @@ _reg.register_injective_schedule("gather_nd")
_reg.register_injective_schedule("sequence_mask")
_reg.register_injective_schedule("one_hot")
_reg.register_reduce_schedule("collapse_sum_like")
+_reg.register_injective_schedule("unravel_index")
# concatenate
_reg.register_schedule("concatenate", strategy.schedule_concatenate)
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index 6a30eb2..d7a7b4f 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -861,3 +861,26 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype):
[0, 0, 1]]
"""
return _make.one_hot(indices, on_value, off_value, depth, axis, dtype)
+
+
+def unravel_index(indices, shape):
+ """Convert a flat index or array of flat indices into a tuple of coordinate arrays.
+
+ Example::
+ - unravel_index([22, 41, 37], [7, 6]) = [[3, 6, 6],[4, 5, 1]]
+
+ Parameters
+ ----------
+ indices : relay.Expr
+ An integer array containing indices.
+
+ shape : relay.Expr
+ The shape of the array.
+
+ Returns
+ -------
+ result : relay.Expr
+ The tuple of coordinate arrays.
+ """
+
+ return _make.unravel_index(indices, shape)
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 32df221..942ba7e 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -806,15 +806,13 @@ bool ArgWhereRel(const Array<Type>& types,
TVM_REGISTER_GLOBAL("relay.op._make.argwhere")
.set_body_typed([](Expr data) {
static const Op& op = Op::Get("argwhere");
- auto attrs = make_object<ArgWhereAttrs>();
- return CallNode::make(op, {data}, Attrs(attrs), {});
+ return CallNode::make(op, {data}, Attrs(), {});
});
RELAY_REGISTER_OP("argwhere")
.describe(R"doc(Find the indices of elements of a tensor that are
non-zero)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
-.set_attrs_type<ArgWhereAttrs>()
.add_argument("condition", "Tensor", "The input condition tensor.")
.add_type_rel("ArgWhere", ArgWhereRel)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
@@ -2662,5 +2660,73 @@ RELAY_REGISTER_OP("one_hot")
.set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
.set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
+/* relay.unravel_index */
+bool UnRavelIndexRel(const Array<Type>& types,
+ int num_inputs,
+ const Attrs& attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(types.size(), 3);
+
+ const auto* indices = types[0].as<TensorTypeNode>();
+ if (indices == nullptr) {
+ CHECK(types[0].as<IncompleteTypeNode>())
+ << "unravel_index: expect input type to be TensorType but get "
+ << types[0];
+ return false;
+ }
+ CHECK(indices->dtype.is_int())
+ << "indices of unravel_index must be tensor of integer";
+
+ const auto* shape = types[1].as<TensorTypeNode>();
+ if (shape == nullptr) {
+ CHECK(types[1].as<IncompleteTypeNode>())
+ << "unravel_index: expect input type to be TensorType but get "
+ << types[1];
+ return false;
+ }
+ CHECK(indices->dtype.is_int())
+ << "shape of unravel_index must be tensor of integer";
+
+ Array<IndexExpr> indices_shape;
+ Array<IndexExpr> shape_shape;
+ indices_shape = indices->shape;
+ shape_shape = shape->shape;
+
+ Array<IndexExpr> oshape;
+ oshape.push_back(shape_shape[0]);
+ if (indices_shape.size() != 0) {
+ oshape.push_back(indices_shape[0]);
+ }
+ reporter->Assign(types[2], TensorType(oshape, indices->dtype));
+ return true;
+}
+
+Array<te::Tensor> UnRavelIndexCompute(const Attrs& attrs,
+ const Array<te::Tensor>& inputs,
+ const Type& out_type) {
+ return Array<te::Tensor>{topi::unravel_index(inputs[0], inputs[1])};
+}
+
+Expr MakeUnRavelIndex(Expr data,
+ Expr shape) {
+ static const Op& op = Op::Get("unravel_index");
+ return CallNode::make(op, {data, shape}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.unravel_index")
+.set_body_typed(MakeUnRavelIndex);
+
+RELAY_REGISTER_OP("unravel_index")
+.describe(R"code(Converts a flat index or array of flat indices into a tuple of coordinate arrays.
+
+Example::
+ - unravel_index([22, 41, 37], (7, 6)) = [[3, 6, 6], [4, 5, 1]]
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(2)
+.set_support_level(3)
+.add_type_rel("UnRavelIndexRel", UnRavelIndexRel)
+.set_attr<FTVMCompute>("FTVMCompute", UnRavelIndexCompute)
+.set_attr<TOpPattern>("TOpPattern", kInjective);
+
} // namespace relay
} // namespace tvm
diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py
index b81fbab..102905a 100644
--- a/tests/python/frontend/mxnet/test_forward.py
+++ b/tests/python/frontend/mxnet/test_forward.py
@@ -949,6 +949,32 @@ def test_forward_cond():
verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32'))
+def test_forward_unravel_index():
+ def verify(x, shape, dtype):
+ a_np = np.array(x).astype(dtype)
+ mx_sym = _mx_symbol(mx.sym, 'unravel_index', [mx.sym.var('a'), shape])
+ ref_res = _mx_symbol(mx.nd, 'unravel_index', [mx.nd.array(a_np), shape])
+ shapes = {'a': a_np.shape}
+ mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
+
+ for target, ctx in ctx_list():
+ for kind in ["graph", "vm", "debug"]:
+ intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+ op_res = intrp.evaluate()(a_np)
+ tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
+ for dtype in ["int32", "int64"]:
+ verify([0, 1, 2, 3], [2, 2], dtype)
+ verify([144, 13, 45], [6, 7, 10, 2], dtype)
+ verify([456], [6, 7, 10, 2], dtype)
+
+ # In below example, 5 is out of bound for array of size 4.
+ # MXNet implementation provides different result than TVM
+ # TVM implementation is inline with Tensorflow
+ # Ideally error should be thrown just like Numpy
+ # verify([0, 1, 2, 5], [2, 2], dtype)
+
+
if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
@@ -1003,4 +1029,5 @@ if __name__ == '__main__':
test_forward_convolution()
test_forward_deconvolution()
test_forward_cond()
- test_forward_make_loss()
\ No newline at end of file
+ test_forward_make_loss()
+ test_forward_unravel_index()
\ No newline at end of file
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 2342606..3c51977 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -3057,6 +3057,57 @@ def test_forward_add_n():
_test_forward_add_n(in5)
+#######################################################################
+# Unravel Index
+# ----------------------
+def _test_forward_unravel_index(inputs):
+ tf.reset_default_graph()
+ with tf.Graph().as_default():
+ temp = []
+ for each in inputs:
+ temp.append(tf.placeholder(shape=each.shape, dtype=each.dtype))
+ output = tf.unravel_index(temp[0], temp[1])
+ compare_tf_with_tvm([each for each in inputs], [
+ each.name for each in temp], output.name)
+
+
+def _test_forward_unravel_index_scalar(x, y, dtype="int32"):
+ tf.reset_default_graph()
+ with tf.Graph().as_default():
+ indices_1 = constant_op.constant(x, dtype=dtype)
+ dims_1 = constant_op.constant(y, dtype=dtype)
+ out_1 = array_ops.unravel_index(indices_1, dims_1)
+ compare_tf_with_tvm([], [], out_1.name)
+
+
+def test_forward_unravel_index():
+ x = np.array([0, 1, 2, 3])
+ y = np.array([2, 2])
+ _test_forward_unravel_index([x, y])
+
+ x = np.array([0, 1, 2, 5])
+ y = np.array([2, 2])
+ _test_forward_unravel_index([x, y])
+
+ x = np.array([0, 1, 2, 5])
+ y = np.array([2])
+ _test_forward_unravel_index([x, y])
+
+ x = np.array([102, 300, 16])
+ y = np.array([10, 10, 9, 6])
+ _test_forward_unravel_index([x, y])
+
+ x = np.array([100])
+ y = np.array([10, 10, 9, 6])
+ _test_forward_unravel_index([x, y])
+
+ # Test scalar input
+ _test_forward_unravel_index_scalar(13, [1, 4, 5, 2])
+
+
+#######################################################################
+# Dilation2d
+# ----------------------
def _test_dilation2d(tensor_in_sizes, filter_in_sizes,
strides, dilations, padding):
""" One iteration of dilation2d with given shapes and attributes """
@@ -3173,6 +3224,7 @@ if __name__ == '__main__':
test_forward_squared_difference()
test_forward_add_n()
test_forward_floormod()
+ test_forward_unravel_index()
# Reductions
test_forward_argminmax()
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index 7e5314d..fffb1de 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -683,6 +683,44 @@ def test_gather_nd():
verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])
+
+def test_unravel_index():
+ def verify_unravel_index(indices, shape, dtype):
+ x_data = np.array(indices).astype(dtype)
+ y_data = np.array(shape).astype(dtype)
+ x = relay.var("x", relay.TensorType(x_data.shape, dtype))
+ y = relay.var("y", relay.TensorType(y_data.shape, dtype))
+
+ z = relay.unravel_index(x, y)
+ zz = run_infer_type(z)
+
+ if len(x_data.shape) == 1:
+ out_shape = [y_data.shape[0], x_data.shape[0]]
+ else:
+ out_shape = [y_data.shape[0]]
+ assert zz.checked_type == relay.ty.TensorType(out_shape, dtype)
+
+ func = relay.Function([x, y], z)
+ ref_res = np.unravel_index(x_data, y_data)
+ for target, ctx in ctx_list():
+ for kind in ["graph", "debug"]:
+ intrp = relay.create_executor(kind, ctx=ctx, target=target)
+ op_res = intrp.evaluate(func)(x_data, y_data)
+ tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
+ for dtype in ["int64", "int32"]:
+ verify_unravel_index([0, 1, 2, 3], [2, 2], dtype)
+ verify_unravel_index([144], [5, 5, 5, 2], dtype)
+ verify_unravel_index(144, [5, 5, 5, 2], dtype)
+ verify_unravel_index([100, 13, 5], [5, 5, 5, 2], dtype)
+
+ # In below example, 5 is out of bound for array of size 4.
+ # Numpy implementation throws error for it
+ # TVM implementation does not throw error instead it produces
+ # output which is inline with Tensorflow
+ # verify_unravel_index([0, 1, 2, 5], [2, 2], dtype)
+
+
if __name__ == "__main__":
test_arange()
test_cast()
@@ -713,3 +751,4 @@ if __name__ == "__main__":
test_tile()
test_repeat()
test_gather_nd()
+ test_unravel_index()
diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h
index efbffad..40bdcc6 100644
--- a/topi/include/topi/transform.h
+++ b/topi/include/topi/transform.h
@@ -233,6 +233,54 @@ inline Tensor reshape(const Tensor& x,
}
/*!
+ * \brief Converts a flat index or array of flat indices into a tuple of coordinate arrays
+ *
+ * \param x The input tensor having indices.
+ * \param shape The shape tensor
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor of coordinate arrays.
+ */
+
+inline Tensor unravel_index(const Tensor& x,
+ const Tensor& shape,
+ std::string name = "T_unravel",
+ std::string tag = kInjective) {
+ auto x_shape = x->shape;
+ auto shape_shape = shape->shape;
+
+ Array<PrimExpr> oshape;
+ oshape.push_back(shape_shape[0]);
+ if (x_shape.size() != 0) {
+ oshape.push_back(x_shape[0]);
+ }
+
+ auto func = [&](const Array<Var>& indices) {
+ auto i = indices[0];
+ std::vector<PrimExpr> indices_divs;
+ PrimExpr ret = 0;
+ PrimExpr cur_val = 0;
+ PrimExpr index_val = 0;
+
+ if (x_shape.size() != 0) {
+ index_val = x[indices[1]];
+ } else {
+ index_val = x();
+ }
+ indices_divs.push_back(index_val);
+ for (int v = GetConstInt(shape_shape[0]) - 1; v >= 0; --v) {
+ ret = tvm::if_then_else(i == v, indexmod(indices_divs.back(), shape[v]), ret);
+ cur_val = indexdiv(indices_divs.back(), shape[v]);
+ indices_divs.push_back(cur_val);
+ }
+ return ret;
+ };
+
+ return compute(oshape, func, name, tag);
+}
+
+/*!
* \brief Remove size 1 dimensions from the shape of a tensor.
* The removed dimensions must have a constant size of 1.
*
diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py
index 036191b..ef54560 100644
--- a/topi/python/topi/transform.py
+++ b/topi/python/topi/transform.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=invalid-name,consider-using-enumerate
+# pylint: disable=invalid-name,consider-using-enumerate,redefined-outer-name
"""Injective transformation operators"""
from __future__ import absolute_import as _abs
import tvm
@@ -653,3 +653,26 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype):
[0, 0, 1]]
"""
return cpp.one_hot(indices, on_value, off_value, depth, axis, dtype)
+
+
+def unravel_index(indices, shape):
+ """Convert a flat index or array of flat indices into a tuple of coordinate arrays.
+
+ Example::
+ - unravel_index([22, 41, 37], [7, 6]) = [[3, 6, 6], [4, 5, 1]]
+
+ Parameters
+ ----------
+ indices : relay.Expr
+ An integer array containing indices.
+
+ shape : relay.Expr
+ The shape of the array.
+
+ Returns
+ -------
+ result : relay.Expr
+ The tuple of coordinate arrays.
+ """
+
+ return cpp.unravel_index(indices, shape)
diff --git a/topi/src/topi.cc b/topi/src/topi.cc
index 5581f2b..3a3175c 100644
--- a/topi/src/topi.cc
+++ b/topi/src/topi.cc
@@ -435,6 +435,11 @@ TVM_REGISTER_GLOBAL("topi.gather_nd")
*rv = gather_nd(args[0], args[1]);
});
+TVM_REGISTER_GLOBAL("topi.unravel_index")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+ *rv = unravel_index(args[0], args[1]);
+ });
+
TVM_REGISTER_GLOBAL("topi.matmul")
.set_body([](TVMArgs args, TVMRetValue *rv) {
switch ( args.size() ) {
diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py
index 097c87d..b98ce09 100644
--- a/topi/tests/python/test_topi_transform.py
+++ b/topi/tests/python/test_topi_transform.py
@@ -562,6 +562,40 @@ def verify_one_hot(indices_shape, depth, on_value, off_value, axis, dtype):
for device in get_all_backend():
check_device(device)
+
+def verify_unravel_index(indices, shape, dtype):
+ x_data = np.array(indices).astype(dtype)
+ y_data = np.array(shape).astype(dtype)
+ if len(x_data.shape) == 1:
+ dst_shape = [y_data.shape[0], x_data.shape[0]]
+ else:
+ dst_shape = [y_data.shape[0]]
+
+ X = te.placeholder(shape=x_data.shape, dtype=dtype, name="X")
+ Y = te.placeholder(shape=y_data.shape, dtype=dtype, name="Y")
+ Z = topi.unravel_index(X, Y)
+
+ def check_device(device):
+ ctx = tvm.context(device, 0)
+ if not ctx.exist:
+ print("Skip because %s is not enabled" % device)
+ return
+ print("Running on target: %s" % device)
+ with tvm.target.create(device):
+ s = topi.testing.get_injective_schedule(device)(Z)
+ foo = tvm.build(s, [X, Y, Z], device, name="unravel_index")
+
+ out_npy = np.unravel_index(x_data, y_data)
+ datax_nd = tvm.nd.array(x_data, ctx)
+ datay_nd = tvm.nd.array(y_data, ctx)
+ out_nd = tvm.nd.empty(dst_shape, ctx=ctx, dtype=Z.dtype)
+ foo(datax_nd, datay_nd, out_nd)
+ tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
+
+ for device in get_all_backend():
+ check_device(device)
+
+
def test_strided_slice():
verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
@@ -882,6 +916,15 @@ def test_one_hot():
verify_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32")
verify_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
+
+def test_unravel_index():
+ for dtype in ["int32", "int64"]:
+ verify_unravel_index([0, 1, 2, 3], [2, 2], dtype)
+ verify_unravel_index([144], [5, 5, 5, 2], dtype)
+ verify_unravel_index(144, [5, 5, 5, 2], dtype)
+ verify_unravel_index([100, 13, 5], [5, 5, 5, 2], dtype)
+
+
if __name__ == "__main__":
test_strided_slice()
test_concatenate()
@@ -905,3 +948,4 @@ if __name__ == "__main__":
test_ndarray_size()
test_where_fusion()
test_one_hot()
+ test_unravel_index()