You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/11/04 17:10:08 UTC

[GitHub] [incubator-mxnet] sxjscience opened a new issue #19472: [Bug][AMP][2.0] GluonNLP BART model AMP error

sxjscience opened a new issue #19472:
URL: https://github.com/apache/incubator-mxnet/issues/19472


   ## Description
   Tried to use the GluonNLP BART model with amp but met with a  weird error:
   
   What I'm doing is the following:
   - Create FP32 model
   - FP32 forward
   - Create FP16 model
   - Share the weights of FP32 model to FP16 model
   - FP16 forward
   - Turn on AMP
   - FP32 forward (Trigger error).
   
   To reproduce, you will first need to install the master version of GluonNLP:
   
   ```
   git clone https://github.com/dmlc/gluon-nlp; cd gluon-nlp
   git checkout master
   python3 -m pip install --quiet -e .[extras]
   ```
   
   ```python
   import mxnet as mx
   import numpy as np
   mx.npx.set_np()
   from gluonnlp.models.bart import BartModel
   cfg = BartModel.get_cfg()
   cfg.defrost()
   cfg.MODEL.vocab_size = 32
   cfg.freeze()
   ctx = mx.gpu()
   
   batch_size = 4
   src_length = 32
   tgt_length = 16
   
   with ctx:
       src_data = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, src_length),
                                       dtype=np.int32)
       src_valid_length = mx.np.random.randint(src_length // 2, src_length, (batch_size,),
                                               dtype=np.int32)
       tgt_data = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, tgt_length),
                                       dtype=np.int32)
       tgt_valid_length = mx.np.random.randint(tgt_length // 2, tgt_length, (batch_size, ),
                                               dtype=np.int32)
       model = BartModel.from_cfg(cfg)
       model.initialize(ctx=ctx)
       model.hybridize()
   
       model(src_data, src_valid_length, tgt_data, tgt_valid_length)
       mx.npx.waitall()
       
       model_fp16 = BartModel.from_cfg(cfg)
       model_fp16.share_parameters(model.collect_params())
       model_fp16.cast('float16')
       model_fp16.hybridize()
       model_fp16(src_data, src_valid_length, tgt_data, tgt_valid_length)
       mx.npx.waitall()
       from mxnet import amp
       amp.init()
       model(src_data, src_valid_length, tgt_data, tgt_valid_length)
       mx.npx.waitall()
   
       trainer = mx.gluon.Trainer(model.collect_params(), 'adam',
                                  {'learning_rate': 1E-3, 'wd': 1E-4,
                                   'multi_precision': True},
                                  update_on_kvstore=False)
       amp.init_trainer(trainer)
       with mx.autograd.record():
           outputs_amp = model(src_data, src_valid_length, tgt_data, tgt_valid_length)
           if not isinstance(outputs_amp, (tuple, list)):
               loss = outputs_amp.mean()
           else:
               loss = sum([ele.mean() for ele in outputs_amp])
           with amp.scale_loss(loss, trainer) as scaled_loss:
               mx.autograd.backward(scaled_loss)
       trainer.step(1)
       mx.npx.waitall()
   
   
   ```
   
   
   
   ### Error Message
   ```
   ---------------------------------------------------------------------------
   MXNetError                                Traceback (most recent call last)
   <ipython-input-1-73c1077ec25a> in <module>
        33     model_fp16.cast('float16')
        34     model_fp16.hybridize()
   ---> 35     model_fp16(src_data, src_valid_length, tgt_data, tgt_valid_length)
        36     mx.npx.waitall()
        37     from mxnet import amp
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs)
       297         def _with_np_shape(*args, **kwargs):
       298             with np_shape(active=True):
   --> 299                 return func(*args, **kwargs)
       300         return _with_np_shape
       301     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs)
       478         def _with_np_array(*args, **kwargs):
       479             with np_array(active=True):
   --> 480                 return func(*args, **kwargs)
       481         return _with_np_array
       482     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs)
       297         def _with_np_shape(*args, **kwargs):
       298             with np_shape(active=True):
   --> 299                 return func(*args, **kwargs)
       300         return _with_np_shape
       301     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs)
       478         def _with_np_array(*args, **kwargs):
       479             with np_array(active=True):
   --> 480                 return func(*args, **kwargs)
       481         return _with_np_array
       482     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, x, *args)
      1431 
      1432             with x.ctx:
   -> 1433                 return self._call_cached_op(x, *args)
      1434 
      1435     def forward(self, x, *args):
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs)
       297         def _with_np_shape(*args, **kwargs):
       298             with np_shape(active=True):
   --> 299                 return func(*args, **kwargs)
       300         return _with_np_shape
       301     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs)
       478         def _with_np_array(*args, **kwargs):
       479             with np_array(active=True):
   --> 480                 return func(*args, **kwargs)
       481         return _with_np_array
       482     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs)
       297         def _with_np_shape(*args, **kwargs):
       298             with np_shape(active=True):
   --> 299                 return func(*args, **kwargs)
       300         return _with_np_shape
       301     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs)
       478         def _with_np_array(*args, **kwargs):
       479             with np_array(active=True):
   --> 480                 return func(*args, **kwargs)
       481         return _with_np_array
       482     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in _call_cached_op(self, *args)
      1100     def _call_cached_op(self, *args):
      1101         if self._cached_op is None:
   -> 1102             self._build_cache(*args)
      1103         assert self._cached_op, "Gluon failed to build the cache. " \
      1104                                 "This should never happen. " \
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs)
       297         def _with_np_shape(*args, **kwargs):
       298             with np_shape(active=True):
   --> 299                 return func(*args, **kwargs)
       300         return _with_np_shape
       301     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs)
       478         def _with_np_array(*args, **kwargs):
       479             with np_array(active=True):
   --> 480                 return func(*args, **kwargs)
       481         return _with_np_array
       482     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs)
       297         def _with_np_shape(*args, **kwargs):
       298             with np_shape(active=True):
   --> 299                 return func(*args, **kwargs)
       300         return _with_np_shape
       301     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs)
       478         def _with_np_array(*args, **kwargs):
       479             with np_array(active=True):
   --> 480                 return func(*args, **kwargs)
       481         return _with_np_array
       482     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in _build_cache(self, *args)
       993 
       994     def _build_cache(self, *args):
   --> 995         data, out = self._get_graph(*args)
       996         data_names = {data.name: i for i, data in enumerate(data)}
       997         params = {p.var().name: p for p in self.collect_params().values()}
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs)
       297         def _with_np_shape(*args, **kwargs):
       298             with np_shape(active=True):
   --> 299                 return func(*args, **kwargs)
       300         return _with_np_shape
       301     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs)
       478         def _with_np_array(*args, **kwargs):
       479             with np_array(active=True):
   --> 480                 return func(*args, **kwargs)
       481         return _with_np_array
       482     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs)
       297         def _with_np_shape(*args, **kwargs):
       298             with np_shape(active=True):
   --> 299                 return func(*args, **kwargs)
       300         return _with_np_shape
       301     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs)
       478         def _with_np_array(*args, **kwargs):
       479             with np_array(active=True):
   --> 480                 return func(*args, **kwargs)
       481         return _with_np_array
       482     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in _get_graph(self, *args)
       989                 return self._get_graph_v1(*args)
       990             else:  # Gluon 2 based on deferred compute mode
   --> 991                 return self._get_graph_v2(*args)
       992         return self._cached_graph
       993 
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs)
       297         def _with_np_shape(*args, **kwargs):
       298             with np_shape(active=True):
   --> 299                 return func(*args, **kwargs)
       300         return _with_np_shape
       301     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs)
       478         def _with_np_array(*args, **kwargs):
       479             with np_array(active=True):
   --> 480                 return func(*args, **kwargs)
       481         return _with_np_array
       482     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs)
       297         def _with_np_shape(*args, **kwargs):
       298             with np_shape(active=True):
   --> 299                 return func(*args, **kwargs)
       300         return _with_np_shape
       301     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs)
       478         def _with_np_array(*args, **kwargs):
       479             with np_array(active=True):
   --> 480                 return func(*args, **kwargs)
       481         return _with_np_array
       482     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in _get_graph_v2(self, *args)
       978             args = _regroup(flatten_args, self._in_format)
       979             with autograd.pause(), dc.context():
   --> 980                 out = super().__call__(*args)
       981             flatten_out, self._out_format = _flatten(out, "output")
       982             symbol_outputs = dc.get_symbol(flatten_out, sym_cls=type(symbol_inputs[0]))
   
   ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, *args)
       709             hook(self, args)
       710 
   --> 711         out = self.forward(*args)
       712 
       713         for hook in self._forward_hooks.values():
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs)
       297         def _with_np_shape(*args, **kwargs):
       298             with np_shape(active=True):
   --> 299                 return func(*args, **kwargs)
       300         return _with_np_shape
       301     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs)
       478         def _with_np_array(*args, **kwargs):
       479             with np_array(active=True):
   --> 480                 return func(*args, **kwargs)
       481         return _with_np_array
       482     else:
   
   ~/gluon-nlp/src/gluonnlp/models/bart.py in forward(self, src_data, src_valid_length, tgt_data, tgt_valid_length)
       235                     Shape (tgt_length, batch_size, tgt_vocab_size)
       236         """
   --> 237         enc_out = self.encode(src_data, src_valid_length)
       238         contextual_embedding = self.decode_seq(tgt_data, tgt_valid_length, enc_out,
       239                                                src_valid_length)
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs)
       297         def _with_np_shape(*args, **kwargs):
       298             with np_shape(active=True):
   --> 299                 return func(*args, **kwargs)
       300         return _with_np_shape
       301     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs)
       478         def _with_np_array(*args, **kwargs):
       479             with np_array(active=True):
   --> 480                 return func(*args, **kwargs)
       481         return _with_np_array
       482     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs)
       297         def _with_np_shape(*args, **kwargs):
       298             with np_shape(active=True):
   --> 299                 return func(*args, **kwargs)
       300         return _with_np_shape
       301     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs)
       478         def _with_np_array(*args, **kwargs):
       479             with np_array(active=True):
   --> 480                 return func(*args, **kwargs)
       481         return _with_np_array
       482     else:
   
   ~/gluon-nlp/src/gluonnlp/models/transformer.py in encode(self, src_data, src_valid_length)
      1177                     npx.arange_like(src_data, axis=0)), axis=1)
      1178 
   -> 1179         enc_out = self.encoder(src_data, src_valid_length)
      1180         return enc_out
      1181 
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs)
       297         def _with_np_shape(*args, **kwargs):
       298             with np_shape(active=True):
   --> 299                 return func(*args, **kwargs)
       300         return _with_np_shape
       301     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs)
       478         def _with_np_array(*args, **kwargs):
       479             with np_array(active=True):
   --> 480                 return func(*args, **kwargs)
       481         return _with_np_array
       482     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, x, *args)
      1428                 # Deferred compute is already enabled. This typically means that the current
      1429                 # HybridBlock is a child block of a HybridBlock that has been hybridized.
   -> 1430                 return super().__call__(x, *args)
      1431 
      1432             with x.ctx:
   
   ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, *args)
       709             hook(self, args)
       710 
   --> 711         out = self.forward(*args)
       712 
       713         for hook in self._forward_hooks.values():
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs)
       297         def _with_np_shape(*args, **kwargs):
       298             with np_shape(active=True):
   --> 299                 return func(*args, **kwargs)
       300         return _with_np_shape
       301     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs)
       478         def _with_np_array(*args, **kwargs):
       479             with np_array(active=True):
   --> 480                 return func(*args, **kwargs)
       481         return _with_np_array
       482     else:
   
   ~/gluon-nlp/src/gluonnlp/models/transformer.py in forward(self, data, valid_length)
       372             else:
       373                 layer = self.layers[i]
   --> 374             out, _ = layer(out, attn_mask)
       375         if self._pre_norm:
       376             out = self.ln_final(out)
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs)
       297         def _with_np_shape(*args, **kwargs):
       298             with np_shape(active=True):
   --> 299                 return func(*args, **kwargs)
       300         return _with_np_shape
       301     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs)
       478         def _with_np_array(*args, **kwargs):
       479             with np_array(active=True):
   --> 480                 return func(*args, **kwargs)
       481         return _with_np_array
       482     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, x, *args)
      1428                 # Deferred compute is already enabled. This typically means that the current
      1429                 # HybridBlock is a child block of a HybridBlock that has been hybridized.
   -> 1430                 return super().__call__(x, *args)
      1431 
      1432             with x.ctx:
   
   ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, *args)
       709             hook(self, args)
       710 
   --> 711         out = self.forward(*args)
       712 
       713         for hook in self._forward_hooks.values():
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs)
       297         def _with_np_shape(*args, **kwargs):
       298             with np_shape(active=True):
   --> 299                 return func(*args, **kwargs)
       300         return _with_np_shape
       301     else:
   
   ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs)
       478         def _with_np_array(*args, **kwargs):
       479             with np_array(active=True):
   --> 480                 return func(*args, **kwargs)
       481         return _with_np_array
       482     else:
   
   ~/gluon-nlp/src/gluonnlp/models/transformer.py in forward(self, data, attn_mask)
       257         key = npx.reshape(key, (-2, -2, self._num_heads, -1))
       258         value = npx.reshape(value, (-2, -2, self._num_heads, -1))
   --> 259         out, [_, attn_weight] = self.attention_cell(query, key, value, attn_mask)
       260         out = self.attention_proj(out)
       261         out = self.dropout_layer(out)
   
   ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, x, *args)
      1428                 # Deferred compute is already enabled. This typically means that the current
      1429                 # HybridBlock is a child block of a HybridBlock that has been hybridized.
   -> 1430                 return super().__call__(x, *args)
      1431 
      1432             with x.ctx:
   
   ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, *args)
       709             hook(self, args)
       710 
   --> 711         out = self.forward(*args)
       712 
       713         for hook in self._forward_hooks.values():
   
   ~/gluon-nlp/src/gluonnlp/attention_cell.py in forward(self, query, key, value, mask, edge_scores)
       643                                    query_head_units=self._query_head_units,
       644                                    layout=self._layout, use_einsum=self._use_einsum,
   --> 645                                    dtype=self._dtype)
       646 
       647     def __repr__(self):
   
   ~/gluon-nlp/src/gluonnlp/attention_cell.py in multi_head_dot_attn(query, key, value, mask, edge_scores, dropout, scaled, normalized, eps, query_head_units, layout, use_einsum, dtype)
       540         else:
       541             context_vec = npx.batch_dot(attn_weights,
   --> 542                                           np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3))
       543         context_vec = npx.reshape(context_vec, (-2, -2, -1))
       544     elif layout == 'TNK':
   
   ~/.local/lib/python3.6/site-packages/mxnet/ndarray/register.py in batch_dot(lhs, rhs, transpose_a, transpose_b, forward_stype, out, name, **kwargs)
   
   ~/.local/lib/python3.6/site-packages/mxnet/_ctypes/ndarray.py in _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op, output_is_list)
        89         c_str_array(keys),
        90         c_str_array([str(s) for s in vals]),
   ---> 91         ctypes.byref(out_stypes)))
        92 
        93     create_ndarray_fn = _global_var._np_ndarray_cls if is_np_op else _global_var._ndarray_cls
   
   ~/.local/lib/python3.6/site-packages/mxnet/base.py in check_call(ret)
       244     """
       245     if ret != 0:
   --> 246         raise get_last_ffi_error()
       247 
       248 
   
   MXNetError: Traceback (most recent call last):
     File "../src/io/../operator/elemwise_op_common.h", line 135
   MXNetError: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node batch_dot at 1-th input: expected float32, got float16
   ```
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] sxjscience edited a comment on issue #19472: [Bug][2.0] GluonNLP BART model float16 error

