You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/08/11 09:12:30 UTC

[GitHub] [incubator-mxnet] kohillyang opened a new issue #18902: Got "kFlag == type_flag_: TBlob.get_with_shape: data type do not match specified type.Expected: 0 v.s. given 2" when training with amp.

kohillyang opened a new issue #18902:
URL: https://github.com/apache/incubator-mxnet/issues/18902


   ## Description
   Hello, I'm trying to reproduce FCOS. In order to set a larger batch size, I'm trying to train my codes with FP16. I modified my codes based on <https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/faster_rcnn/train_faster_rcnn.py>, but after training started, I got the following error. One possible reason is that I have used some Operators that AMP does not support, but all contrib operators  I used are Deformable Convolution and BilinearResize2D, and I'm not sure if they are supported by AMP.
   
   <https://github.com/kohillyang/mx-detection/blob/master/scripts/train_fcos.py> are my codes, Thanks.
   ### Error Message
   ```bash
     0%|          | 0.00/19.5k [00:04<?, ?it/s]
   Traceback (most recent call last):
     File "/data2/kohill/mx-detection/scripts/train_fcos.py", line 584, in <module>
       main()
     File "/data2/kohill/mx-detection/scripts/train_fcos.py", line 512, in main
       train_net(config)
     File "/data2/kohill/mx-detection/scripts/train_fcos.py", line 399, in train_net
       trainer.step(batch_size)
     File "/data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/gluon/trainer.py", line 335, in step
       self._update(ignore_stale_grad)
     File "/data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/contrib/amp/amp.py", line 321, in new_update
       self._old_update(ignore_stale_grad)
     File "/data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/gluon/trainer.py", line 437, in _update
       updater(i, w, g)
     File "/data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/optimizer/optimizer.py", line 2071, in __call__
       states)
     File "/data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/contrib/amp/amp.py", line 313, in new_update_multi_precision
       if not skip_update():
     File "/data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/contrib/amp/loss_scaler.py", line 64, in wait_and_update
       self._has_overflow = not bool(self.output.asnumpy())
     File "/data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py", line 2535, in asnumpy
       ctypes.c_size_t(data.size)))
     File "/data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/base.py", line 255, in check_call
       raise MXNetError(py_str(_LIB.MXGetLastError()))
   mxnet.base.MXNetError: [16:57:52] include/mxnet/././tensor_blob.h:215: Check failed: mshadow::DataType<DType>::kFlag == type_flag_: TBlob.get_with_shape: data type do not match specified type.Expected: 0 v.s. given 2
   Stack trace:
     [bt] (0) /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x6b8b5b) [0x7ffaebc7cb5b]
     [bt] (1) /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x4295e78) [0x7ffaef859e78]
     [bt] (2) /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x42c01d1) [0x7ffaef8841d1]
     [bt] (3) /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/libmxnet.so(mxnet::imperative::PushFCompute(std::function<void (nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)> const&, nnvm::Op const*, nnvm::NodeAttrs const&, mxnet::Context const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::Resource, std::allocator<mxnet::Resource> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<unsigned int, std::allocator<unsigned int> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&)::{lambda(mxnet::RunContex
 t)#1}::operator()(mxnet::RunContext) const+0x307) [0x7ffaeee5d377]
     [bt] (4) /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x37b68d4) [0x7ffaeed7a8d4]
     [bt] (5) /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x37c4961) [0x7ffaeed88961]
     [bt] (6) /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x37c7ea0) [0x7ffaeed8bea0]
     [bt] (7) /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x37c8136) [0x7ffaeed8c136]
     [bt] (8) /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x37c3114) [0x7ffaeed87114]
   
   
   
   Process finished with exit code 1
   
   ```
   
   ## To Reproduce
   (If you developed your own code, please provide a short script that reproduces the error. For existing examples, please provide link.)
   
   ### Steps to reproduce
   (Paste the commands you ran that produced the error.)
   
   1.
   2.
   
   ## What have you tried to solve it?
   
   1.
   2.
   
   ## Environment
   
   We recommend using our script for collecting the diagnositc information. Run the following command and paste the outputs below:
   ```
   curl --retry 10 -s https://raw.githubusercontent.com/dmlc/gluon-nlp/master/tools/diagnose.py | python
   
   ----------Python Info----------
   Version      : 3.6.5
   Compiler     : GCC 7.2.0
   Build        : ('default', 'Apr 29 2018 16:14:56')
   Arch         : ('64bit', '')
   ------------Pip Info-----------
   Version      : 10.0.1
   Directory    : /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/pip
   ----------MXNet Info-----------
   /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
     from ._conv import register_converters as _register_converters
   Version      : 1.6.0
   Directory    : /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet
   Num GPUs     : 4
   Commit Hash   : 6eec9da55c5096079355d1f1a5fa58dcf35d6752
   ----------System Info----------
   Platform     : Linux-4.15.0-107-generic-x86_64-with-debian-stretch-sid
   system       : Linux
   node         : ubuntu
   release      : 4.15.0-107-generic
   version      : #108~16.04.1-Ubuntu SMP Fri Jun 12 02:57:13 UTC 2020
   ----------Hardware Info----------
   machine      : x86_64
   processor    : x86_64
   Architecture:          x86_64
   CPU op-mode(s):        32-bit, 64-bit
   Byte Order:            Little Endian
   CPU(s):                48
   On-line CPU(s) list:   0-47
   Thread(s) per core:    2
   Core(s) per socket:    12
   Socket(s):             2
   NUMA node(s):          2
   Vendor ID:             GenuineIntel
   CPU family:            6
   Model:                 63
   Model name:            Intel(R) Xeon(R) CPU E5-2680 v3 @ 2.50GHz
   Stepping:              2
   CPU MHz:               1735.211
   CPU max MHz:           3300.0000
   CPU min MHz:           1200.0000
   BogoMIPS:              5001.81
   Virtualization:        VT-x
   L1d cache:             32K
   L1i cache:             32K
   L2 cache:              256K
   L3 cache:              30720K
   NUMA node0 CPU(s):     0-11,24-35
   NUMA node1 CPU(s):     12-23,36-47
   Flags:                 fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm cpuid_fault epb invpcid_single pti intel_ppin ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm xsaveopt cqm_llc cqm_occup_llc dtherm ida arat pln pts md_clear flush_l1d
   ----------Network Test----------
   Setting timeout: 10
   Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 10.0385 sec, LOAD: 8.7852 sec.
   Timing for GluonNLP GitHub: https://github.com/dmlc/gluon-nlp, DNS: 0.0294 sec, LOAD: 9.1138 sec.
   Timing for GluonNLP: http://gluon-nlp.mxnet.io, DNS: 5.0335 sec, LOAD: 22.5200 sec.
   Timing for D2L: http://d2l.ai, DNS: 0.0282 sec, LOAD: 2.4723 sec.
   Timing for D2L (zh-cn): http://zh.d2l.ai, DNS: 0.0285 sec, LOAD: 0.4769 sec.
   Timing for FashionMNIST: https://repo.mxnet.io/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 11.8055 sec, LOAD: 16.3673 sec.
   Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0289 sec, LOAD: 10.1921 sec.
   Error open Conda: https://repo.continuum.io/pkgs/free/, HTTP Error 403: Forbidden, DNS finished in 0.02894139289855957 sec.
   
   ```
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mk-61 commented on issue #18902: Got "kFlag == type_flag_: TBlob.get_with_shape: data type do not match specified type.Expected: 0 v.s. given 2" when training with amp.

