You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/08/22 06:03:55 UTC

[GitHub] [tvm] crawlingcub opened a new issue, #12524: [Bug] Lower Accuracy with a model compared to pytorch

crawlingcub opened a new issue, #12524:
URL: https://github.com/apache/tvm/issues/12524

   Hi,
   
   I used a variant of pre-trained Alexnet model with a subset of cifar100 dataset. I am getting lower accuracy with TVM compared to PyTorch. I made two modifications to the original model. I changed a `ReLU` layer to `LogSigmoid` and added some noise to a `Conv2d` layer. I have more examples where using the `LogSigmoid` sometimes lead to different accuracy. Is this a bug?
   
   Converting from an onnx model also leads to same result.
   
   Please find the code/model/dataset below. Let me know if you need more info. Thanks!
   
   ### Expected behavior
   
   Accuracy should be same
   
   ### Actual behavior
   
   Accuracy with TVM is lower: (Acc1 and Acc5 numbers)
   ```
   Running validation...
   
   Pytorch: 2.5 3.5
   TVM:     0.1 1.9
   ```
   
   ### Environment
   
   torch 1.8.0+cu111
   torchvision 0.9.0+cu111
   TVM version: latest ecbe4ca0edadeca8fee4d0c2c9f7a9093043b5ee
   Python 3.7.12
   
   
   ### Steps to reproduce
   Download the model and data files [here](https://drive.google.com/drive/folders/1_AOA-9hA1I92vBUeQMbZ6F-2r69_aIrF?usp=sharing)
   
   Python script to reproduce:
   ```python
   import sys
   import os
   import torch
   
   
   from torch.utils.data import DataLoader
   import sys
   from torchvision import datasets
   from torchvision.transforms import transforms
   import numpy as np
   
   import os
   import pickle
   
   import torch
   
   import metrics
   
   import sys
   
   def accuracy(output, target, topk=(1,)):
       """Computes the accuracy over the k top predictions for the specified values of k"""
       with torch.no_grad():
           maxk = max(topk)
           batch_size = target.size(0)
   
           _, pred = output.topk(maxk, 1, True, True)
           pred = pred.t()
           correct = pred.eq(target.view(1, -1).expand_as(pred))
   
           res = []
           for k in topk:
               correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
               res.append(correct_k.mul_(100.0 / batch_size))
           return res
   
   def eval_model_tvm(model, dataset, device, batch_size):
       import tvm
       from tvm import relay
       from tvm.contrib.download import download_testdata
       from tvm.contrib import graph_executor
       import logging
       logger = logging.getLogger('compile_engine')
       logger.setLevel(logging.ERROR)
       
       validation_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
       if "cpu" in device.lower():
           target = tvm.target.Target("llvm", host="llvm")
       else:
           target = tvm.target.cuda()
       print("target", target)
       dev = tvm.device(str(target))
       model = model.to("cpu")
       model.eval()
       mod = None
       lib = None
       acc1s = []
       acc5s = []
       for i, (images, targets) in enumerate(validation_dataloader):
           input_name = "input0"
           if mod is None:
               scripted_model = torch.jit.trace(model, images).eval()
               print("scripted")
               input_data = np.array([images[i].data.numpy() for i in range(len(images))], dtype="float32")
               shape_list = [(input_name, input_data.shape)]
               mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
               
               with tvm.transform.PassContext(opt_level=3):
                   lib = relay.build(mod, target=target, params=params)
                   
           m = graph_executor.GraphModule(lib["default"](dev))
           m.set_input(input_name, tvm.nd.array(images))
           m.run()
           output = torch.tensor(m.get_output(0).asnumpy())
           acc1, acc5 = accuracy(output, targets, topk=(1, 5))
   
           acc1s.append(acc1.item())
           acc5s.append(acc5.item())
           
       
       return {'acc1': np.mean(acc1s), 'acc5': np.mean(acc5s)}
   
   
   def eval_model_vision(model, dataset, device, criterion, compute_metrics_fn, batch_size):
       print("Running validation...")
       from tqdm import tqdm
           
       if not isinstance(model, torch.nn.DataParallel):
           model = torch.nn.DataParallel(model)
       if not isinstance(dataset, DataLoader):
           validation_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
       else:
           validation_dataloader = dataset
       acc1s = []
       acc2s = []
       losses = []
       model.to(device)
       model.eval()
   
       with torch.no_grad():        
           for i, (images, target) in tqdm(enumerate(validation_dataloader), total=len(validation_dataloader)):
               # compute output
               images = images.to(device)
               target = target.to(device)
   
               output = model(images)
               loss = criterion(output, target)
   
               # measure accuracy and record loss
               acc1, acc5 = compute_metrics_fn(output, target, topk=(1, 5))
               acc1s.append(acc1.item())
               acc2s.append(acc5.item())
               losses.append(loss.item())
               #if i % 10 == 0:
               #    print(i, loss)
   
       return {'acc1': np.mean(acc1s), 'acc5': np.mean(acc2s)}
   if __name__ == '__main__':
       DEVICE='cuda'
       model=torch.load(sys.argv[1])
       data=torch.load(sys.argv[2])
       criterion=torch.nn.CrossEntropyLoss()
       batch_size=10
       
       results_torch = eval_model_vision(model,
                                         data,
                                         device=DEVICE,
                                         criterion=criterion,
                                         compute_metrics_fn=accuracy,
                                         batch_size=batch_size)
   
       
       print(results_torch["acc1"], results_torch["acc5"])
   
       res2= eval_model_tvm(model, data, DEVICE, batch_size=batch_size)
       print(res2["acc1"], res2["acc5"])
   
   
   
   ```
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on issue #12524: [Bug] Lower Accuracy with a model compared to pytorch

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #12524:
URL: https://github.com/apache/tvm/issues/12524#issuecomment-1221933724

   If you believe that the results of TVM and PyTorch are different, please provide a minimal script, rather than a script like yours that does evaluation loop.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on issue #12524: [Bug] Lower Accuracy with a model compared to pytorch

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #12524:
URL: https://github.com/apache/tvm/issues/12524#issuecomment-1223554340

   Interesting. Is there a known numerical issue if we translate PT `log_sigmoid` simply by `log(sigmoid(x))`? Otherwise I have no idea where the difference is coming from.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on issue #12524: [Bug] Lower Accuracy with a model compared to pytorch

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #12524:
URL: https://github.com/apache/tvm/issues/12524#issuecomment-1221975095

   This is an example of "minimal script":
   
   ```
   import torch
   import tvm
   from tvm import relay
   from tvm.contrib.download import download_testdata
   from tvm.contrib import graph_executor
   
   import numpy as np
   
   
   DEVICE='cpu'
   model=torch.load("model.pt", map_location=torch.device('cpu'))
   data=torch.load("data2.pt", map_location=torch.device('cpu'))
   
   pt_inp = torch.unsqueeze(data[0][0], 0)
   
   with torch.no_grad():
       pt_out = model(pt_inp).numpy()
   
   input_name = "input0"
   target = tvm.target.Target("llvm", host="llvm")
   dev = tvm.cpu(0)
   scripted_model = torch.jit.trace(model, pt_inp).eval()
   input_data = pt_inp.numpy()
   shape_list = [(input_name, input_data.shape)]
   mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
   
   # print(mod)
   
   with tvm.transform.PassContext(opt_level=3):
       lib = relay.build(mod, target=target, params=params)
   
   m = graph_executor.GraphModule(lib["default"](dev))
   m.set_input(input_name, tvm.nd.array(pt_inp.numpy()))
   m.run()
   output = torch.tensor(m.get_output(0).asnumpy())
   
   tvm_out = m.get_output(0).asnumpy()
   
   # print(np.max(np.abs(tvm_out - pt_out)), np.mean(np.abs(tvm_out - pt_out)))
   print(pt_out)
   ```
   
   Running this script, I see PT output like 
   ```
   [[-4.39299689e+36  3.77933626e+35 -4.14227049e+36 -4.03090327e+36
     -4.95762177e+36 -1.81598428e+36 -1.35176413e+36  1.13814440e+36
     -4.90326143e+35  1.11240001e+36 -3.27343576e+36 -4.31897940e+36
     -3.87017024e+36 -1.74157477e+36  5.59471333e+35 -2.40675192e+36
     -1.40862270e+36 -4.29128504e+36  4.14507723e+35 -5.33183350e+36
     -5.69302083e+35 -4.26776790e+36  2.10849216e+35 -3.88321310e+36
     -9.52957448e+35  9.93559742e+35  4.13343307e+35 -2.29166969e+36
     -1.85837942e+36 -1.26610193e+36 -2.77660958e+35 -8.23837505e+36
      1.42649498e+36 -1.72732131e+36 -4.55971259e+36 -4.42550959e+36
     -3.68109445e+36 -4.73053168e+36 -1.91429644e+36  4.43432768e+36
     -6.18386068e+35 -3.07289027e+36 -3.72914158e+36 -9.81286935e+36
     -1.02773259e+36  1.03946589e+36 -2.20454802e+36 -2.38695376e+36
   ```
   
   Is this expected?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] crawlingcub commented on issue #12524: [Bug] Lower Accuracy with a model compared to pytorch