Posted by GitBox <gi...@apache.org>.
sxjscience edited a comment on issue #19472:
URL: https://github.com/apache/incubator-mxnet/issues/19472#issuecomment-721884922


   In fact, directly casting the original model to float16 does not work also:
   ```python
   import mxnet as mx
   import numpy as np
   mx.npx.set_np()
   from gluonnlp.models.bart import BartModel
   cfg = BartModel.get_cfg()
   cfg.defrost()
   cfg.MODEL.vocab_size = 32
   cfg.freeze()
   ctx = mx.gpu()
   
   batch_size = 4
   src_length = 32
   tgt_length = 16
   
   with ctx:
       src_data = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, src_length),
                                       dtype=np.int32)
       src_valid_length = mx.np.random.randint(src_length // 2, src_length, (batch_size,),
                                               dtype=np.int32)
       tgt_data = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, tgt_length),
                                       dtype=np.int32)
       tgt_valid_length = mx.np.random.randint(tgt_length // 2, tgt_length, (batch_size, ),
                                               dtype=np.int32)
       model = BartModel.from_cfg(cfg)
       model.initialize(ctx=ctx)
       model.hybridize()
   
       model(src_data, src_valid_length, tgt_data, tgt_valid_length)
       mx.npx.waitall()
   
       model.cast('float16')
       model.hybridize()
       model(src_data, src_valid_length, tgt_data, tgt_valid_length)
       mx.npx.waitall()
   ```
   
   In fact, it cannot run even if we have removed hybridization.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] sxjscience commented on issue #19472: [Bug][2.0] GluonNLP BART model float16 error

