You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/09/14 18:48:12 UTC

[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8971: [Relay][Topi][Op][Onnx] Support negative indices in gather, gathernd. Fixes NLL Loss tests

AndrewZhaoLuo commented on a change in pull request #8971:
URL: https://github.com/apache/tvm/pull/8971#discussion_r708540484



##########
File path: src/te/tensor.cc
##########
@@ -39,15 +39,26 @@ IterVar reduce_axis(Range dom, std::string name) { return IterVar(dom, Var(name)
 Var var(std::string name_hint, DataType t) { return Var(name_hint, t); }
 
 // Tensor
-PrimExpr Tensor::operator()(Array<Var> indices) const {
+PrimExpr Tensor::operator()(Array<Var> indices, bool support_negative_indices) const {
   Array<PrimExpr> arr(indices.begin(), indices.end());
-  return operator()(arr);
+  return operator()(arr, support_negative_indices);
 }
 
-PrimExpr Tensor::operator()(Array<PrimExpr> indices) const {
-  if (ndim() != 0) {
-    ICHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read "
-                                      << "ndim = " << ndim() << ", indices.size=" << indices.size();
+PrimExpr Tensor::operator()(Array<PrimExpr> indices, bool support_negative_indices) const {
+  Array<PrimExpr> shape = (*this)->shape;
+
+  if (shape.size() != 0) {
+    ICHECK_EQ(shape.size(), indices.size())
+        << "Tensor dimension mismatch in read "
+        << "ndim = " << ndim() << ", indices.size=" << indices.size();
+  }
+
+  if (support_negative_indices) {
+    for (size_t i = 0; i < shape.size(); i++) {
+      PrimExpr new_index = if_then_else(indices[i] < make_const(indices[i]->dtype, 0),
+                                        indices[i] + shape[i], indices[i]);
+      indices.Set(i, new_index);

Review comment:
       Hmm this is a good point. I think pushing down the stack is the right choice personally since I expect the most basic indexing op to work with negative indices. Since all of the other operations will use these basic indexing ops we should therefore get these things for free. In our case, we add a flag to a basic indexing operation which turns on this features.
   
   Otherwise we'll get a lot of copies of the same code everywhere.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org