You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/11/09 06:39:18 UTC

[GitHub] [incubator-mxnet] wy3406 opened a new issue #19498: SyncBN causes the memory to gradually increase with iteration

wy3406 opened a new issue #19498:
URL: https://github.com/apache/incubator-mxnet/issues/19498


   ## Description
   (A clear and concise description of what the bug is.)
   
   - I have a few issues/questions regarding SyncBN
   When using BN training in custom image segmentation, the memory is normal. But when I replaced BN with SyncBN, I found that the GPU memory gradually increased with iteration until it occupied the entire GPU memory,then the training is stuck. I try to use a smaller batch than BN, which also takes up all the GPU memory.
   Note there is no warning when I use SyncBN.
   Is there something I have missed?
   
   - Environments: Python 3.6.9 ; TITAN RTX × 8;CUDA 10.1
   
   - Framework: mxnet-cu101-1.7.0 and gluoncv-0.8.0
   
   ### Error Message
   (Paste the complete error message. Please also include stack trace by setting environment variable `DMLC_LOG_STACK_TRACE_DEPTH=100` before running your script.)
   
   ## To Reproduce
   (If you developed your own code, please provide a short script that reproduces the error. For existing examples, please provide link.)
   
   ### Steps to reproduce
   (Paste the commands you ran that produced the error.)
   
   1.
   2.
   
   ## What have you tried to solve it?
   
   1.
   2.
   
   ## Environment
   
   ***We recommend using our script for collecting the diagnostic information with the following command***
   `curl --retry 10 -s https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/diagnose.py | python3`
   
   <details>
   <summary>Environment Information</summary>
   
   ```
   # Paste the diagnose.py command output here
   ```
   
   </details>
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] kohillyang commented on issue #19498: SyncBN causes the memory to gradually increase with iteration

Posted by GitBox <gi...@apache.org>.
kohillyang commented on issue #19498:
URL: https://github.com/apache/incubator-mxnet/issues/19498#issuecomment-732689325


   Is it because of of DataParallelModel? Since muti-threading training && multi-processing training is not supported by mxnet. To speed the training up, I suggest you trying horovod instead.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] wy3406 commented on issue #19498: SyncBN causes the memory to gradually increase with iteration