Posted by GitBox <gi...@apache.org>.
sxjscience commented on issue #19472:
URL: https://github.com/apache/incubator-mxnet/issues/19472#issuecomment-721883569


   Removed the amp-related code:
   
   ```python
   import mxnet as mx
   import numpy as np
   mx.npx.set_np()
   from gluonnlp.models.bart import BartModel
   cfg = BartModel.get_cfg()
   cfg.defrost()
   cfg.MODEL.vocab_size = 32
   cfg.freeze()
   ctx = mx.gpu()
   
   batch_size = 4
   src_length = 32
   tgt_length = 16
   
   with ctx:
       src_data = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, src_length),
                                       dtype=np.int32)
       src_valid_length = mx.np.random.randint(src_length // 2, src_length, (batch_size,),
                                               dtype=np.int32)
       tgt_data = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, tgt_length),
                                       dtype=np.int32)
       tgt_valid_length = mx.np.random.randint(tgt_length // 2, tgt_length, (batch_size, ),
                                               dtype=np.int32)
       model = BartModel.from_cfg(cfg)
       model.initialize(ctx=ctx)
       model.hybridize()
   
       model(src_data, src_valid_length, tgt_data, tgt_valid_length)
       mx.npx.waitall()
       
       model_fp16 = BartModel.from_cfg(cfg)
       model_fp16.share_parameters(model.collect_params())
       model_fp16.cast('float16')
       model_fp16.hybridize()
       model_fp16(src_data, src_valid_length, tgt_data, tgt_valid_length)
       mx.npx.waitall()
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] sxjscience commented on issue #19472: [Bug][2.0] GluonNLP BART model float16 error

Posted by GitBox <gi...@apache.org>.
sxjscience commented on issue #19472:
URL: https://github.com/apache/incubator-mxnet/issues/19472#issuecomment-721884922


   In fact, directly casting the original model to float16 does not work also:
   ```python
   import mxnet as mx
   import numpy as np
   mx.npx.set_np()
   from gluonnlp.models.bart import BartModel
   cfg = BartModel.get_cfg()
   cfg.defrost()
   cfg.MODEL.vocab_size = 32
   cfg.freeze()
   ctx = mx.gpu()
   
   batch_size = 4
   src_length = 32
   tgt_length = 16
   
   with ctx:
       src_data = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, src_length),
                                       dtype=np.int32)
       src_valid_length = mx.np.random.randint(src_length // 2, src_length, (batch_size,),
                                               dtype=np.int32)
       tgt_data = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, tgt_length),
                                       dtype=np.int32)
       tgt_valid_length = mx.np.random.randint(tgt_length // 2, tgt_length, (batch_size, ),
                                               dtype=np.int32)
       model = BartModel.from_cfg(cfg)
       model.initialize(ctx=ctx)
       model.hybridize()
   
       model(src_data, src_valid_length, tgt_data, tgt_valid_length)
       mx.npx.waitall()
   
       model.cast('float16')
       model.hybridize()
       model(src_data, src_valid_length, tgt_data, tgt_valid_length)
       mx.npx.waitall()
   ```
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] sxjscience commented on issue #19472: [Bug][2.0] GluonNLP BART model float16 error

Posted by GitBox <gi...@apache.org>.
sxjscience commented on issue #19472:
URL: https://github.com/apache/incubator-mxnet/issues/19472#issuecomment-721888681


   Closed as I noticed it's a problem caused by GluonNLP implementation.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] szha commented on issue #19472: [Bug][AMP][2.0] GluonNLP BART model AMP error

Posted by GitBox <gi...@apache.org>.
szha commented on issue #19472:
URL: https://github.com/apache/incubator-mxnet/issues/19472#issuecomment-721867827


   the reported line of error is `model_fp16(src_data, src_valid_length, tgt_data, tgt_valid_length)`, which hasn't invoked amp yet. what's the expected behavior for sharing parameters and then casting the model to fp16?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] sxjscience closed issue #19472: [Bug][2.0] GluonNLP BART model float16 error

Posted by GitBox <gi...@apache.org>.
sxjscience closed issue #19472:
URL: https://github.com/apache/incubator-mxnet/issues/19472


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org