You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ke...@apache.org on 2020/09/11 06:10:05 UTC

[incubator-tvm] branch master updated: [Relay][Topi][Op]Advanced indexing (#6388)

This is an automated email from the ASF dual-hosted git repository.

kevinthesun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 1228111  [Relay][Topi][Op]Advanced indexing (#6388)
1228111 is described below

commit 122811137a6fe3e787d5174a36d99eebefb479d0
Author: Yao Wang <ke...@gmail.com>
AuthorDate: Thu Sep 10 23:09:45 2020 -0700

    [Relay][Topi][Op]Advanced indexing (#6388)
    
    * Add Relay adv_index op
    
    * Support single index tensor dynamic shape
    
    * Support more dynamic index
    
    * Fix lint
    
    * Minor fix for comment
    
    * Fix lint
    
    * Fix lint
    
    * Fix test
    
    * Fix
---
 include/tvm/topi/transform.h                    | 80 ++++++++++++++++++++++++
 python/tvm/relay/frontend/pytorch.py            | 40 +-----------
 python/tvm/relay/op/_transform.py               | 33 ++++++++++
 python/tvm/relay/op/transform.py                | 18 ++++++
 python/tvm/topi/transform.py                    | 18 ++++++
 src/relay/op/tensor/transform.cc                | 83 +++++++++++++++++++++++++
 src/topi/transform.cc                           |  4 ++
 tests/python/relay/test_any.py                  | 15 +++++
 tests/python/relay/test_op_level3.py            | 26 ++++++++
 tests/python/topi/python/test_topi_transform.py | 42 +++++++++++++
 10 files changed, 321 insertions(+), 38 deletions(-)

diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index af59928..2c0d102 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -26,6 +26,7 @@
 
 #include <tvm/te/operation.h>
 #include <tvm/tir/data_layout.h>
+#include <tvm/topi/broadcast.h>
 #include <tvm/topi/detail/constant_utils.h>
 #include <tvm/topi/detail/ravel_unravel.h>
 #include <tvm/topi/detail/tensor_utils.h>
@@ -1551,6 +1552,85 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal,
       name, tag);
 }
 
+/*!
+ * \brief Numpy style advanced indexing with tensor.
+ * \param data is input data.
+ * \param indices is list of indexing tensors.
+ * \param name output tensor name.
+ * \param tag output tensor tag.
+ * \return Output tensor.
+ */
+inline Tensor adv_index(const Tensor& data, const Array<Tensor>& indices,
+                        const std::string name = "advanced_index",
+                        const std::string tag = kInjective) {
+  Array<PrimExpr> oshape;
+  Array<PrimExpr> broadcast_shape;
+  Array<Tensor> bindices;
+  std::vector<int64_t> flatten_shape_lens;
+  int64_t num_picked_elems = 1;
+  bool has_dyn_shape = false;
+
+  if (indices.size() == 1) {
+    broadcast_shape = indices[0]->shape;
+    bindices = indices;
+  } else {
+    for (const auto& index : indices) {
+      int64_t flatten_len = 1;
+      for (const auto& dim : index->shape) {
+        const IntImmNode* axis_len = dim.as<IntImmNode>();
+        if (!axis_len) {
+          broadcast_shape = index->shape;
+          has_dyn_shape = true;
+          break;
+        }
+        flatten_len *= axis_len->value;
+      }
+      if (has_dyn_shape) break;
+      flatten_shape_lens.push_back(flatten_len);
+      if (flatten_len > num_picked_elems) {
+        num_picked_elems = flatten_len;
+        broadcast_shape = index->shape;
+      }
+    }
+
+    // Do broadcast for indices
+    for (size_t i = 0; i < indices.size(); ++i) {
+      if (!has_dyn_shape && flatten_shape_lens[i] < num_picked_elems) {
+        bindices.push_back(broadcast_to(indices[i], broadcast_shape));
+      } else {
+        bindices.push_back(indices[i]);
+      }
+    }
+  }
+
+  for (const auto& dim : broadcast_shape) {
+    oshape.push_back(dim);
+  }
+  for (size_t i = indices.size(); i < data->shape.size(); ++i) {
+    oshape.push_back(data->shape[i]);
+  }
+
+  return compute(
+      oshape,
+      [&](const Array<Var>& iter_var) {
+        Array<PrimExpr> tensor_indices;
+        for (size_t i = 0; i < broadcast_shape.size(); ++i) {
+          tensor_indices.push_back(iter_var[i]);
+        }
+
+        Array<PrimExpr> real_indices;
+        for (size_t i = 0; i < bindices.size(); ++i) {
+          real_indices.push_back(bindices[i](tensor_indices));
+        }
+        for (size_t i = broadcast_shape.size(); i < iter_var.size(); ++i) {
+          real_indices.push_back(iter_var[i]);
+        }
+
+        return data(real_indices);
+      },
+      name, tag);
+}
+
 }  // namespace topi
 }  // namespace tvm
 #endif  // TVM_TOPI_TRANSFORM_H_
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 7203150..19cbf75 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1816,44 +1816,8 @@ def _one_hot():
 def _index():
     def _impl(inputs, input_types):
         data = inputs[0]
