You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/05/20 19:26:33 UTC

[GitHub] [tvm] masahi commented on pull request #8084: [Relay, ONNX] Support gather_nd batch_dims attribute for TF/ONNX

masahi commented on pull request #8084:
URL: https://github.com/apache/tvm/pull/8084#issuecomment-845414090


   > X_M, ..., X_{N-1} is the implicit batch dimension
   
   I don't get what you meant by "implicit batch dimension". X_M, ..., X_{N-1} are axes of the input that are not indexed and thus simply copied. `batch_dims` tells from which axis the indexing starts.
   
   Our current `gather_nd` is identical with mxnet one in 
   https://mxnet.apache.org/versions/1.6/api/r/docs/api/mx.symbol.gather_nd.html, which is the same as TF `gather_nd` and ONNX `GatherND` except that indexing M tuples are in the first axis rather than the last. (There is an open request to add `batch_dims` support to the mxnet op https://github.com/apache/incubator-mxnet/issues/9998)
   
   So right now the output is
   ```
   output[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] = data[indices[0, y_0, ..., y_{K-1}],
   ...,
   indices[M-1, y_0, ..., y_{K-1}],
   x_M, ..., x_{N-1}]
   ```  
   With `batch_dims` B, it becomes (I hope it is correct but didn't check deeply)
   ```
   output[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] = data[y_0, ..., y_{B-1}, indices[0, y_0, ..., y_{K-1}],
   ...,
   indices[M-1, y_0, ..., y_{K-1}],
   x_{M+B}, ..., x_{N-1}]
   ```
   
   I'm going to update the doc to the following if this makes sense @tkonolige 
   ```
   Optionally, batch_dims, the number of batch dimensions, can be given, whose
   default value is 0.
   
   Let B denote batch_dims, and data, indices shape be (X_0, X_1, ..., X_{N-1}),
   (M, Y_0, ..., Y_{K-1}) respectively. When B > 0, indexing will start from the B-th axis,
   and it must be the case that X_0, ... X_{B-1} == Y_0, ... Y_{B-1}.
   
   The output will have shape
   (Y_0, ..., Y_{K-1}, X_{M+B}, ..., X_{N-1}), where M + B <= N. If M + B == N,
   output shape will simply be (Y_0, ..., Y_{K-1}).
   ```
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org