You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/09/30 01:48:24 UTC

[GitHub] [incubator-tvm] masahi opened a new issue #6594: [Torch, CI] Support PyTorch 1.6

masahi opened a new issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594


   Let's summarize steps needed to update our PyTorch support for the latest version, 1.6 (the CI is at 1.4). 
   
   - [ ] Quantization support: It is completely broken due to a bug introduced in 1.6 (see https://github.com/pytorch/pytorch/issues/42497). I found a simple workaround for that problem (finally!) and the fix is WIP. The representation of quantized weights and APIs of quantized ops have changed. To support both 1.6 and older versions, we need to handle both cases (with version check) for now.
   
   - [ ] torchvision detection model support: Need to add converters for `aten::tensor`, `aten::empty`, `aten::numel` and `aten::dim`. I found that other missing ops I mentioned in https://github.com/apache/incubator-tvm/pull/6449#issuecomment-693023530 can be removed using `torch._C._jit_pass_onnx_function_substitution(...)` pass.
   
   - [ ] Upgrade CI docker image 
   
   cc @t-vi @siju-samuel @kevinthesun @yongwww @tqchen  


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi edited a comment on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706472689


   @kevinthesun @jroesch @lixiaoquan @MarisaKirisame I found that this change was introduced in https://github.com/apache/incubator-tvm/pull/5795
   
   If I make that above change, that effectively undos that PR and breaks the test case introduced there. Basically, what we want from type checking `If` involving both static and dynamic shape are complete opposite to the motivation of #5795.
   
   What should we do about this? I think #5795 should be reverted, since type inference is supposed to pick the most general types.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-701936601


   @yongwww It seems torchvision nms op https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py#L35 returns int64 indices, while nms in Relay returns int32 bit indices. We need to cast the output indices to int64, that should resolve the one of typing problem (two branches of `If` having different types, int32 and int64). 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] kevinthesun commented on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706467574


   @masahi 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi edited a comment on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-701877144


   @kevinthesun @yongwww I've taken a look at supporting detection models in 1.6. It turned out we need to convert a function `batched_nms` that was traced in 1.4, but now it is scripted. This brings a very nasty typing problem.
   
   Looking at this code,
   https://github.com/pytorch/vision/blob/6e10e3f88158f12b7a304d3c2f803d2bbdde0823/torchvision/ops/boxes.py#L75-L86
   
   One of the types of if branch is a zero dim empty tensor of dtype int64 https://github.com/pytorch/vision/blob/6e10e3f88158f12b7a304d3c2f803d2bbdde0823/torchvision/ops/boxes.py#L76
   while the type of other branch is a dynamic 1D tensor of dtype int32 https://github.com/pytorch/vision/blob/6e10e3f88158f12b7a304d3c2f803d2bbdde0823/torchvision/ops/boxes.py#L85-L86
   
   We cannot type check the converted Relay model with `If`, because dtype of two branches are different. Even if I force the dtype to be the same, if one branch is `Any` shape while the other is static, Relay type inference chooses a static shape for the type of `If`. So the return type of above function becomes a zero dim tensor. 
   
   Is there a way to turn a static shape tensor to dynamic, when one of `If` branch is dynamic while other is static?  


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] kevinthesun edited a comment on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
kevinthesun edited a comment on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706474453


   Hmmm, I think that PR is an optimization for some cases but is actually not correct since static dim is just a case of ```Any``` but not vise versa. The case provided in that PR only stands when ```False``` is fed with (Any, 1) tensor at runtime. 
   
   In TF frontend there are lots of cases when we want to make expression as static as possible. My guessing is that ```False``` actually has shape (Any, 1) but somehow not fully optimized to be less dynamic. Usually such kind of optimization is done in the frontend.
   
   @lixiaoquan We should revert it?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] kevinthesun edited a comment on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