Posted by GitBox <gi...@apache.org>.
mk-61 commented on issue #18902:
URL: https://github.com/apache/incubator-mxnet/issues/18902#issuecomment-680623024


   For loading states - have you tried after
   https://github.com/apache/incubator-mxnet/pull/18959 ?
   It was supposed to fix this.
   
   On Tue, Aug 25, 2020 at 9:42 PM kohillyang <no...@github.com> wrote:
   
   > And it seems that if amp is used, the trainer is not able to load its
   > trainer states saved previously.
   >
   > —
   > You are receiving this because you were mentioned.
   > Reply to this email directly, view it on GitHub
   > <https://github.com/apache/incubator-mxnet/issues/18902#issuecomment-680604104>,
   > or unsubscribe
   > <https://github.com/notifications/unsubscribe-auth/ANQG5UTW4JI6PYZT7QHCCDTSCSHBZANCNFSM4P22FKJQ>
   > .
   >
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] szha commented on issue #18902: Got "kFlag == type_flag_: TBlob.get_with_shape: data type do not match specified type.Expected: 0 v.s. given 2" when training with amp.

Posted by GitBox <gi...@apache.org>.
szha commented on issue #18902:
URL: https://github.com/apache/incubator-mxnet/issues/18902#issuecomment-671877732


   Hi @kohillyang. Thanks for reporting the issue. Would you mind creating a small reproducible example? Others who help would still need to trigger this issue first in order to debug it.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] szha commented on issue #18902: Got "kFlag == type_flag_: TBlob.get_with_shape: data type do not match specified type.Expected: 0 v.s. given 2" when training with amp.