Posted by GitBox <gi...@apache.org>.
crawlingcub commented on issue #12524:
URL: https://github.com/apache/tvm/issues/12524#issuecomment-1223556813

   There was this fix in pytorch for numerically stable logsigmoid: https://github.com/pytorch/pytorch/pull/2211 
   Maybe something similar needs to be done here?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on issue #12524: [Bug] Lower Accuracy with a model compared to pytorch

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #12524:
URL: https://github.com/apache/tvm/issues/12524#issuecomment-1221938896

   I think any data would give lead to different results for TVM and PT, if your claim is True.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] crawlingcub commented on issue #12524: [Bug] Lower Accuracy with a model compared to pytorch

Posted by GitBox <gi...@apache.org>.
crawlingcub commented on issue #12524:
URL: https://github.com/apache/tvm/issues/12524#issuecomment-1221953601

   > I tried your script on PT 1.11 but got an error from somewhere irrelevant.
   
   I think that was compatibility error with pt 1.11. I uploaded `data2.pt` to that same link. Can you try with that?
   
   > I think any data would give lead to different results for TVM and PT, if your claim is True.
   
   Not all predictions are different. If you try with data point 0, you will see the output is different. Just added ` data=torch.utils.data.Subset(data, [0])` before running. 
   Let me know if you have other suggestions. Thanks!
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] crawlingcub commented on issue #12524: [Bug] Lower Accuracy with a model compared to pytorch

