You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2022/01/31 17:54:25 UTC

[tvm] branch main updated: [Bugfix][Op] Fix shape inference of adv_index (#9717)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dad8f62  [Bugfix][Op] Fix shape inference of adv_index (#9717)
dad8f62 is described below

commit dad8f62fc10282691227f303be3a7bd306e511c8
Author: Huang, Guangtai <gu...@amazon.com>
AuthorDate: Tue Feb 1 01:53:45 2022 +0800

    [Bugfix][Op] Fix shape inference of adv_index (#9717)
    
    * init
    
    * test
    
    * lint
---
 include/tvm/topi/transform.h                    | 36 ++++-----------
 python/tvm/relay/op/_transform.py               | 58 ++++++++++---------------
 src/relay/op/tensor/transform.cc                | 39 +++--------------
 tests/python/relay/test_any.py                  | 13 +++---
 tests/python/relay/test_op_level3.py            |  1 +
 tests/python/topi/python/test_topi_transform.py |  2 +-
 6 files changed, 48 insertions(+), 101 deletions(-)

diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 59e6d41..acff301 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -1902,43 +1902,23 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k
 inline Tensor adv_index(const Tensor& data, const Array<Tensor>& indices,
                         const std::string name = "advanced_index",
                         const std::string tag = kInjective) {
+  ICHECK_LE(indices.size(), data->shape.size()) << "too many indices for data!";
   Array<PrimExpr> oshape;
   Array<PrimExpr> broadcast_shape;
   Array<Tensor> bindices;
-  std::vector<int64_t> flatten_shape_lens;
-  int64_t num_picked_elems = 1;
-  bool has_dyn_shape = false;
 
+  broadcast_shape = indices[0]->shape;
+  for (size_t i = 1; i < indices.size(); ++i) {
+    auto bh = detail::BroadcastShape(broadcast_shape, indices[i]->shape);
+    broadcast_shape = Array<PrimExpr>(bh.common_shape.begin(), bh.common_shape.end());
+  }
   if (indices.size() == 1) {
-    broadcast_shape = indices[0]->shape;
+    // quick path
     bindices = indices;
   } else {
-    for (const auto& index : indices) {
-      int64_t flatten_len = 1;
-      for (const auto& dim : index->shape) {
-        const IntImmNode* axis_len = dim.as<IntImmNode>();
-        if (!axis_len) {
-          broadcast_shape = index->shape;
-          has_dyn_shape = true;
-          break;
-        }
-        flatten_len *= axis_len->value;
-      }
-      if (has_dyn_shape) break;
-      flatten_shape_lens.push_back(flatten_len);
-      if (flatten_len > num_picked_elems) {
-        num_picked_elems = flatten_len;
-        broadcast_shape = index->shape;
-      }
-    }
-
     // Do broadcast for indices
     for (size_t i = 0; i < indices.size(); ++i) {
-      if (!has_dyn_shape && flatten_shape_lens[i] < num_picked_elems) {
-        bindices.push_back(broadcast_to(indices[i], broadcast_shape));
-      } else {
-        bindices.push_back(indices[i]);
-      }
+      bindices.push_back(broadcast_to(indices[i], broadcast_shape));
     }
   }
 
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index cc71ea1..b67579a 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -976,40 +976,6 @@ def split_shape_func(attrs, inputs, _):
 
 
 @script
-def _adv_index_shape_func(inputs):
-    index_rank = inputs[1].shape[0]
-    data_rank = inputs[0].shape[0]
-    out = output_tensor((data_rank + index_rank - len(inputs) + 1,), "int64")
-
-    max_flatten_len = int64(1)
-    for i in const_range(index_rank):
-        max_flatten_len *= inputs[1][i]
-        out[i] = inputs[1][i]
-    for i in const_range(len(inputs) - 2):
-        flatten_len = int64(1)
-        for j in const_range(index_rank):
-            flatten_len *= inputs[i + 2][j]
-        if flatten_len > max_flatten_len:
-            max_flatten_len = flatten_len
-            for k in const_range(index_rank):
-                out[k] = inputs[i + 2][k]
-
-    for i in const_range(data_rank - len(inputs) + 1):
-        out[i + index_rank] = inputs[0][i + len(inputs) - 1]
-
-    return out
-
-
-@_reg.register_shape_func("adv_index", False)
-def adv_index_shape_func(attrs, inputs, _):
-    """
-    Shape func for adv_index.
-    Only allow single index tensor.
-    """
-    return [_adv_index_shape_func(inputs)]
-
-
-@script
 def _repeat_shape_func(data_shape, repeats, axis):
     out = output_tensor((data_shape.shape[0],), "int64")
 
@@ -1116,6 +1082,30 @@ def where_shape_func(attrs, inputs, _):
 
 
 @script
+def _adv_index_post_process(data_shape, bcast_shape, num_indices):
+    data_rank = data_shape.shape[0]
+    bcast_rank = bcast_shape.shape[0]
+    out = output_tensor((data_rank + bcast_rank - num_indices,), "int64")
+
+    for i in const_range(bcast_rank):
+        out[i] = bcast_shape[i]
+    for i in const_range(data_rank - num_indices):
+        out[i + bcast_rank] = data_shape[i + num_indices]
+    return out
+
+
+@_reg.register_shape_func("adv_index", False)
+def adv_index_shape_func(attrs, inputs, _):
+    """
+    Shape func for adv_index.
+    """
+    bcast_shape = inputs[1]
+    for i in inputs[2:]:
+        bcast_shape = _broadcast_shape_tensors(bcast_shape, i)
+    return [_adv_index_post_process(inputs[0], bcast_shape, convert(len(inputs) - 1))]
+
+
+@script
 def _unique_shape(data_shape):
     unique_shape = output_tensor((1,), "int64")
     indices_shape = output_tensor((1,), "int64")
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 645c7dc..19f6cdf 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -3886,43 +3886,16 @@ bool AdvIndexRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   if (inputs == nullptr || data == nullptr) {
     return false;
   }
+  ICHECK_LE(inputs->fields.size() - 1, data->shape.size()) << "too many indices for data!";
 
   Array<IndexExpr> oshape;
-  Array<IndexExpr> broadcast_shape;
-  int64_t num_picked_elems = 1;
-
-  if (inputs->fields.size() == 2) {
-    broadcast_shape = inputs->fields[1].as<TensorTypeNode>()->shape;
-  } else {
-    for (size_t i = 1; i < inputs->fields.size(); ++i) {
-      auto index_type = inputs->fields[i].as<TensorTypeNode>();
-      if (index_type == nullptr) {
-        return false;
-      }
-      ICHECK(index_type->dtype.is_int() || index_type->dtype.is_uint())
-          << "indices must be tensor of integers";
-
-      int64_t flatten_len = 1;
-      bool has_dyn_shape = false;
-      for (const auto& dim : index_type->shape) {
-        const IntImmNode* axis_len = dim.as<IntImmNode>();
-        if (!axis_len) {
-          // If dynamic shape appears, just use the first shape
-          broadcast_shape = index_type->shape;
-          has_dyn_shape = true;
-          break;
-        }
-        flatten_len *= axis_len->value;
-      }
-      if (has_dyn_shape) break;
-      if (flatten_len > num_picked_elems) {
-        num_picked_elems = flatten_len;
-        broadcast_shape = index_type->shape;
-      }
-    }
+  TensorType broadcast_type = Downcast<TensorType>(inputs->fields[1]);
+  for (size_t i = 2; i < inputs->fields.size(); ++i) {
+    broadcast_type =
+        ConcreteBroadcast(broadcast_type, Downcast<TensorType>(inputs->fields[i]), data->dtype);
   }
 
-  for (const auto& dim : broadcast_shape) {
+  for (const auto& dim : broadcast_type->shape) {
     oshape.push_back(dim);
   }
   for (size_t i = inputs->fields.size() - 1; i < data->shape.size(); ++i) {
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index da36bba..97770f5 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -1711,16 +1711,19 @@ def test_reshape_concat():
 def test_any_adv_index():
     data = relay.var("data", shape=(5, relay.Any(), relay.Any()), dtype="float32")
     index0 = relay.var("index0", shape=(1, relay.Any()), dtype="int64")
-    index1 = relay.var("index1", shape=(1, relay.Any()), dtype="int64")
+    index1 = relay.var("index1", shape=(relay.Any(), 1), dtype="int64")
     out = relay.adv_index([data, index0, index1])
     mod = tvm.IRModule()
     mod["main"] = relay.Function([data, index0, index1], out)
     np_data_shape = (5, 5, 10)
-    np_index_shape = (1, 4)
+    np_index0_shape = (1, 4)
+    np_index1_shape = (4, 1)
     np_data = np.random.uniform(size=np_data_shape).astype("float32")
-    np_index = np.random.uniform(0, np_data_shape[0], size=np_index_shape).astype("int64")
-    ref_res = np_data[tuple([np_index, np_index])]
-    check_result([np_data, np_index, np_index], mod, ref_res)
+    np_index0 = np.random.uniform(0, np_data_shape[0], size=np_index0_shape).astype("int64")
+    np_index1 = np.random.uniform(0, np_data_shape[0], size=np_index1_shape).astype("int64")
+    ref_res = np_data[tuple([np_index0, np_index1])]
+    print(ref_res.shape)
+    check_result([np_data, np_index0, np_index1], mod, ref_res)
 
 
 def verify_any_repeat(data_shape, np_dshape, repeats, axis):
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index a6eeaa6..34f3324 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -1825,6 +1825,7 @@ def test_adv_index(target, dev, executor_kind):
         tvm.testing.assert_allclose(op_res.numpy(), np_out, rtol=1e-5)
 
     verify_adv_index((10, 5), [(3, 4), (3, 1)])
+    verify_adv_index((10, 5), [(1, 4), (3, 1)])
     verify_adv_index(
         (10, 5),
         [
diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py
index d500b66..ddec14b 100644
--- a/tests/python/topi/python/test_topi_transform.py
+++ b/tests/python/topi/python/test_topi_transform.py
@@ -1259,7 +1259,7 @@ def test_matrix_set_diag():
 def test_adv_index():
     for indice_dtype in ["int32", "int64", "uint8", "uint16", "uint32"]:
         verify_adv_index((3, 4, 5), [(2,), (2,), (1,)], indice_dtype=indice_dtype)
-        verify_adv_index((10, 15, 5), [(1, 1), (2, 7)], indice_dtype=indice_dtype)
+        verify_adv_index((10, 15, 5), [(4, 1), (1, 7)], indice_dtype=indice_dtype)
         verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)], indice_dtype=indice_dtype)