You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ke...@apache.org on 2020/09/11 06:10:05 UTC
[incubator-tvm] branch master updated: [Relay][Topi][Op]Advanced
indexing (#6388)
This is an automated email from the ASF dual-hosted git repository.
kevinthesun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 1228111 [Relay][Topi][Op]Advanced indexing (#6388)
1228111 is described below
commit 122811137a6fe3e787d5174a36d99eebefb479d0
Author: Yao Wang <ke...@gmail.com>
AuthorDate: Thu Sep 10 23:09:45 2020 -0700
[Relay][Topi][Op]Advanced indexing (#6388)
* Add Relay adv_index op
* Support single index tensor dynamic shape
* Support more dynamic index
* Fix lint
* Minor fix for comment
* Fix lint
* Fix lint
* Fix test
* Fix
---
include/tvm/topi/transform.h | 80 ++++++++++++++++++++++++
python/tvm/relay/frontend/pytorch.py | 40 +-----------
python/tvm/relay/op/_transform.py | 33 ++++++++++
python/tvm/relay/op/transform.py | 18 ++++++
python/tvm/topi/transform.py | 18 ++++++
src/relay/op/tensor/transform.cc | 83 +++++++++++++++++++++++++
src/topi/transform.cc | 4 ++
tests/python/relay/test_any.py | 15 +++++
tests/python/relay/test_op_level3.py | 26 ++++++++
tests/python/topi/python/test_topi_transform.py | 42 +++++++++++++
10 files changed, 321 insertions(+), 38 deletions(-)
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index af59928..2c0d102 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -26,6 +26,7 @@
#include <tvm/te/operation.h>
#include <tvm/tir/data_layout.h>
+#include <tvm/topi/broadcast.h>
#include <tvm/topi/detail/constant_utils.h>
#include <tvm/topi/detail/ravel_unravel.h>
#include <tvm/topi/detail/tensor_utils.h>
@@ -1551,6 +1552,85 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal,
name, tag);
}
+/*!
+ * \brief Numpy style advanced indexing with tensor.
+ * \param data is input data.
+ * \param indices is list of indexing tensors.
+ * \param name output tensor name.
+ * \param tag output tensor tag.
+ * \return Output tensor.
+ */
+inline Tensor adv_index(const Tensor& data, const Array<Tensor>& indices,
+ const std::string name = "advanced_index",
+ const std::string tag = kInjective) {
+ Array<PrimExpr> oshape;
+ Array<PrimExpr> broadcast_shape;
+ Array<Tensor> bindices;
+ std::vector<int64_t> flatten_shape_lens;
+ int64_t num_picked_elems = 1;
+ bool has_dyn_shape = false;
+
+ if (indices.size() == 1) {
+ broadcast_shape = indices[0]->shape;
+ bindices = indices;
+ } else {
+ for (const auto& index : indices) {
+ int64_t flatten_len = 1;
+ for (const auto& dim : index->shape) {
+ const IntImmNode* axis_len = dim.as<IntImmNode>();
+ if (!axis_len) {
+ broadcast_shape = index->shape;
+ has_dyn_shape = true;
+ break;
+ }
+ flatten_len *= axis_len->value;
+ }
+ if (has_dyn_shape) break;
+ flatten_shape_lens.push_back(flatten_len);
+ if (flatten_len > num_picked_elems) {
+ num_picked_elems = flatten_len;
+ broadcast_shape = index->shape;
+ }
+ }
+
+ // Do broadcast for indices
+ for (size_t i = 0; i < indices.size(); ++i) {
+ if (!has_dyn_shape && flatten_shape_lens[i] < num_picked_elems) {
+ bindices.push_back(broadcast_to(indices[i], broadcast_shape));
+ } else {
+ bindices.push_back(indices[i]);
+ }
+ }
+ }
+
+ for (const auto& dim : broadcast_shape) {
+ oshape.push_back(dim);
+ }
+ for (size_t i = indices.size(); i < data->shape.size(); ++i) {
+ oshape.push_back(data->shape[i]);
+ }
+
+ return compute(
+ oshape,
+ [&](const Array<Var>& iter_var) {
+ Array<PrimExpr> tensor_indices;
+ for (size_t i = 0; i < broadcast_shape.size(); ++i) {
+ tensor_indices.push_back(iter_var[i]);
+ }
+
+ Array<PrimExpr> real_indices;
+ for (size_t i = 0; i < bindices.size(); ++i) {
+ real_indices.push_back(bindices[i](tensor_indices));
+ }
+ for (size_t i = broadcast_shape.size(); i < iter_var.size(); ++i) {
+ real_indices.push_back(iter_var[i]);
+ }
+
+ return data(real_indices);
+ },
+ name, tag);
+}
+
} // namespace topi
} // namespace tvm
#endif // TVM_TOPI_TRANSFORM_H_
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 7203150..19cbf75 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1816,44 +1816,8 @@ def _one_hot():
def _index():
def _impl(inputs, input_types):
data = inputs[0]
- indices = []
- raw_indices = []
- max_indices_len = -1
- for index in inputs[1]:
- if not isinstance(index, _expr.Constant):
- try:
- index = _expr.const(_infer_value(index, {}))
- except Exception:
- raise RuntimeError("Only supports constant indices for "
- "pytorch advanced indexing ")
- raw_indices.append(index)
- cindex_len = index.data.shape[0]
- if cindex_len > max_indices_len:
- max_indices_len = cindex_len
-
- for index in raw_indices:
- cnp = index.data.asnumpy()
- cindex_len = cnp.shape[0]
- if cindex_len < max_indices_len:
- cnp = np.tile(cnp, max_indices_len // cindex_len)
- indices.append(cnp)
-
- ret = []
- slice_map = {}
- for i in range(indices[0].shape[0]):
- tmp = data
- current_indices = []
- for index in indices:
- current_indices.append(index[i])
- index_key = tuple(current_indices)
- if index_key in slice_map:
- tmp = slice_map[index_key]
- else:
- tmp = _op.take(tmp, _expr.const(index[i]), axis=0)
- slice_map[index_key] = tmp
- ret.append(_op.expand_dims(tmp, axis=0))
-
- return _op.concatenate(ret, axis=0)
+ indices = inputs[1]
+ return _op.adv_index([data] + indices)
return _impl
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index 9d7c389..98ff0b3 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -62,6 +62,7 @@ _reg.register_reduce_schedule("collapse_sum_to")
_reg.register_injective_schedule("unravel_index")
_reg.register_injective_schedule("sparse_to_dense")
_reg.register_injective_schedule("matrix_set_diag")
+_reg.register_injective_schedule("adv_index")
# concatenate
_reg.register_schedule("concatenate", strategy.schedule_concatenate)
@@ -661,3 +662,35 @@ def split_shape_func(attrs, inputs, _):
convert(i),
convert(indices_or_sections),
convert(axis)) for i in range(num_out)]
+
+@script
+def _adv_index_shape_func(inputs):
+ index_rank = inputs[1].shape[0]
+ data_rank = inputs[0].shape[0]
+ out = output_tensor((data_rank + index_rank - len(inputs) + 1,), "int64")
+
+ max_flatten_len = int64(1)
+ for i in const_range(index_rank):
+ max_flatten_len *= inputs[1][i]
+ out[i] = inputs[1][i]
+ for i in const_range(len(inputs) - 2):
+ flatten_len = int64(1)
+ for j in const_range(index_rank):
+ flatten_len *= inputs[i + 2][j]
+ if flatten_len > max_flatten_len:
+ max_flatten_len = flatten_len
+ for k in const_range(index_rank):
+ out[k] = inputs[i + 2][k]
+
+ for i in const_range(data_rank - len(inputs) + 1):
+ out[i + index_rank] = inputs[0][i + len(inputs) - 1]
+
+ return out
+
+@_reg.register_shape_func("adv_index", False)
+def adv_index_shape_func(attrs, inputs, _):
+ """
+ Shape func for adv_index.
+ Only allow single index tensor.
+ """
+ return [_adv_index_shape_func(inputs)]
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index 01466f7..0ce59ad 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -1213,3 +1213,21 @@ def matrix_set_diag(data, diagonal):
[7, 7, 6, 7]]]
"""
return _make.matrix_set_diag(data, diagonal)
+
+
+def adv_index(inputs):
+ """
+ Numpy style advanced indexing. Index with a list of tensors.
+
+ Parameters
+ ----------
+ inputs : Union(List[relay.Expr], Tuple[relay.Expr])
+ Input tensor and indices.
+ The first tensor is input data and rests are indices.
+
+ Returns
+ -------
+ result: relay.Expr
+ Output tensor.
+ """
+ return _make.adv_index(Tuple(inputs))
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index f3e5a6a..1681d87 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -838,3 +838,21 @@ def matrix_set_diag(data, diagonal):
[7, 7, 6, 7]]]
"""
return cpp.matrix_set_diag(data, diagonal)
+
+def adv_index(data, indices):
+ """Numpy style indexing with tensors.
+
+ Parameters
+ ----------
+ data : tvm.te.Tensor
+ Input data.
+
+ indices : A list of tvm.te.Tensor
+ Tensor index.
+
+ Returns
+ -------
+ result : tvm.te.Tensor
+ Output tensor
+ """
+ return cpp.adv_index(data, indices)
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 88179b7..e3d0950 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -3163,5 +3163,88 @@ RELAY_REGISTER_OP("matrix_set_diag")
.set_attr<FTVMCompute>("FTVMCompute", MatrixSetDiagCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
+// adv_index
+bool AdvIndexRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(num_inputs, 1);
+ auto inputs = types[0].as<TupleTypeNode>();
+ auto data = inputs->fields[0].as<TensorTypeNode>();
+
+ if (inputs == nullptr || data == nullptr) {
+ return false;
+ }
+
+ Array<IndexExpr> oshape;
+ Array<IndexExpr> broadcast_shape;
+ int64_t num_picked_elems = 1;
+
+ if (inputs->fields.size() == 2) {
+ broadcast_shape = inputs->fields[1].as<TensorTypeNode>()->shape;
+ } else {
+ for (size_t i = 1; i < inputs->fields.size(); ++i) {
+ auto index_type = inputs->fields[i].as<TensorTypeNode>();
+ if (index_type == nullptr) {
+ return false;
+ }
+ CHECK(index_type->dtype.is_int()) << "indices must be tensor of integers";
+
+ int64_t flatten_len = 1;
+ bool has_dyn_shape = false;
+ for (const auto& dim : index_type->shape) {
+ const IntImmNode* axis_len = dim.as<IntImmNode>();
+ if (!axis_len) {
+ // If dynamic shape appears, just use the first shape
+ broadcast_shape = index_type->shape;
+ has_dyn_shape = true;
+ break;
+ }
+ flatten_len *= axis_len->value;
+ }
+ if (has_dyn_shape) break;
+ if (flatten_len > num_picked_elems) {
+ num_picked_elems = flatten_len;
+ broadcast_shape = index_type->shape;
+ }
+ }
+ }
+
+ for (const auto& dim : broadcast_shape) {
+ oshape.push_back(dim);
+ }
+ for (size_t i = inputs->fields.size() - 1; i < data->shape.size(); ++i) {
+ oshape.push_back(data->shape[i]);
+ }
+ reporter->Assign(types[1], TensorType(oshape, data->dtype));
+ return true;
+}
+
+Array<te::Tensor> AdvIndexCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_type) {
+ Array<te::Tensor> indices;
+ for (size_t i = 1; i < inputs.size(); ++i) {
+ indices.push_back(inputs[i]);
+ }
+ return {topi::adv_index(inputs[0], indices)};
+}
+
+Expr MakeAdvIndex(Expr inputs) {
+ static const Op& op = Op::Get("adv_index");
+ return Call(op, {inputs}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.adv_index").set_body_typed(MakeAdvIndex);
+
+RELAY_REGISTER_OP("adv_index")
+ .describe(R"code(Numpy style advanced indexing. Index with a list of tensors.
+ )code" TVM_ADD_FILELINE)
+ .set_num_inputs(1)
+ .set_support_level(3)
+ .add_argument("inputs", "Tuple of Tensors", "Input tensor and indices.")
+ .add_type_rel("AdvIndex", AdvIndexRel)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<TOpPattern>("TOpPattern", kInjective)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+ .set_attr<FTVMCompute>("FTVMCompute", AdvIndexCompute);
+
} // namespace relay
} // namespace tvm
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index 154933f..bf7e1e6 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -180,5 +180,9 @@ TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body([](TVMArgs args, TVMRetValu
*rv = matrix_set_diag(args[0], args[1]);
});
+TVM_REGISTER_GLOBAL("topi.adv_index").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = adv_index(args[0], args[1]);
+});
+
} // namespace topi
} // namespace tvm
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index 6bb34d3..3a46fdd 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -892,5 +892,20 @@ def test_reshape_concat():
np.reshape(np_data1, np_shape_like1.shape)], axis=0)
check_result([np_data0, np_data1, np_shape_like0, np_shape_like1], mod, ref_res)
+def test_any_adv_index():
+ data = relay.var("data", shape=(5, relay.Any(), relay.Any()), dtype='float32')
+ index0 = relay.var("index0", shape=(1, relay.Any()), dtype='int64')
+ index1 = relay.var("index1", shape=(1, relay.Any()), dtype='int64')
+ out = relay.adv_index([data, index0, index1])
+ mod = tvm.IRModule()
+ mod['main'] = relay.Function([data, index0, index1], out)
+ np_data_shape = (5, 5, 10)
+ np_index_shape = (1, 4)
+ np_data = np.random.uniform(size=np_data_shape).astype('float32')
+ np_index = np.random.uniform(0, np_data_shape[0], size=np_index_shape).astype('int64')
+ ref_res = np_data[tuple([np_index, np_index])]
+ check_result([np_data, np_index, np_index], mod, ref_res)
+
+
if __name__ == "__main__":
pytest.main([__file__])
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index f709aa2..98ef38d 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -1091,6 +1091,31 @@ def test_sparse_to_dense():
#sparse_indices should not be > 2d tensor
#verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[[3.1, 3.1, 3.1]]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
+def test_adv_index():
+ def verify_adv_index(data_shape, index_shapes):
+ dtype = "float32"
+ inputs = [relay.var("data", relay.TensorType(data_shape, dtype))]
+ np_data = np.random.uniform(size=data_shape).astype(dtype)
+ np_indices = []
+ for i, index_shape in enumerate(index_shapes):
+ limit = data_shape[i]
+ np_indices.append(np.random.uniform(0, limit - 1, size=index_shape).astype("int64"))
+ inputs.append(relay.var("index_{}".format(i), relay.TensorType(index_shape, "int64")))
+ np_out = np_data[tuple(np_indices)]
+ np_args = [np_data] + np_indices
+ out = relay.op.adv_index(inputs)
+
+ func = relay.Function(inputs, out)
+ for target, ctx in tvm.testing.enabled_targets():
+ for kind in ["graph", "debug"]:
+ intrp = relay.create_executor(kind, ctx=ctx, target=target)
+ op_res = intrp.evaluate(func)(*np_args)
+ tvm.testing.assert_allclose(op_res.asnumpy(), np_out, rtol=1e-5)
+
+ verify_adv_index((10, 5), [(3, 4), (3, 1)])
+ verify_adv_index((10, 5), [(2,),])
+ verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)])
+
if __name__ == "__main__":
test_cast()
test_zeros_ones()
@@ -1127,3 +1152,4 @@ if __name__ == "__main__":
test_unravel_index()
test_sparse_to_dense()
test_fixed_point_multiply()
+ test_adv_index()
diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py
index a061ba9..fc6f19f 100644
--- a/tests/python/topi/python/test_topi_transform.py
+++ b/tests/python/topi/python/test_topi_transform.py
@@ -678,6 +678,7 @@ def verify_matrix_set_diag(input_shape, dtype):
input = te.placeholder(shape=input_shape, name="input", dtype=dtype)
diagonal = te.placeholder(shape=diagonal_shape, name="diagonal", dtype=dtype)
matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal)
+
def check_device(device, ctx):
ctx = tvm.context(device, 0)
print("Running on target: %s" % device)
@@ -697,6 +698,40 @@ def verify_matrix_set_diag(input_shape, dtype):
for target, ctx in tvm.testing.enabled_targets():
check_device(target, ctx)
+def verify_adv_index(data_shape, index_shapes):
+ dtype = "float32"
+ data = te.placeholder(shape=data_shape, name="data", dtype=dtype)
+ indices = []
+ np_data = np.random.uniform(size=data_shape).astype(dtype)
+ np_indices = []
+ for i, index_shape in enumerate(index_shapes):
+ limit = data_shape[i]
+ np_indices.append(np.random.uniform(0, limit - 1, size=index_shape).astype("int64"))
+ indices.append(te.placeholder(shape=index_shape, name="index_{}".format(i), dtype="int64"))
+ np_out = np_data[tuple(np_indices)]
+ out = topi.adv_index(data, indices)
+
+ def check_device(device, ctx):
+ ctx = tvm.context(device, 0)
+ if not ctx.exist:
+ print("Skip because %s is not enabled" % device)
+ return
+ print("Running on target: %s" % device)
+ with tvm.target.create(device):
+ s = tvm.topi.testing.get_injective_schedule(device)(out)
+
+ func = tvm.build(s, [data] + indices + [out], device, name="adv_index")
+
+ nd_list = [tvm.nd.array(np_data, ctx)]
+ for np_index in np_indices:
+ nd_list.append(tvm.nd.array(np_index, ctx))
+ nd_list.append(tvm.nd.empty(out.shape, ctx=ctx, dtype=data.dtype))
+
+ func(*nd_list)
+ tvm.testing.assert_allclose(nd_list[-1].asnumpy(), np.array(np_out))
+
+ for target, ctx in tvm.testing.enabled_targets():
+ check_device(target, ctx)
@tvm.testing.uses_gpu
def test_strided_slice():
@@ -1071,6 +1106,12 @@ def test_matrix_set_diag():
verify_matrix_set_diag((4, 3, 3), dtype)
verify_matrix_set_diag((2, 3, 4), dtype)
+@tvm.testing.uses_gpu
+def test_adv_index():
+ verify_adv_index((3, 4, 5), [(2,), (2, ), (1,)])
+ verify_adv_index((10, 15, 5), [(1, 1), (2, 7)])
+ verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)])
+
if __name__ == "__main__":
test_strided_slice()
test_concatenate()
@@ -1097,3 +1138,4 @@ if __name__ == "__main__":
test_unravel_index()
test_sparse_to_dense()
test_matrix_set_diag()
+ test_adv_index()