Posted by GitBox <gi...@apache.org>.
szha commented on issue #18902:
URL: https://github.com/apache/incubator-mxnet/issues/18902#issuecomment-679443104


   > I found the reason is that amp.init should be called before the creation of the network
   
   looks like at minimum we need to document this, or better yet it would be great if AMP can handle the necessary changes transparently. @mk-61 it would be great if you could help take this into account in https://github.com/apache/incubator-mxnet/issues/18896


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] szha commented on issue #18902: Got "kFlag == type_flag_: TBlob.get_with_shape: data type do not match specified type.Expected: 0 v.s. given 2" when training with amp.

Posted by GitBox <gi...@apache.org>.
szha commented on issue #18902:
URL: https://github.com/apache/incubator-mxnet/issues/18902#issuecomment-675240014


   so the error is saying that it's expecting float32 while getting float16. you may want to try to use NaiveEngine to pinpoint where this came from: https://mxnet.apache.org/api/dev-guide/debugging_and_performance_optimization_tips
   
   Type code can be found here: https://github.com/apache/incubator-mxnet/blob/master/3rdparty/mshadow/mshadow/base.h#L327-L342
   ```
   /*! \brief data type flag */
   enum TypeFlag {
     kFloat32 = 0,
     kFloat64 = 1,
     kFloat16 = 2,
     kUint8 = 3,
     kInt32 = 4,
     kInt8  = 5,
     kInt64 = 6,
     kBool = 7,
     kInt16 = 8,
     kUint16 = 9,
     kUint32 = 10,
     kUint64 = 11,
     kBfloat16 = 12
   };
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] kohillyang edited a comment on issue #18902: Got "kFlag == type_flag_: TBlob.get_with_shape: data type do not match specified type.Expected: 0 v.s. given 2" when training with amp.

Posted by GitBox <gi...@apache.org>.
kohillyang edited a comment on issue #18902:
URL: https://github.com/apache/incubator-mxnet/issues/18902#issuecomment-671957553


   @szha The following codes can reproduce the above error.
   ```bash
   from __future__ import print_function
   
   import mxnet as mx
   import mxnet.autograd as ag
   import numpy as np
   import gluoncv
   
   
   class resnet(mx.gluon.nn.HybridBlock):
       def __init__(self):
           super(resnet, self).__init__()
           self.feat = gluoncv.model_zoo.resnet50_v1b(pretrained=False)
   
       def hybrid_forward(self, F, x):
           input = F.transpose(x, (0, 3, 1, 2))
           x = input / 255.0
           x = self.feat.conv1(x)
           x = self.feat.bn1(x)
           x = self.feat.relu(x)
           x = self.feat.maxpool(x)
   
           res2 = self.feat.layer1(x)
           res3 = self.feat.layer2(res2)
           res4 = self.feat.layer3(res3)
           res5 = self.feat.layer4(res4)
   
           return res5
   
   
   def train_net():
       mx.random.seed(3)
       np.random.seed(3)
   
       ctx_list = [mx.gpu(0)]
       net = resnet()
       net.initialize()
       net.collect_params().reset_ctx(list(set(ctx_list)))
       if True:
           from mxnet.contrib import amp
           amp.init()
           net.cast("float16")
           # net.collect_params('.*batchnorm.*').setattr('dtype', 'float32')
   
       trainer = mx.gluon.Trainer(
           net.collect_params(),  # fix batchnorm, fix first stage, etc...
           'sgd',
           {'wd': 1e-4,
            'momentum': .9,
            'clip_gradient': None,
            'lr_scheduler': None,
            'multi_precision': True,
            },
           update_on_kvstore=(False if True else None), kvstore=mx.kvstore.create('local')
       )
       if True:
           amp.init_trainer(trainer)
   
       with ag.record():
           data = mx.nd.zeros(shape=(1, 368, 368, 3), ctx=ctx_list[0])
           fpn_predictions = net(data)
           preds = mx.nd.concat(*[x.reshape((0, 0, -1)) for x in fpn_predictions], dim=2)
           with amp.scale_loss(preds.sum(), trainer) as scaled_losses:
               scaled_losses.backward()
       trainer.step(1, ignore_stale_grad=False)
   
   
   if __name__ == '__main__':
       train_net()
   
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] kohillyang commented on issue #18902: Got "kFlag == type_flag_: TBlob.get_with_shape: data type do not match specified type.Expected: 0 v.s. given 2" when training with amp.

