You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2022/01/31 17:54:25 UTC
[tvm] branch main updated: [Bugfix][Op] Fix shape inference of adv_index (#9717)
This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dad8f62 [Bugfix][Op] Fix shape inference of adv_index (#9717)
dad8f62 is described below
commit dad8f62fc10282691227f303be3a7bd306e511c8
Author: Huang, Guangtai <gu...@amazon.com>
AuthorDate: Tue Feb 1 01:53:45 2022 +0800
[Bugfix][Op] Fix shape inference of adv_index (#9717)
* init
* test
* lint
---
include/tvm/topi/transform.h | 36 ++++-----------
python/tvm/relay/op/_transform.py | 58 ++++++++++---------------
src/relay/op/tensor/transform.cc | 39 +++--------------
tests/python/relay/test_any.py | 13 +++---
tests/python/relay/test_op_level3.py | 1 +
tests/python/topi/python/test_topi_transform.py | 2 +-
6 files changed, 48 insertions(+), 101 deletions(-)
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 59e6d41..acff301 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -1902,43 +1902,23 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k
inline Tensor adv_index(const Tensor& data, const Array<Tensor>& indices,
const std::string name = "advanced_index",
const std::string tag = kInjective) {
+ ICHECK_LE(indices.size(), data->shape.size()) << "too many indices for data!";
Array<PrimExpr> oshape;
Array<PrimExpr> broadcast_shape;
Array<Tensor> bindices;
- std::vector<int64_t> flatten_shape_lens;
- int64_t num_picked_elems = 1;
- bool has_dyn_shape = false;
+ broadcast_shape = indices[0]->shape;
+ for (size_t i = 1; i < indices.size(); ++i) {
+ auto bh = detail::BroadcastShape(broadcast_shape, indices[i]->shape);
+ broadcast_shape = Array<PrimExpr>(bh.common_shape.begin(), bh.common_shape.end());
+ }
if (indices.size() == 1) {
- broadcast_shape = indices[0]->shape;
+ // quick path
bindices = indices;
} else {
- for (const auto& index : indices) {
- int64_t flatten_len = 1;
- for (const auto& dim : index->shape) {
- const IntImmNode* axis_len = dim.as<IntImmNode>();
- if (!axis_len) {
- broadcast_shape = index->shape;
- has_dyn_shape = true;
- break;
- }
- flatten_len *= axis_len->value;
- }
- if (has_dyn_shape) break;
- flatten_shape_lens.push_back(flatten_len);
- if (flatten_len > num_picked_elems) {
- num_picked_elems = flatten_len;
- broadcast_shape = index->shape;
- }
- }
-
// Do broadcast for indices
for (size_t i = 0; i < indices.size(); ++i) {
- if (!has_dyn_shape && flatten_shape_lens[i] < num_picked_elems) {
- bindices.push_back(broadcast_to(indices[i], broadcast_shape));
- } else {
- bindices.push_back(indices[i]);
- }
+ bindices.push_back(broadcast_to(indices[i], broadcast_shape));
}
}
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index cc71ea1..b67579a 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -976,40 +976,6 @@ def split_shape_func(attrs, inputs, _):
@script
-def _adv_index_shape_func(inputs):
- index_rank = inputs[1].shape[0]
- data_rank = inputs[0].shape[0]
- out = output_tensor((data_rank + index_rank - len(inputs) + 1,), "int64")
-
- max_flatten_len = int64(1)
- for i in const_range(index_rank):
- max_flatten_len *= inputs[1][i]
- out[i] = inputs[1][i]
- for i in const_range(len(inputs) - 2):
- flatten_len = int64(1)
- for j in const_range(index_rank):
- flatten_len *= inputs[i + 2][j]
- if flatten_len > max_flatten_len:
- max_flatten_len = flatten_len
- for k in const_range(index_rank):
- out[k] = inputs[i + 2][k]
-
- for i in const_range(data_rank - len(inputs) + 1):
- out[i + index_rank] = inputs[0][i + len(inputs) - 1]
-
- return out
-
-
-@_reg.register_shape_func("adv_index", False)
-def adv_index_shape_func(attrs, inputs, _):
- """
- Shape func for adv_index.
- Only allow single index tensor.
- """
- return [_adv_index_shape_func(inputs)]
-
-
-@script
def _repeat_shape_func(data_shape, repeats, axis):
out = output_tensor((data_shape.shape[0],), "int64")
@@ -1116,6 +1082,30 @@ def where_shape_func(attrs, inputs, _):
@script
+def _adv_index_post_process(data_shape, bcast_shape, num_indices):
+ data_rank = data_shape.shape[0]
+ bcast_rank = bcast_shape.shape[0]
+ out = output_tensor((data_rank + bcast_rank - num_indices,), "int64")
+
+ for i in const_range(bcast_rank):
+ out[i] = bcast_shape[i]
+ for i in const_range(data_rank - num_indices):
+ out[i + bcast_rank] = data_shape[i + num_indices]
+ return out
+
+
+@_reg.register_shape_func("adv_index", False)
+def adv_index_shape_func(attrs, inputs, _):
+ """
+ Shape func for adv_index.
+ """
+ bcast_shape = inputs[1]
+ for i in inputs[2:]:
+ bcast_shape = _broadcast_shape_tensors(bcast_shape, i)
+ return [_adv_index_post_process(inputs[0], bcast_shape, convert(len(inputs) - 1))]
+
+
+@script
def _unique_shape(data_shape):
unique_shape = output_tensor((1,), "int64")
indices_shape = output_tensor((1,), "int64")
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 645c7dc..19f6cdf 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -3886,43 +3886,16 @@ bool AdvIndexRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (inputs == nullptr || data == nullptr) {
return false;
}
+ ICHECK_LE(inputs->fields.size() - 1, data->shape.size()) << "too many indices for data!";
Array<IndexExpr> oshape;
- Array<IndexExpr> broadcast_shape;
- int64_t num_picked_elems = 1;
-
- if (inputs->fields.size() == 2) {
- broadcast_shape = inputs->fields[1].as<TensorTypeNode>()->shape;
- } else {
- for (size_t i = 1; i < inputs->fields.size(); ++i) {
- auto index_type = inputs->fields[i].as<TensorTypeNode>();
- if (index_type == nullptr) {
- return false;
- }
- ICHECK(index_type->dtype.is_int() || index_type->dtype.is_uint())
- << "indices must be tensor of integers";
-
- int64_t flatten_len = 1;
- bool has_dyn_shape = false;
- for (const auto& dim : index_type->shape) {
- const IntImmNode* axis_len = dim.as<IntImmNode>();
- if (!axis_len) {
- // If dynamic shape appears, just use the first shape
- broadcast_shape = index_type->shape;
- has_dyn_shape = true;
- break;
- }
- flatten_len *= axis_len->value;
- }
- if (has_dyn_shape) break;
- if (flatten_len > num_picked_elems) {
- num_picked_elems = flatten_len;
- broadcast_shape = index_type->shape;
- }
- }
+ TensorType broadcast_type = Downcast<TensorType>(inputs->fields[1]);
+ for (size_t i = 2; i < inputs->fields.size(); ++i) {
+ broadcast_type =
+ ConcreteBroadcast(broadcast_type, Downcast<TensorType>(inputs->fields[i]), data->dtype);
}
- for (const auto& dim : broadcast_shape) {
+ for (const auto& dim : broadcast_type->shape) {
oshape.push_back(dim);
}
for (size_t i = inputs->fields.size() - 1; i < data->shape.size(); ++i) {
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index da36bba..97770f5 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -1711,16 +1711,19 @@ def test_reshape_concat():
def test_any_adv_index():
data = relay.var("data", shape=(5, relay.Any(), relay.Any()), dtype="float32")
index0 = relay.var("index0", shape=(1, relay.Any()), dtype="int64")
- index1 = relay.var("index1", shape=(1, relay.Any()), dtype="int64")
+ index1 = relay.var("index1", shape=(relay.Any(), 1), dtype="int64")
out = relay.adv_index([data, index0, index1])
mod = tvm.IRModule()
mod["main"] = relay.Function([data, index0, index1], out)
np_data_shape = (5, 5, 10)
- np_index_shape = (1, 4)
+ np_index0_shape = (1, 4)
+ np_index1_shape = (4, 1)
np_data = np.random.uniform(size=np_data_shape).astype("float32")
- np_index = np.random.uniform(0, np_data_shape[0], size=np_index_shape).astype("int64")
- ref_res = np_data[tuple([np_index, np_index])]
- check_result([np_data, np_index, np_index], mod, ref_res)
+ np_index0 = np.random.uniform(0, np_data_shape[0], size=np_index0_shape).astype("int64")
+ np_index1 = np.random.uniform(0, np_data_shape[0], size=np_index1_shape).astype("int64")
+ ref_res = np_data[tuple([np_index0, np_index1])]
+ print(ref_res.shape)
+ check_result([np_data, np_index0, np_index1], mod, ref_res)
def verify_any_repeat(data_shape, np_dshape, repeats, axis):
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index a6eeaa6..34f3324 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -1825,6 +1825,7 @@ def test_adv_index(target, dev, executor_kind):
tvm.testing.assert_allclose(op_res.numpy(), np_out, rtol=1e-5)
verify_adv_index((10, 5), [(3, 4), (3, 1)])
+ verify_adv_index((10, 5), [(1, 4), (3, 1)])
verify_adv_index(
(10, 5),
[
diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py
index d500b66..ddec14b 100644
--- a/tests/python/topi/python/test_topi_transform.py
+++ b/tests/python/topi/python/test_topi_transform.py
@@ -1259,7 +1259,7 @@ def test_matrix_set_diag():
def test_adv_index():
for indice_dtype in ["int32", "int64", "uint8", "uint16", "uint32"]:
verify_adv_index((3, 4, 5), [(2,), (2,), (1,)], indice_dtype=indice_dtype)
- verify_adv_index((10, 15, 5), [(1, 1), (2, 7)], indice_dtype=indice_dtype)
+ verify_adv_index((10, 15, 5), [(4, 1), (1, 7)], indice_dtype=indice_dtype)
verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)], indice_dtype=indice_dtype)