You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/03/10 02:51:28 UTC

[GitHub] [tvm] qingcd opened a new issue #10559: [Bug] TVM model give different results when run multiple times

qingcd opened a new issue #10559:
URL: https://github.com/apache/tvm/issues/10559


   My pytorch model will have different results when run multiple times with the same input after converting to tvm model.  The cuda target fmt is ptx. If the target fmt chage back to cubin, then there is no problem. 
   
   ### Expected behavior
   The result of multiple run using the same input should stay the same,  the print of the sample code should be:
   max abs diff is: 0
   
   ### Actual behavior
   max abs diff is: 7.818208
   
   ### Environment
   gpu: rtx 2070
   nvcc: Cuda compilation tools, release 11.1, V11.1.74
   Nvidia Driver Version: 470.86 
   system: Linux shukun-desktop 5.13.0-27-generic #29~20.04.1-Ubuntu SMP Fri Jan 14 00:32:30 UTC 2022 x86_64 x86_64 x86_64 GNU/Linux
   TVM commit: 0c836b73ffd9669bcc416515dce6436cbd7d7ebe
   
   ### Steps to reproduce
   
   1. change the target_fmt from cubin to ptx in python/tvm/contrib/nvcc.py
   ```
   @tvm._ffi.register_func
   def tvm_callback_cuda_compile(code):
       """use nvcc to generate fatbin code for better optimization"""
       ptx = compile_cuda(code, target_format="fatbin")
       return ptx
   ```
   
   2. run this code
   ```
   import math
   
   import numpy as np
   
   import torch
   import torch.nn.functional as F
   from torch import nn
   
   import tvm
   from tvm import relay
   from tvm.contrib import graph_executor
   
   class BatchActivateConvLayer(nn.Module):
       def __init__(
           self, channel_in, growth_rate, bottleneck_size_basic_factor, drop_ratio=0.8
       ):
   
           super(BatchActivateConvLayer, self).__init__()
   
           self.drop_ratio = drop_ratio
           self.growth_rate = growth_rate
           self.bottleneck_channel_out = bottleneck_size_basic_factor * growth_rate
   
           self.mode_bn = torch.nn.BatchNorm3d(channel_in)
           self.mode_conv = nn.Conv3d(
               channel_in, self.bottleneck_channel_out, kernel_size=1, stride=1, bias=False
           )
   
           self.bn = torch.nn.BatchNorm3d(self.bottleneck_channel_out)
           self.conv = nn.Conv3d(
               self.bottleneck_channel_out,
               growth_rate,
               kernel_size=3,
               stride=1,
               padding=1,
               bias=False,
           )
   
           self.drop_out = nn.Dropout3d(p=self.drop_ratio)
   
       def forward(self, x):
   
           current = x
           current = self.mode_bn(current)
           current = self.mode_conv(current)
   
           current = self.bn(current)
           current = self.conv(current)
   
           if self.drop_ratio > 0:
               current = self.drop_out(current)
   
           return current
   
   
   class DenseBlock(nn.Module):
       def __init__(
           self,
           current_block_layers_number,
           channel_in,
           growth_rate,
           bottleneck_size_basic_factor,
           drop_ratio=0.8,
       ):
   
           super(DenseBlock, self).__init__()
   
           self.channel_in = channel_in
           self.growth_rate = growth_rate
           self.bottleneck_size_basic_factor = bottleneck_size_basic_factor
           self.current_channel_in = self.channel_in
           self.current_blcok_drop_ratio = drop_ratio
           self.current_block_layer_number = current_block_layers_number
   
           for i in range(self.current_block_layer_number):
               current_block_layers = BatchActivateConvLayer(
                   self.current_channel_in,
                   self.growth_rate,
                   self.bottleneck_size_basic_factor,
                   self.current_blcok_drop_ratio,
               )
   
               setattr(self, "block_layer_" + str(i), current_block_layers)
   
               self.current_channel_in += self.growth_rate
   
       def get_current_block_channel_out(self):
   
           return self.current_channel_in
   
       def forward(self, x):
   
           current = x
   
           for i in range(self.current_block_layer_number):
               current_clone = current.clone()
               tmp = getattr(self, "block_layer_" + str(i))(current_clone)
               current = torch.cat((current, tmp), 1)
   
           return current
   
   
   class DenseNet(nn.Module):
       def __init__(
           self,
           growth_rate=24,
           block_config=(2, 2),
           compression=0.5,
           num_init_features=24,
           bottleneck_size_basic_factor=2,
           drop_rate=0,
           num_classes=2,
           small_inputs=True,
           rnn_units=512,
       ):
           super(DenseNet, self).__init__()
   
           self.features = nn.Conv3d(
               1, num_init_features, kernel_size=3, stride=1, padding=1, bias=False
           )
   
           self.init_feature_channel_number = num_init_features
           self.growth_rate = growth_rate
           self.compression = compression
           self.number_class = num_classes
           self.block_config = block_config
           self.rnn_units = rnn_units
           self.drop_ratio = drop_rate
   
           num_features = num_init_features
   
           self.dense_trainsition_out_put_list = []
   
           for i, num_layers in enumerate(self.block_config):
               block = DenseBlock(
                   num_layers,
                   num_features,
                   self.growth_rate,
                   bottleneck_size_basic_factor,
                   drop_rate,
               )
               setattr(self, "block_" + str(i), block)
               num_features = num_features + num_layers * growth_rate
               self.dense_trainsition_out_put_list.append(num_features)
   
           for name, param in self.named_parameters():
               if "conv" in name and "weight" in name:
                   n = param.size(0) * param.size(2) * param.size(3) * param.size(4)
                   param.data.normal_().mul_(math.sqrt(2.0 / n))
               elif "norm" in name and "weight" in name:
                   param.data.fill_(1)
               elif "norm" in name and "bias" in name:
                   param.data.fill_(0)
       
   
       def forward(self, x):
           features = self.features(x[:, :1])
           for i in range(len(self.block_config)):
               features = getattr(self, "block_" + str(i))(features)
           return features
   
   def run_tvm_module(module, inpt):
       module.set_input(0, inpt)
       module.run()
       tvm.cuda().sync()
       res = module.get_output(0).numpy()
       return res
       
   if __name__ == "__main__":
       model = DenseNet()
       model.eval()
       model_jit = torch.jit.trace(model, example_inputs=torch.randn((4,2,64,64,64)))
       print("finish gen trace model")
       
       relay_model, params = relay.frontend.from_pytorch(
           model_jit, [('input_0', (4,2,64,64,64))], default_dtype='float32')
       target = tvm.target.cuda()
       with tvm.transform.PassContext(opt_level=3):
               lib = relay.build(relay_model, target=target, params=params)
       lib.export_library('./dense.so')
       del lib
       print("finish compile tvm model")
       
       inpt = np.random.random((4,2,64,64,64))
       lib = tvm.runtime.load_module('./dense.so')
       module = graph_executor.GraphModule(lib["default"](tvm.cuda()))
       res1 = run_tvm_module(module, inpt)
       res2 = run_tvm_module(module, inpt)
       
       diff = res1 - res2
       print("max abs diff is:", np.max(np.abs(diff)))
   ```
   May be there is some problem with the call of cuda kernel functions? Because the cubin target fmt works, the generated code of cuda kernel should be right with high probabilties.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org