-        indices = []
-        raw_indices = []
-        max_indices_len = -1
-        for index in inputs[1]:
-            if not isinstance(index, _expr.Constant):
-                try:
-                    index = _expr.const(_infer_value(index, {}))
-                except Exception:
-                    raise RuntimeError("Only supports constant indices for "
-                                       "pytorch advanced indexing ")
-            raw_indices.append(index)
-            cindex_len = index.data.shape[0]
-            if cindex_len > max_indices_len:
-                max_indices_len = cindex_len
-
-        for index in raw_indices:
-            cnp = index.data.asnumpy()
-            cindex_len = cnp.shape[0]
-            if cindex_len < max_indices_len:
-                cnp = np.tile(cnp, max_indices_len // cindex_len)
-            indices.append(cnp)
-
-        ret = []
-        slice_map = {}
-        for i in range(indices[0].shape[0]):
-            tmp = data
-            current_indices = []
-            for index in indices:
-                current_indices.append(index[i])
-                index_key = tuple(current_indices)
-                if index_key in slice_map:
-                    tmp = slice_map[index_key]
-                else:
-                    tmp = _op.take(tmp, _expr.const(index[i]), axis=0)
-                    slice_map[index_key] = tmp
-            ret.append(_op.expand_dims(tmp, axis=0))
-
-        return _op.concatenate(ret, axis=0)
+        indices = inputs[1]
+        return _op.adv_index([data] + indices)
     return _impl
 
 
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index 9d7c389..98ff0b3 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -62,6 +62,7 @@ _reg.register_reduce_schedule("collapse_sum_to")
 _reg.register_injective_schedule("unravel_index")
 _reg.register_injective_schedule("sparse_to_dense")
 _reg.register_injective_schedule("matrix_set_diag")
+_reg.register_injective_schedule("adv_index")
 
 # concatenate
 _reg.register_schedule("concatenate", strategy.schedule_concatenate)
@@ -661,3 +662,35 @@ def split_shape_func(attrs, inputs, _):
                               convert(i),
                               convert(indices_or_sections),
                               convert(axis)) for i in range(num_out)]
+
+@script
+def _adv_index_shape_func(inputs):
+    index_rank = inputs[1].shape[0]
+    data_rank = inputs[0].shape[0]
+    out = output_tensor((data_rank + index_rank - len(inputs) + 1,), "int64")
+
+    max_flatten_len = int64(1)
+    for i in const_range(index_rank):
+        max_flatten_len *= inputs[1][i]
+        out[i] = inputs[1][i]
+    for i in const_range(len(inputs) - 2):
+        flatten_len = int64(1)
+        for j in const_range(index_rank):
+            flatten_len *= inputs[i + 2][j]
+        if flatten_len > max_flatten_len:
+            max_flatten_len = flatten_len
+            for k in const_range(index_rank):
+                out[k] = inputs[i + 2][k]
+
+    for i in const_range(data_rank - len(inputs) + 1):
+        out[i + index_rank] = inputs[0][i + len(inputs) - 1]
+
+    return out
+
+@_reg.register_shape_func("adv_index", False)
+def adv_index_shape_func(attrs, inputs, _):
+    """
+    Shape func for adv_index.
+    Only allow single index tensor.
+    """
+    return [_adv_index_shape_func(inputs)]
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index 01466f7..0ce59ad 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -1213,3 +1213,21 @@ def matrix_set_diag(data, diagonal):
               [7, 7, 6, 7]]]
     """
     return _make.matrix_set_diag(data, diagonal)
+
+
+def adv_index(inputs):
+    """
+    Numpy style advanced indexing. Index with a list of tensors.
+
+    Parameters
+    ----------
+    inputs : Union(List[relay.Expr], Tuple[relay.Expr])
+        Input tensor and indices.
+        The first tensor is input data and rests are indices.
+
+    Returns
+    -------
+    result: relay.Expr
+        Output tensor.
+    """
+    return _make.adv_index(Tuple(inputs))
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index f3e5a6a..1681d87 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -838,3 +838,21 @@ def matrix_set_diag(data, diagonal):
               [7, 7, 6, 7]]]
     """
     return cpp.matrix_set_diag(data, diagonal)