Posted by GitBox <gi...@apache.org>.
wy3406 commented on issue #19498:
URL: https://github.com/apache/incubator-mxnet/issues/19498#issuecomment-724387973


   @leezu In the following example, nvidia-smi shows that the memory grows slowly as the iteration progresses
   ```
   from tqdm import tqdm
   
   from mxnet import gluon, autograd
   from mxnet.gluon import nn
   from mxnet.gluon.data import dataset
   from gluoncv.utils.parallel import DataParallelCriterion,DataParallelModel
   
   import mxnet.ndarray as nd
   import mxnet as mx
   import numpy as np
   
   class Activation(nn.HybridBlock):
       """Activation function used in MobileNetV3"""
       def __init__(self, act_func, **kwargs):
           super(Activation, self).__init__(**kwargs)
           if act_func == "relu":
               self.act = nn.Activation('relu')
           elif act_func == "relu6":
               self.act = ReLU6()
           elif act_func == "hard_sigmoid":
               self.act = HardSigmoid()
           elif act_func == "swish":
               self.act = nn.Swish()
           elif act_func == "leaky":
               self.act = nn.LeakyReLU(alpha=0.375)
           else:
               raise NotImplementedError
       def hybrid_forward(self, F, x):
           return self.act(x)
   
   def ConvBlock(in_channels,out_channels,
                   kernel_size=1,strides=1,padding=0,num_groups=1,
                   use_act=True,act_type='relu',
                   name_prefix='ConvBlock_Act_',
                   use_bias=False,
                   conv2d=nn.Conv2D,
                   norm_layer=nn.BatchNorm,norm_kwargs=None):
       out = nn.HybridSequential()
       with out.name_scope():
           out.add(conv2d(in_channels=in_channels,channels=out_channels,kernel_size=kernel_size,strides=strides,padding=padding,use_bias=use_bias,groups=num_groups)
                   ,norm_layer(in_channels=out_channels,**({} if norm_kwargs is None else norm_kwargs))
                   )
           if use_act:
                   out.add(Activation(act_type))
       return out
   
   class Net(nn.HybridBlock):
       def __init__(self,norm_layer,norm_kwargs):
           super(Net, self).__init__(prefix='')
           self.features= nn.HybridSequential()
           self.features.add(ConvBlock(3,256,
                                       kernel_size=3,strides=1,padding=1,num_groups=1,
                                       use_act=True,act_type='relu',
                                       name_prefix='ConvBlock_Act_',
                                       use_bias=False,
                                       conv2d=nn.Conv2D,
                                       norm_layer=norm_layer,norm_kwargs=norm_kwargs),
                             ConvBlock(256,512,
                                       kernel_size=3,strides=2,padding=1,num_groups=1,
                                       use_act=True,act_type='relu',
                                       name_prefix='ConvBlock_Act_',
                                       use_bias=False,
                                       conv2d=nn.Conv2D,
                                       norm_layer=norm_layer,norm_kwargs=norm_kwargs),
                             ConvBlock(512,512,
                                       kernel_size=3,strides=2,padding=1,num_groups=1,
                                       use_act=True,act_type='relu',
                                       name_prefix='ConvBlock_Act_',
                                       use_bias=False,
                                       conv2d=nn.Conv2D,
                                       norm_layer=norm_layer,norm_kwargs=norm_kwargs),
                             ConvBlock(512,512,
                                       kernel_size=3,strides=2,padding=1,num_groups=1,
                                       use_act=True,act_type='relu',
                                       name_prefix='ConvBlock_Act_',
                                       use_bias=False,
                                       conv2d=nn.Conv2D,
                                       norm_layer=norm_layer,norm_kwargs=norm_kwargs),
                             ConvBlock(512,1024,
                                       kernel_size=3,strides=2,padding=1,num_groups=1,
                                       use_act=True,act_type='relu',
                                       name_prefix='ConvBlock_Act_',
                                       use_bias=False,
                                       conv2d=nn.Conv2D,
                                       norm_layer=norm_layer,norm_kwargs=norm_kwargs),
                             ConvBlock(1024,1024,
                                       kernel_size=3,strides=2,padding=1,num_groups=1,
                                       use_act=True,act_type='relu',
                                       name_prefix='ConvBlock_Act_',
                                       use_bias=False,
                                       conv2d=nn.Conv2D,
                                       norm_layer=norm_layer,norm_kwargs=norm_kwargs),
           )
           self.features.add(nn.GlobalAvgPool2D())
           self.features.add(nn.Flatten())
           self.fc = nn.Dense(1, in_units=1024, use_bias=False)
       
       def hybrid_forward(self,F, x):
           x=self.features(x)
           out = self.fc(x)
           return out
   
   
   class TestData(dataset.Dataset):
       def __init__(self,):
           self.Number=1e5
       def __len__(self):
           return self.Number
   
       def __getitem__(self, idx):
           inp,tag=self.gen_data()
           inp=nd.array(inp,dtype=np.float32)
           tag=nd.array(tag,dtype=np.float32)
           return inp,tag
       
       def gen_data(self):
           X = np.random.randn(3*512*512,1).reshape(3,512,512)
           Y =np.random.randn(1)
           return X, Y
   
   ngpus=4
   _ctx=[mx.gpu(i) for i in range(ngpus)]
   _batch_size=20
   norm_kwargs ={'num_devices': ngpus}
   usesyncbn=True
   
   model=Net(norm_layer=mx.gluon.contrib.nn.SyncBatchNorm,norm_kwargs=norm_kwargs)
   model.initialize(mx.init.MSRAPrelu(),ctx=_ctx)
   net = DataParallelModel(model,_ctx, usesyncbn)
   criterion = DataParallelCriterion(mx.gluon.loss.L1Loss(), _ctx, usesyncbn)
   update_params=net.module.collect_params()
   optimizer=mx.gluon.Trainer(update_params,'adam',{'learning_rate': 0.001},mx.kvstore.create())
   
   train_dataset=TestData()
   train_data = gluon.data.DataLoader(train_dataset, _batch_size, 
                                           shuffle=True, last_batch='rollover',
                                           num_workers=4,
                                           pin_memory=False)
   
   for j in range(1000):
       tbar=tqdm(train_data)
       for i, idatas in enumerate(tbar):
           with autograd.record(True):
               ipt,targ=idatas
               oupt=net(ipt)
               losses=criterion(oupt,targ)
               mx.nd.waitall()
               autograd.backward(losses)
           optimizer.step(_batch_size)
           tbar.set_description()
           mx.nd.waitall()
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] leezu commented on issue #19498: SyncBN causes the memory to gradually increase with iteration

Posted by GitBox <gi...@apache.org>.
leezu commented on issue #19498:
URL: https://github.com/apache/incubator-mxnet/issues/19498#issuecomment-724130801


   Please provide a short example that can reproduce the bug. That will make it easier to identify and fix the memory leak


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org