Posted by GitBox <gi...@apache.org>.
kohillyang commented on issue #18902:
URL: https://github.com/apache/incubator-mxnet/issues/18902#issuecomment-679441768


   I found the reason is that amp.init should be called before the creation of the network, because it will patch some functions in the namespaces of mx.nd and mx.sym. And since the functions have already been patched, the network does not need to cast to float16. 
   
   However, in <https://github.com/dmlc/gluon-cv/blob/e3513064244f3f987f699ac43781d40ad01e144a/scripts/detection/faster_rcnn/train_faster_rcnn.py#L655>, the network is cast to float16 when using amp, so I 'm not sure whether the casting is needed.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] kohillyang commented on issue #18902: Got "kFlag == type_flag_: TBlob.get_with_shape: data type do not match specified type.Expected: 0 v.s. given 2" when training with amp.

Posted by GitBox <gi...@apache.org>.
kohillyang commented on issue #18902:
URL: https://github.com/apache/incubator-mxnet/issues/18902#issuecomment-675226027


   Can I get some advice? I have no idea how to solve this problem. Thanks.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] kohillyang commented on issue #18902: Got "kFlag == type_flag_: TBlob.get_with_shape: data type do not match specified type.Expected: 0 v.s. given 2" when training with amp.

Posted by GitBox <gi...@apache.org>.
kohillyang commented on issue #18902:
URL: https://github.com/apache/incubator-mxnet/issues/18902#issuecomment-680604104


   And it seems that if amp is used, the trainer is not able to load its trainer states saved previously.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] kohillyang edited a comment on issue #18902: Got "kFlag == type_flag_: TBlob.get_with_shape: data type do not match specified type.Expected: 0 v.s. given 2" when training with amp.