Posted by GitBox <gi...@apache.org>.
crawlingcub commented on issue #12524:
URL: https://github.com/apache/tvm/issues/12524#issuecomment-1221936262

   Do u mean find one or two data points for which the results are different?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] crawlingcub commented on issue #12524: [Bug] Lower Accuracy with a model compared to pytorch

Posted by GitBox <gi...@apache.org>.
crawlingcub commented on issue #12524:
URL: https://github.com/apache/tvm/issues/12524#issuecomment-1224692072

   Yes, that fixes it! :) I see exact same accuracy now! Should I raise a PR? 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on issue #12524: [Bug] Lower Accuracy with a model compared to pytorch

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #12524:
URL: https://github.com/apache/tvm/issues/12524#issuecomment-1224790611

   You can try, for example `pytest tests/python/frontend/pytorch/test_forward.py -k log_sigmoid -s`


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] crawlingcub commented on issue #12524: [Bug] Lower Accuracy with a model compared to pytorch

Posted by GitBox <gi...@apache.org>.
crawlingcub commented on issue #12524:
URL: https://github.com/apache/tvm/issues/12524#issuecomment-1224778517

   ok, naive question: how do i run a specific test locally?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on issue #12524: [Bug] Lower Accuracy with a model compared to pytorch

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #12524:
URL: https://github.com/apache/tvm/issues/12524#issuecomment-1223566223

   Yeah, that PR is very old, but their current implementation in https://github.com/pytorch/pytorch/blob/a85d1f0bcdd02cf18d3b0517337458cb51a18cdb/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu#L32-L35 is not differently not a simple composition of log and sigmoid.
   
   I tried replacing https://github.com/apache/tvm/blob/f64a3bda253f4220d66eeb3348f93f486392cb8e/python/tvm/relay/frontend/pytorch.py#L912-L914 by 
   
   ```
       def log_sigmoid(self, inputs, input_types):
           data = inputs[0]
           mn = _op.minimum(_op.const(0, dtype=input_types[0]), data)
           z = _op.exp(-_op.abs(data))
           return mn - self.log1p([z], input_types)
   
   ```
   
   following the PT code. TVM results are now also huge, but still don't agree with PT. Can you take it from here?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi closed issue #12524: [Bug] Lower Accuracy with a model compared to pytorch

