You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/04/19 02:41:28 UTC

[GitHub] [tvm] eleflea opened a new issue #7878: PyTorch quantized group conv2d raise an error when converting to tvm

eleflea opened a new issue #7878:
URL: https://github.com/apache/tvm/issues/7878


   Hi
   I find a bug of relay. I used pytorch to quantize a grouped convolution model, and an error was reported in `relay.frontend.from_pytorch` function. Detailed as follows.
   ```
   import tvm
   from tvm import relay
   import torch
   from torch import nn
   from torch import quantization
   
   GROUPS = 4
   
   class Net(nn.Module):
   
       def __init__(self):
           super().__init__()
           self.quant = quantization.QuantStub()
           self.dequant = quantization.DeQuantStub()
           self.gconv = nn.Conv2d(12, 24, 3, groups=GROUPS, bias=False)
   
       def forward(self, x):
           x = self.quant(x)
           return self.dequant(self.gconv(x))
   
   net = Net()
   
   net.eval()
   net.qconfig = torch.quantization.get_default_qconfig('fbgemm')
   net = torch.quantization.prepare(net, inplace=False)
   net = torch.quantization.convert(net, inplace=False)
   
   inp = torch.randn(1, 12, 32, 32)
   script_module = torch.jit.trace(net, inp).eval()
   
   input_name = "input"  # the input name can be be arbitrary for PyTorch frontend.
   input_shapes = [(input_name, (1, 12, 32, 32))]
   mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
   
   target = tvm.target.cuda()
   with tvm.transform.PassContext(opt_level=3):
       lib = relay.build_module.build(mod, target=target, params=params)
   print('finish')
   ```
   
   It raise:
   
   ```
   The Relay type checker is unable to show the following types match.
   In particular dimension 0 conflicts: 72 does not match 24.
   The Relay type checker is unable to show the following types match.
   In particular `Tensor[(24), float32]` does not match `Tensor[(72), float32]`
   Traceback (most recent call last):
     File "bug.py", line 33, in <module>
       mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
     File "/home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 3238, in from_pytorch
       ret = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)[0]
     File "/home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 2662, in convert_operators
       self.record_output_type(relay_out)
     File "/home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 222, in record_output_type
       self.infer_type_with_prelude(output)
     File "/home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 170, in infer_type_with_prelude
       body = self.infer_type(val, self.prelude.mod)
     File "/home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 163, in infer_type
       new_mod = transform.InferType()(new_mod)
     File "/home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/ir/transform.py", line 127, in __call__
       return _ffi_transform_api.RunPass(self, mod)
     File "tvm/_ffi/_cython/./packed_func.pxi", line 322, in tvm._ffi._cy3.core.PackedFuncBase.__call__
     File "tvm/_ffi/_cython/./packed_func.pxi", line 257, in tvm._ffi._cy3.core.FuncCall
     File "tvm/_ffi/_cython/./packed_func.pxi", line 246, in tvm._ffi._cy3.core.FuncCall3
     File "tvm/_ffi/_cython/./base.pxi", line 160, in tvm._ffi._cy3.core.CALL
   tvm.error.DiagnosticError: Traceback (most recent call last):
     [bt] (6) /home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/libtvm.so(TVMFuncCall+0x5b) [0x7fd44602622b]
     [bt] (5) /home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/libtvm.so(+0x80b06a) [0x7fd4454b306a]
     [bt] (4) /home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/libtvm.so(tvm::transform::Pass::operator()(tvm::IRModule) const+0xcd) [0x7fd4454b24ad]
     [bt] (3) /home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/libtvm.so(tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x1b7) [0x7fd4454b1c27]
     [bt] (2) /home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/libtvm.so(+0x1195018) [0x7fd445e3d018]
     [bt] (1) /home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/libtvm.so(tvm::DiagnosticContext::Render()+0x199) [0x7fd44545ecc9]
     [bt] (0) /home/eleflea/.local/lib/python3.7/site-packages/tvm-0.8.dev704+g3a0e3a5bb-py3.7-linux-x86_64.egg/tvm/libtvm.so(+0x7b5d22) [0x7fd44545dd22]
     File "/home/eleflea/code/tvm/src/ir/diagnostic.cc", line 105
   DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output.
   ```
   
   But when `GROUPS = 1`, it works, so i think it is related to quantized group conv2d.
   Looking forward to your reply! Thank you.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi edited a comment on issue #7878: PyTorch quantized group conv2d raise an error when converting to tvm

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #7878:
URL: https://github.com/apache/tvm/issues/7878#issuecomment-822278347


   @anijain2305 Is group conv supported by QNN? The following line that does multiplication seems incorrect for group conv case. Here, the weight shape is (24, 3, 3, 3), and multiplying 24 * 3 results in the error message above because weight scale shape is (24,).
   https://github.com/apache/tvm/blob/813136401a11a49d6c15e6013c34dd822a5c4ff6/src/relay/qnn/op/convolution.cc#L81
   
   @eleflea For now you can do per tensor quantization to workaround this problem (the error happens if you use per channel weight quantization by `get_default_qconfig('fbgemm')`). You can force per tensor Q by `get_default_qconfig('qnnpack')` for example.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on issue #7878: PyTorch quantized group conv2d raise an error when converting to tvm

Posted by GitBox <gi...@apache.org>.
tqchen commented on issue #7878:
URL: https://github.com/apache/tvm/issues/7878#issuecomment-826801482


   Thanks for asking the question, the community uses for trouble shooting and discussions, please bring a new discussion topic on https://discuss.tvm.apache.org/, where more people will be able to watch and answer the questions.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen closed issue #7878: PyTorch quantized group conv2d raise an error when converting to tvm

Posted by GitBox <gi...@apache.org>.
tqchen closed issue #7878:
URL: https://github.com/apache/tvm/issues/7878


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on issue #7878: PyTorch quantized group conv2d raise an error when converting to tvm

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #7878:
URL: https://github.com/apache/tvm/issues/7878#issuecomment-822278347


   @anijain2305 Is group conv supported by QNN? The following line that does multiplication seems incorrect for group conv case:
   https://github.com/apache/tvm/blob/813136401a11a49d6c15e6013c34dd822a5c4ff6/src/relay/qnn/op/convolution.cc#L81
   
   @eleflea For now you can do per tensor quantization to workaround this problem (the error happens if you use per channel weight quantization by `get_default_qconfig('fbgemm')`). You can force per tensor Q by `get_default_qconfig('qnnpack')` for example.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org