You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/10/27 02:25:40 UTC

[GitHub] [tvm] hgt312 edited a comment on pull request #9375: [PyTorch] [Frontend] Add support for 'aten::new_zeros' & 'aten::copy_'

hgt312 edited a comment on pull request #9375:
URL: https://github.com/apache/tvm/pull/9375#issuecomment-952479498


   @comaniac @masahi  I find that the output will not be correct due to something like `a[...] = b`, like the previous issues.
   
   In BART, it is from a function, the whole function is not inplace.
   ```
   def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
       """
       Shift input ids one token to the right.
       """
       shifted_input_ids = input_ids.new_zeros(input_ids.shape)
       shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
       shifted_input_ids[:, 0] = decoder_start_token_id
   
       assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
       # replace possible -100 values in labels by `pad_token_id`
       shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
   
       return shifted_input_ids
   ```
   
   Also, I find that after https://github.com/pytorch/pytorch/pull/52063 (torch version >= 1.9), we can use `torch._C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, None)` to move all the `aten::copys`, then the corresponding part will look like:
   ```
     %69 : Tensor = onnx::Placeholder[name="index_put_"](%62) # <ipython-input-1-662caefe3c7e>:8:0
       block0():
         %70 : Float(3, 3, strides=[3, 1], requires_grad=0, device=cpu) = aten::slice(%shifted_input_ids, %42, %43, %44, %45) # <ipython-input-1-662caefe3c7e>:8:0
         %71 : Float(3, strides=[3], requires_grad=0, device=cpu) = aten::select(%70, %47, %48) # <ipython-input-1-662caefe3c7e>:8:0
         %72 : Float(3, strides=[3], requires_grad=0, device=cpu) = aten::index_put_(%71, %66, %67, %57) # <ipython-input-1-662caefe3c7e>:8:0
         -> (%72)
   ```
   and the subgraph can be convert to ONNX's `index_put`.
   
   Maybe the torch->onnx path will work for these models?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org