Posted by GitBox <gi...@apache.org>.
masahi closed issue #12524: [Bug] Lower Accuracy with a model compared to pytorch
URL: https://github.com/apache/tvm/issues/12524


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] crawlingcub commented on issue #12524: [Bug] Lower Accuracy with a model compared to pytorch

Posted by GitBox <gi...@apache.org>.
crawlingcub commented on issue #12524:
URL: https://github.com/apache/tvm/issues/12524#issuecomment-1225060994

   Hi, I added the new test. All the sigmoid tests pass (locally) with the new change. Can you review?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] crawlingcub commented on issue #12524: [Bug] Lower Accuracy with a model compared to pytorch

Posted by GitBox <gi...@apache.org>.
crawlingcub commented on issue #12524:
URL: https://github.com/apache/tvm/issues/12524#issuecomment-1223524294

   Yes, thats what I am seeing as well. With argmax, torch's output label is 39, while for tvm it is 33. Seems like some numerical error is happening here?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] crawlingcub commented on issue #12524: [Bug] Lower Accuracy with a model compared to pytorch

Posted by GitBox <gi...@apache.org>.
crawlingcub commented on issue #12524:
URL: https://github.com/apache/tvm/issues/12524#issuecomment-1223549827

   Yes, the outputs changed only when I replaced relu with logsigmoid. They were exactly same before that.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on issue #12524: [Bug] Lower Accuracy with a model compared to pytorch

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #12524:
URL: https://github.com/apache/tvm/issues/12524#issuecomment-1223543413

   You said you replaced relu with logsigmoid. I have a feeling that huge values in the PT output come from logarithm. If you keep relu, do the two outputs agree?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on issue #12524: [Bug] Lower Accuracy with a model compared to pytorch

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #12524:
URL: https://github.com/apache/tvm/issues/12524#issuecomment-1224757199

   hmm maybe argmax got fixed, but the raw outputs should still be different. 
   
   First, please run the test in https://github.com/apache/tvm/blob/f64a3bda253f4220d66eeb3348f93f486392cb8e/tests/python/frontend/pytorch/test_forward.py#L809-L814 to make sure that the existing test passes
   
   Second, try to find an input that causes the numerical issue and use that in the new test case. The new implementation should pass the test for such input. If it doesn't, you need to figure out what is still off from the PT `log_sigmoid`.
   
   > how does onnx frontend handle it?
   I think PT's ONNX exporter decomposes `log_sigmoid(x)` into `log(sigmoid(x))`, so there is nothing we can do about it.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org