You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "Thrsu (via GitHub)" <gi...@apache.org> on 2023/10/15 10:35:27 UTC

[PR] [Unity] [Bugfix] Fix bug in interpolate operator's default mode parameter in PyTorch frontend [tvm]

Thrsu opened a new pull request, #15933:
URL: https://github.com/apache/tvm/pull/15933

   This PR fixes a bug in the interpolate operator of the PyTorch frontend in TVM. The bug was caused by incorrectly using the `method` keyword instead of the `mode` keyword when retrieving the default value for the mode parameter. This resulted in incorrect computation of `interpolate` results.
   
   This bug can be reproduced by the below script:
   ```python
   import torch
   from torch import fx
   from torch.nn import Module
   import tvm
   import tvm.testing
   from tvm import relax
   from tvm.relax.frontend.torch import from_fx
   
   input_data = torch.randn([1, 2, 4, 4], dtype=torch.float32)
   class interpolate(Module):
       def forward(self, input):
           return torch.nn.functional.interpolate(input, size=None, scale_factor=2.0, mode='bilinear', align_corners=False,)
   
   model = interpolate().float()
   input_data = [input_data]
   input_names = [f"input{idx}" for idx, _ in enumerate(input_data)]
   input_info = list(zip([list(inp.shape) for inp in input_data], [str(inp.dtype) for inp in input_data]))
   fx_model : torch.fx.GraphModule = fx.symbolic_trace(model)
   with torch.no_grad():
       mod = from_fx(fx_model, input_info)
   
   if torch.cuda.is_available():
       model = model.cuda()
       torch_input = [inp.cuda() for inp in input_data]
   
   with torch.no_grad():
       torch_outputs = model(*[input.clone() for input in input_data])
   torch_outputs = (torch_outputs.cpu().numpy(),)
   
   compiled_input = dict(zip(input_names, [inp.clone().cpu().numpy() for inp in input_data]))
   tvm_input = {}
   for name, inp in compiled_input.items():
       tvm_input[name] = tvm.nd.array(inp)
       target = tvm.target.Target("llvm", host="llvm")
       mod = relax.transform.LegalizeOps()(mod)
       ex = relax.build(mod, target)
       vm = relax.VirtualMachine(ex, tvm.cpu())
   
       tvm_outputs = vm["main"](*[tvm_input[name] for name in input_names])
       tvm_outputs = [tvm_outputs]
   
   for i, torch_output in enumerate(torch_outputs):
       output = tvm_outputs[i].numpy()
       tvm.testing.assert_allclose(torch_output, output, rtol=1e-5, atol=1e-5)
   ```
   
   And here is the traceback information:
   ```
   Traceback (most recent call last):
   ...  
       tvm.testing.assert_allclose(torch_output, output, rtol=1e-5, atol=1e-5)
     File "/workplace/software/tvm/tvm/python/tvm/testing/utils.py", line 120, in assert_allclose
       np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, verbose=True)
     File "/workplace/software/miniconda3/envs/tflite/lib/python3.8/site-packages/numpy/testing/_private/utils.py", line 1527, in assert_allclose
       assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
     File "/workplace/software/miniconda3/envs/tflite/lib/python3.8/site-packages/numpy/testing/_private/utils.py", line 844, in assert_array_compare
       raise AssertionError(msg)
   AssertionError: 
   Not equal to tolerance rtol=1e-05, atol=1e-05
   
   Mismatched elements: 120 / 128 (93.8%)
   Max absolute difference: 1.2100129
   Max relative difference: 39.08069
    x: array([[[[-0.407674, -0.087851,  0.551796,  0.303579, -0.832501,
             -0.630166,  0.910585,  1.680961],
            [-0.152558, -0.054653,  0.141159, -0.0371  , -0.589427,...
    y: array([[[[-0.407674, -0.407674,  0.871619,  0.871619, -1.400542,
             -1.400542,  1.680961,  1.680961],
            [-0.407674, -0.407674,  0.871619,  0.871619, -1.400542,...
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Re: [PR] [Unity] [Bugfix] Fix bug in interpolate operator's default mode parameter in PyTorch frontend [tvm]

Posted by "Hzfengsy (via GitHub)" <gi...@apache.org>.
Hzfengsy merged PR #15933:
URL: https://github.com/apache/tvm/pull/15933


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org