+
+def adv_index(data, indices):
+    """Numpy style indexing with tensors.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        Input data.
+
+    indices : A list of tvm.te.Tensor
+        Tensor index.
+
+    Returns
+    -------
+    result : tvm.te.Tensor
+        Output tensor
+    """
+    return cpp.adv_index(data, indices)
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 88179b7..e3d0950 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -3163,5 +3163,88 @@ RELAY_REGISTER_OP("matrix_set_diag")
     .set_attr<FTVMCompute>("FTVMCompute", MatrixSetDiagCompute)
     .set_attr<TOpPattern>("TOpPattern", kInjective);
 
+// adv_index
+bool AdvIndexRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                 const TypeReporter& reporter) {
+  CHECK_EQ(num_inputs, 1);
+  auto inputs = types[0].as<TupleTypeNode>();
+  auto data = inputs->fields[0].as<TensorTypeNode>();
+
+  if (inputs == nullptr || data == nullptr) {
+    return false;
+  }
+
+  Array<IndexExpr> oshape;
+  Array<IndexExpr> broadcast_shape;
+  int64_t num_picked_elems = 1;
+
+  if (inputs->fields.size() == 2) {
+    broadcast_shape = inputs->fields[1].as<TensorTypeNode>()->shape;
+  } else {
+    for (size_t i = 1; i < inputs->fields.size(); ++i) {
+      auto index_type = inputs->fields[i].as<TensorTypeNode>();
+      if (index_type == nullptr) {
+        return false;
+      }
+      CHECK(index_type->dtype.is_int()) << "indices must be tensor of integers";
+
+      int64_t flatten_len = 1;
+      bool has_dyn_shape = false;
+      for (const auto& dim : index_type->shape) {
+        const IntImmNode* axis_len = dim.as<IntImmNode>();
+        if (!axis_len) {
+          // If dynamic shape appears, just use the first shape
+          broadcast_shape = index_type->shape;
+          has_dyn_shape = true;
+          break;
+        }
+        flatten_len *= axis_len->value;
+      }
+      if (has_dyn_shape) break;
+      if (flatten_len > num_picked_elems) {
+        num_picked_elems = flatten_len;
+        broadcast_shape = index_type->shape;
+      }
+    }
+  }
+
+  for (const auto& dim : broadcast_shape) {
+    oshape.push_back(dim);
+  }
+  for (size_t i = inputs->fields.size() - 1; i < data->shape.size(); ++i) {
+    oshape.push_back(data->shape[i]);
+  }
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
+  return true;
+}
+
+Array<te::Tensor> AdvIndexCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
+                                  const Type& out_type) {
+  Array<te::Tensor> indices;
+  for (size_t i = 1; i < inputs.size(); ++i) {
+    indices.push_back(inputs[i]);
+  }
+  return {topi::adv_index(inputs[0], indices)};
+}
+
+Expr MakeAdvIndex(Expr inputs) {
+  static const Op& op = Op::Get("adv_index");
+  return Call(op, {inputs}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.adv_index").set_body_typed(MakeAdvIndex);
+
+RELAY_REGISTER_OP("adv_index")
+    .describe(R"code(Numpy style advanced indexing. Index with a list of tensors.
+    )code" TVM_ADD_FILELINE)
+    .set_num_inputs(1)
+    .set_support_level(3)
+    .add_argument("inputs", "Tuple of Tensors", "Input tensor and indices.")
+    .add_type_rel("AdvIndex", AdvIndexRel)
+    .set_attr<TOpIsStateful>("TOpIsStateful", false)
+    .set_attr<TOpPattern>("TOpPattern", kInjective)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+    .set_attr<FTVMCompute>("FTVMCompute", AdvIndexCompute);
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index 154933f..bf7e1e6 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -180,5 +180,9 @@ TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body([](TVMArgs args, TVMRetValu
   *rv = matrix_set_diag(args[0], args[1]);
 });
 
+TVM_REGISTER_GLOBAL("topi.adv_index").set_body([](TVMArgs args, TVMRetValue* rv) {
+  *rv = adv_index(args[0], args[1]);
+});
+
 }  // namespace topi
 }  // namespace tvm
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index 6bb34d3..3a46fdd 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -892,5 +892,20 @@ def test_reshape_concat():
                               np.reshape(np_data1, np_shape_like1.shape)], axis=0)
     check_result([np_data0, np_data1, np_shape_like0, np_shape_like1], mod, ref_res)
 
+def test_any_adv_index():
+    data = relay.var("data", shape=(5, relay.Any(), relay.Any()), dtype='float32')
+    index0 = relay.var("index0", shape=(1, relay.Any()), dtype='int64')
+    index1 = relay.var("index1", shape=(1, relay.Any()), dtype='int64')
+    out = relay.adv_index([data, index0, index1])
+    mod = tvm.IRModule()
+    mod['main'] = relay.Function([data, index0, index1], out)
+    np_data_shape = (5, 5, 10)
+    np_index_shape = (1, 4)
+    np_data = np.random.uniform(size=np_data_shape).astype('float32')
+    np_index = np.random.uniform(0, np_data_shape[0], size=np_index_shape).astype('int64')
+    ref_res = np_data[tuple([np_index, np_index])]
+    check_result([np_data, np_index, np_index], mod, ref_res)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index f709aa2..98ef38d 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -1091,6 +1091,31 @@ def test_sparse_to_dense():
     #sparse_indices should not be > 2d tensor
     #verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[[3.1, 3.1, 3.1]]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
 