Posted by GitBox <gi...@apache.org>.
kohillyang edited a comment on issue #18902:
URL: https://github.com/apache/incubator-mxnet/issues/18902#issuecomment-671957553


   @szha The following codes can reproduce the above error.
   ```bash
   from __future__ import print_function
   
   import mxnet as mx
   import mxnet.autograd as ag
   import numpy as np
   import gluoncv
   
   
   class resnet(mx.gluon.nn.HybridBlock):
       def __init__(self):
           super(resnet, self).__init__()
           self.feat = gluoncv.model_zoo.resnet50_v1b(pretrained=False)
   
       def hybrid_forward(self, F, x):
           input = F.transpose(x, (0, 3, 1, 2))
           x = input / 255.0
           x = self.feat.conv1(x)
           x = self.feat.bn1(x)
           x = self.feat.relu(x)
           x = self.feat.maxpool(x)
   
           res2 = self.feat.layer1(x)
           res3 = self.feat.layer2(res2)
           res4 = self.feat.layer3(res3)
           res5 = self.feat.layer4(res4)
   
           return res5
   
   
   def train_net():
       mx.random.seed(3)
       np.random.seed(3)
   
       ctx_list = [mx.gpu(0)]
       net = resnet()
       net.initialize()
       net.collect_params().reset_ctx(list(set(ctx_list)))
       if True:
           from mxnet.contrib import amp
           amp.init()
           net.cast("float16")
           # net.collect_params('.*batchnorm.*').setattr('dtype', 'float32')
   
       trainer = mx.gluon.Trainer(
           net.collect_params(),  # fix batchnorm, fix first stage, etc...
           'sgd',
           {'wd': 1e-4,
            'momentum': .9,
            'clip_gradient': None,
            'lr_scheduler': None,
            'multi_precision': True,
            },
           update_on_kvstore=(False if True else None), kvstore=mx.kvstore.create('local')
       )
       if True:
           amp.init_trainer(trainer)
   
       with ag.record():
           data = mx.nd.zeros(shape=(1, 368, 368, 3), ctx=ctx_list[0])
           fpn_predictions = net(data)
           preds = mx.nd.concat(*[x.reshape((0, 0, -1)) for x in fpn_predictions], dim=2)
           with amp.scale_loss(preds.sum(), trainer) as scaled_losses:
               scaled_losses.backward()
       trainer.step(1, ignore_stale_grad=True)
   
   
   if __name__ == '__main__':
       train_net()
   
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] kohillyang commented on issue #18902: Got "kFlag == type_flag_: TBlob.get_with_shape: data type do not match specified type.Expected: 0 v.s. given 2" when training with amp.

Posted by GitBox <gi...@apache.org>.
kohillyang commented on issue #18902:
URL: https://github.com/apache/incubator-mxnet/issues/18902#issuecomment-671957553


   @szha The following codes can reproduce the above error.
   ```bash
   from __future__ import print_function
   
   import mxnet as mx
   import mxnet.autograd as ag
   import numpy as np
   import gluoncv
   
   
   class FCOS_Head(mx.gluon.nn.HybridBlock):
       def __init__(self, num_classes):
           super(FCOS_Head, self).__init__()
           with self.name_scope():
               self.feat_cls = mx.gluon.nn.HybridSequential()
               init = mx.init.Normal(sigma=0.01)
               init.set_verbosity(True)
               init_bias = mx.init.Constant(-1 * np.log((1 - 0.01) / 0.01))
               init_bias.set_verbosity(True)
               for i in range(4):
                   self.feat_cls.add(mx.gluon.nn.Conv2D(channels=256, kernel_size=3, padding=1, weight_initializer=init))
                   self.feat_cls.add(mx.gluon.nn.GroupNorm(num_groups=32))
                   self.feat_cls.add(mx.gluon.nn.Activation(activation="relu"))
               self.feat_cls.add(mx.gluon.nn.Conv2D(channels=num_classes - 1, kernel_size=1, padding=0,
                                                    bias_initializer=init_bias, weight_initializer=init))
   
               self.feat_reg = mx.gluon.nn.HybridSequential()
               for i in range(4):
                   self.feat_reg.add(mx.gluon.nn.Conv2D(channels=256, kernel_size=3, padding=1, weight_initializer=init))
                   self.feat_reg.add(mx.gluon.nn.GroupNorm(num_groups=32))
                   self.feat_reg.add(mx.gluon.nn.Activation(activation="relu"))
   
               # one extra channel for center-ness, four channel for location regression.
               self.feat_reg_loc = mx.gluon.nn.Conv2D(channels=4, kernel_size=1, padding=0, weight_initializer=init)
               self.feat_reg_centerness = mx.gluon.nn.Conv2D(channels=1, kernel_size=1, padding=0, weight_initializer=init)
   
       def hybrid_forward(self, F, x):
           feat_reg = self.feat_reg(x)
           x_loc = self.feat_reg_loc(feat_reg)
           x_centerness = self.feat_reg_centerness(feat_reg)
           x_cls = self.feat_cls(x)
           x = F.concat(x_loc, x_centerness, x_cls, dim=1)
           return x
   
   
   class resnet(mx.gluon.nn.HybridBlock):
       def __init__(self):
           super(resnet, self).__init__()
           self.feat = gluoncv.model_zoo.resnet50_v1b(pretrained=False)
           self.mean = self.params.get('mean', shape=[1, 3, 1, 1],
                                       init=mx.init.Zero(),
                                       allow_deferred_init=False, grad_req='null')
           self.std = self.params.get('std', shape=[1, 3, 1, 1],
                                      init=mx.init.One(),  # mx.nd.array(),
                                      allow_deferred_init=False, grad_req='null')
           self.mean._load_init(mx.nd.array([[[[0.485]], [[0.456]], [[0.406]]]]), ctx=mx.cpu())
           self.std._load_init(mx.nd.array([[[[0.229]], [[0.224]], [[0.225]]]]), ctx=mx.cpu())
   
       def hybrid_forward(self, F, x, mean, std):
           input = F.transpose(x, (0, 3, 1, 2))
           x = input / 255.0
           x = F.broadcast_sub(x, mean)
           x = F.broadcast_div(x, std)
           x = self.feat.conv1(x)
           x = self.feat.bn1(x)
           x = self.feat.relu(x)
           x = self.feat.maxpool(x)
   
           res2 = self.feat.layer1(x)
           res3 = self.feat.layer2(res2)
           res4 = self.feat.layer3(res3)
           res5 = self.feat.layer4(res4)
   
           return res5
   
   
   class FCOSFPNNet(mx.gluon.nn.HybridBlock):
       def __init__(self, num_classes):
           super(FCOSFPNNet, self).__init__()
           self.backbone = resnet()
           self.fcos_head = FCOS_Head(num_classes)
   
       def hybrid_forward(self, F, x):
           # typically the strides are (4, 8, 16, 32, 64)
           x = self.backbone(x)
           if isinstance(x, list) or isinstance(x, tuple):
               return [self.fcos_head(xx) for xx in x]
           else:
               return [self.fcos_head(x)]
   
   
   def train_net():
       mx.random.seed(3)
       np.random.seed(3)
   
       ctx_list = [mx.gpu(0)]
       net = FCOSFPNNet(11)
       # Initialize parameters
       params = net.collect_params()
       for key in params.keys():
           if params[key]._data is None:
               default_init = mx.init.Zero() if "bias" in key or "offset" in key else mx.init.Normal()
               default_init.set_verbosity(True)
               if params[key].init is not None and hasattr(params[key].init, "set_verbosity"):
                   params[key].init.set_verbosity(True)
                   params[key].initialize(init=params[key].init, default_init=params[key].init)
               else:
                   params[key].initialize(default_init=default_init)
   
       net.collect_params().reset_ctx(list(set(ctx_list)))
       if True:
           from mxnet.contrib import amp
           amp.init()
           net.cast("float16")
           net.collect_params('.*batchnorm.*').setattr('dtype', 'float32')
   
       trainer = mx.gluon.Trainer(
           net.collect_params(),  # fix batchnorm, fix first stage, etc...
           'sgd',
           {'wd': 1e-4,
            'momentum': .9,
            'clip_gradient': None,
            'lr_scheduler': None,
            'multi_precision': True,
            },
           update_on_kvstore=(False if True else None), kvstore=mx.kvstore.create('local')
       )
       if True:
           amp.init_trainer(trainer)
   
       with ag.record():
           data = mx.nd.zeros(shape=(1, 368, 368, 3), ctx=ctx_list[0])
           fpn_predictions = net(data)
           preds = mx.nd.concat(*[x.reshape((0, 0, -1)) for x in fpn_predictions], dim=2)
           with amp.scale_loss(preds.sum(), trainer) as scaled_losses:
               ag.backward(scaled_losses)
       trainer.step(1, ignore_stale_grad=True)
   
   
   if __name__ == '__main__':
       train_net()
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx closed issue #18902: Got "kFlag == type_flag_: TBlob.get_with_shape: data type do not match specified type.Expected: 0 v.s. given 2" when training with amp.

Posted by GitBox <gi...@apache.org>.
ptrendx closed issue #18902:
URL: https://github.com/apache/incubator-mxnet/issues/18902


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] kohillyang commented on issue #18902: Got "kFlag == type_flag_: TBlob.get_with_shape: data type do not match specified type.Expected: 0 v.s. given 2" when training with amp.

Posted by GitBox <gi...@apache.org>.
kohillyang commented on issue #18902:
URL: https://github.com/apache/incubator-mxnet/issues/18902#issuecomment-680788665


   @mk-61 It works, thank you very much.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on issue #18902: Got "kFlag == type_flag_: TBlob.get_with_shape: data type do not match specified type.Expected: 0 v.s. given 2" when training with amp.

Posted by GitBox <gi...@apache.org>.
ptrendx commented on issue #18902:
URL: https://github.com/apache/incubator-mxnet/issues/18902#issuecomment-708001324


   Closing based on https://github.com/apache/incubator-mxnet/issues/18902#issuecomment-680788665, and the warning for too late inintialization of AMP was introduced in #19036 .


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org