kevinthesun edited a comment on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706467574


   @masahi 
   
   ```
   import tvm  
   from tvm import relay
   
   dtype = "float32"
   branch_a = relay.var("a", shape=(relay.Any(),), dtype=dtype)
   branch_b = relay.var("b", shape=(0,), dtype=dtype)
   cond = relay.var("cond", shape=(), dtype="bool")
   
   out = relay.If(cond, branch_a, branch_b)
   
   mod = tvm.IRModule()
   mod["main"] = relay.Function([cond, branch_a,branch_b], out)
   mod = relay.transform.InferType()(mod)
   print(mod["main"])
   ```
   Will generate
   ```
   fn (%cond: bool, %a: Tensor[(0), float32], %b: Tensor[(0), float32]) -> Tensor[(0), float32] {
     if (%cond) {
       %a
     } else {
       %b
     }
   }
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi edited a comment on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706484241


   I've sent the revert to https://github.com/apache/incubator-tvm/pull/6658
   
   With this and other minor fixes, I can finally run torchvision detection models from v1.6. See https://github.com/apache/incubator-tvm/pull/6659. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706471804


   Finding suspicious code was easy, here, when lhs is `IntImmNode` and rhs is `AnyNode`, `ulhs` is returned.
   
   https://github.com/apache/incubator-tvm/blob/98c2096f4944bdbdbbb2b7b20ccd35c6c11dfbf6/src/relay/analysis/type_solver.cc#L195-L198
   
   Returning `urhs` there seems to fix this issue. I'm not sure if this is intended or a bug. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi edited a comment on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706472689


   @kevinthesun @jroesch @lixiaoquan @MarisaKirisame I found that this change was introduced in https://github.com/apache/incubator-tvm/pull/5795
   
   If I make that above change, that effectively undos that PR and breaks the test case introduced there. Basically, what we want from type checking `If` involving both static and dynamic shape are complete opposite to the motivation of #5795.
   
   What should we do about this? 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi edited a comment on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706472689


   @kevinthesun @jroesch @lixiaoquan @MarisaKirisame I found that this change was introduced in https://github.com/apache/incubator-tvm/pull/5795
   
   If I make that above change, that effectively undos that PR and breaks the test case introduced there. Basically, what we want from type checking `If` involving both static and dynamic shape are complete opposite to the motivation of #5795.
   
   What should we do about this? Since type inference is supposed to pick the most general types, I think what we want is the right one and the change in #5795 should be reverted.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] kevinthesun edited a comment on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
kevinthesun edited a comment on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706474453


   Hmmm, I think that PR is an optimization for some cases but is actually not correct since static dim is just a case of ```Any``` but not vise versa. The case provided in that PR only stands when ```False``` is fed with (Any, 1) tensor at runtime. @lixiaoquan We should revert it?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] kevinthesun commented on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706474453


   Hmmm, I think that PR is an optimization for some cases but is actually not correct. @lixiaoquan We should revert it?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen closed issue #6594: [Torch, CI] Support PyTorch 1.7

Posted by GitBox <gi...@apache.org>.
tqchen closed issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi edited a comment on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-701936601


   @yongwww It seems torchvision nms op https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py#L35 returns int64 indices, while nms in Relay returns int32 bit indices. We need to cast the output indices to int64, that should resolve the one of typing problem (two branches of `If` having different dtypes, int32 and int64). 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706472689


   @kevinthesun @jroesch @lixiaoquan I found that this change was introduced in https://github.com/apache/incubator-tvm/pull/5795
   
   If I make that above change, that effectively undos that PR and breaks the test case introduced there. Basically, what we want from type checking `If` involving both static and dynamic shape are complete opposite to the motivation of #5795.
   
   What should we do about this?  


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] kevinthesun commented on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706461452


   @masahi That shape problem coming from ```If``` branch looks like a Relay type inference issue to me. Type inference should generate dynamic shape in this case.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-701877144


   @kevinthesun @yongwww I've taken a look at supporting detection models in 1.6. It turned out we need convert a function `batched_nms` that was traced in 1.4, but now it is scripted. This brings a very nasty typing problem.
   
   Looking at this code,
   https://github.com/pytorch/vision/blob/6e10e3f88158f12b7a304d3c2f803d2bbdde0823/torchvision/ops/boxes.py#L75-L86
   
   One of the types of if branch is a zero dim empty tensor of dtype int64 https://github.com/pytorch/vision/blob/6e10e3f88158f12b7a304d3c2f803d2bbdde0823/torchvision/ops/boxes.py#L76
   while the type of other branch is a dynamic 1D tensor of dtype int32 https://github.com/pytorch/vision/blob/6e10e3f88158f12b7a304d3c2f803d2bbdde0823/torchvision/ops/boxes.py#L85-L86
   
   We cannot type check the converted Relay model with `If`, because dtype of two branches are different. Even if I force the dtype to be the same, if one branch is `Any` shape while the other is static, Relay type inference chooses a static shape for the type of `If`. So the return type of above function becomes a zero dim tensor. 
   
   Is there a way to turn a static shape tensor to dynamic, when one of `If` branch is dynamic while other is static?  


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi edited a comment on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706472689


   @kevinthesun @jroesch @lixiaoquan @MarisaKirisame I found that this change was introduced in https://github.com/apache/incubator-tvm/pull/5795
   
   If I make that above change, that effectively undos that PR and breaks the test case introduced there. Basically, what we want from type checking `If` involving both static and dynamic shape are complete opposite to the motivation of #5795.
   
   What should we do about this? Since type inference is supposed to pick the most general types, I think the want we want is the right one the change in #5795 should be reverted.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] kevinthesun edited a comment on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
kevinthesun edited a comment on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706467574


   @masahi 
   
   ```
   import tvm  
   from tvm import relay
   
   dtype = "float32"
   branch_a = relay.var("a", shape=(relay.Any(),), dtype=dtype)
   branch_b = relay.var("b", shape=(0,), dtype=dtype)
   cond = relay.var("cond", shape=(), dtype="bool")
   
   out = relay.If(cond, branch_a, branch_b)
   
   mod = tvm.IRModule()
   mod["main"] = relay.Function([cond, branch_a,branch_b], out)
   mod = relay.transform.InferType()(mod)
   print(mod["main"])
   ```
   Will generate
   ```
   fn (%cond: bool, %a: Tensor[(0), float32], %b: Tensor[(0), float32]) -> Tensor[(0), float32] {
     if (%cond) {
       %a
     } else {
       %b
     }
   }
   ```
   while input ```a``` should be with shape(?,). There is an issue when inferring such an expression.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706484241


   I've sent the revert to https://github.com/apache/incubator-tvm/pull/6658
   
   With this and other minor fixes, I can finally run torchvision detection models from v1.6


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] kevinthesun edited a comment on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
kevinthesun edited a comment on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706474453


   Hmmm, I think that PR is an optimization for some cases but is actually not correct since static dim is just a case of ```Any``` but not vise versa. @lixiaoquan We should revert it?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi edited a comment on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706472689


   @kevinthesun @jroesch @lixiaoquan @MarisaKirisame I found that this change was introduced in https://github.com/apache/incubator-tvm/pull/5795
   
   If I make that above change, that effectively undos that PR and breaks the test case introduced there. Basically, what we want from type checking `If` involving both static and dynamic shape are complete opposite to the motivation of #5795.
   
   What should we do about this?  


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] masahi commented on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706469099


   @kevinthesun thanks for pointing this out and providing a minimum test case. I'm going to take a look at what's going on inside type checker and hopefully fix it. With this fixed, I can send a PR to support PyTorch 1.6 mask rcnn and faster rcnn.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] kevinthesun edited a comment on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
kevinthesun edited a comment on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706474453


   Hmmm, I think that PR is an optimization for some cases but is actually not correct since static dim is just a case of ```Any``` but not vise versa. The case provided in that PR only stands when ```False``` is fed with (Any, 1) tensor at runtime. 
   
   In TF frontend there are lots of cases when we want to make expression as static as possible. Usually such kind of optimization is done in the frontend.
   
   @lixiaoquan We should revert it?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] kevinthesun edited a comment on issue #6594: [Torch, CI] Support PyTorch 1.6

Posted by GitBox <gi...@apache.org>.
kevinthesun edited a comment on issue #6594:
URL: https://github.com/apache/incubator-tvm/issues/6594#issuecomment-706474453


   Hmmm, I think that PR is an optimization for some cases but is actually not correct since static dim is just a case of ```Any```. @lixiaoquan We should revert it?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org