+def test_adv_index():
+    def verify_adv_index(data_shape, index_shapes):
+        dtype = "float32"
+        inputs = [relay.var("data", relay.TensorType(data_shape, dtype))]
+        np_data = np.random.uniform(size=data_shape).astype(dtype)
+        np_indices = []
+        for i, index_shape in enumerate(index_shapes):
+            limit = data_shape[i]
+            np_indices.append(np.random.uniform(0, limit - 1, size=index_shape).astype("int64"))
+            inputs.append(relay.var("index_{}".format(i), relay.TensorType(index_shape, "int64")))
+        np_out = np_data[tuple(np_indices)]
+        np_args = [np_data] + np_indices
+        out = relay.op.adv_index(inputs)
+
+        func = relay.Function(inputs, out)
+        for target, ctx in tvm.testing.enabled_targets():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(func)(*np_args)
+                tvm.testing.assert_allclose(op_res.asnumpy(), np_out, rtol=1e-5)
+
+    verify_adv_index((10, 5), [(3, 4), (3, 1)])
+    verify_adv_index((10, 5), [(2,),])
+    verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)])
+
 if __name__ == "__main__":
     test_cast()
     test_zeros_ones()
@@ -1127,3 +1152,4 @@ if __name__ == "__main__":
     test_unravel_index()
     test_sparse_to_dense()
     test_fixed_point_multiply()
+    test_adv_index()
diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py
index a061ba9..fc6f19f 100644
--- a/tests/python/topi/python/test_topi_transform.py
+++ b/tests/python/topi/python/test_topi_transform.py
@@ -678,6 +678,7 @@ def verify_matrix_set_diag(input_shape, dtype):
     input = te.placeholder(shape=input_shape, name="input", dtype=dtype)
     diagonal = te.placeholder(shape=diagonal_shape, name="diagonal", dtype=dtype)
     matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal)
+
     def check_device(device, ctx):
         ctx = tvm.context(device, 0)
         print("Running on target: %s" % device)
@@ -697,6 +698,40 @@ def verify_matrix_set_diag(input_shape, dtype):
     for target, ctx in tvm.testing.enabled_targets():
         check_device(target, ctx)
 
+def verify_adv_index(data_shape, index_shapes):
+    dtype = "float32"
+    data = te.placeholder(shape=data_shape, name="data", dtype=dtype)
+    indices = []
+    np_data = np.random.uniform(size=data_shape).astype(dtype)
+    np_indices = []
+    for i, index_shape in enumerate(index_shapes):
+        limit = data_shape[i]
+        np_indices.append(np.random.uniform(0, limit - 1, size=index_shape).astype("int64"))
+        indices.append(te.placeholder(shape=index_shape, name="index_{}".format(i), dtype="int64"))
+    np_out = np_data[tuple(np_indices)]
+    out = topi.adv_index(data, indices)
+
+    def check_device(device, ctx):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        with tvm.target.create(device):
+            s = tvm.topi.testing.get_injective_schedule(device)(out)
+
+        func = tvm.build(s, [data] + indices + [out], device, name="adv_index")
+
+        nd_list = [tvm.nd.array(np_data, ctx)]
+        for np_index in np_indices:
+            nd_list.append(tvm.nd.array(np_index, ctx))
+        nd_list.append(tvm.nd.empty(out.shape, ctx=ctx, dtype=data.dtype))
+
+        func(*nd_list)
+        tvm.testing.assert_allclose(nd_list[-1].asnumpy(), np.array(np_out))
+
+    for target, ctx in tvm.testing.enabled_targets():
+        check_device(target, ctx)
 
 @tvm.testing.uses_gpu
 def test_strided_slice():
@@ -1071,6 +1106,12 @@ def test_matrix_set_diag():
         verify_matrix_set_diag((4, 3, 3), dtype)
         verify_matrix_set_diag((2, 3, 4), dtype)
 
+@tvm.testing.uses_gpu
+def test_adv_index():
+    verify_adv_index((3, 4, 5), [(2,), (2, ), (1,)])
+    verify_adv_index((10, 15, 5), [(1, 1), (2, 7)])
+    verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)])
+
 if __name__ == "__main__":
     test_strided_slice()
     test_concatenate()
@@ -1097,3 +1138,4 @@ if __name__ == "__main__":
     test_unravel_index()
     test_sparse_to_dense()
     test_matrix_set_diag()
+    test_adv_index()