You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/04/30 14:34:08 UTC

[incubator-mxnet] branch master updated: [BUGFIX]try avoid the error in operator/tensor/amp_cast.h (#20188)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0946c8e  [BUGFIX]try avoid the error in operator/tensor/amp_cast.h (#20188)
0946c8e is described below

commit 0946c8eb888a33d053cc6e683a9e2c76b4db0392
Author: Neutron3529 <qw...@163.com>
AuthorDate: Fri Apr 30 22:32:54 2021 +0800

    [BUGFIX]try avoid the error in operator/tensor/amp_cast.h (#20188)
    
    * try avoid the error in operator/tensor/amp_cast.h
    
    I'm trying to avoid the error generated by amp using bfloat16
    
    The error is due to:
    ```
    /me/prog/prog-amp.py:77: UserWarning: All children of this Sequential layer 'compose1_' are HybridBlocks. Consider using HybridSequential for the best performance.
      transform_test.hybridize(static_alloc=True,static_shape=True)
    Traceback (most recent call last):
      File "/me/prog/prog-amp.py", line 359, in <module>
        loss0   = loss_fn(output, label)
      File "/me/incubator-mxnet/python/mxnet/ndarray/ndarray.py", line 314, in __mul__
        return multiply(self, other)
      File "/me/incubator-mxnet/python/mxnet/ndarray/ndarray.py", line 3757, in multiply
        return _ufunc_helper(
      File "/me/incubator-mxnet/python/mxnet/ndarray/ndarray.py", line 3576, in _ufunc_helper
        return fn_array(lhs, rhs)
      File "/me/incubator-mxnet/python/mxnet/contrib/amp/amp.py", line 109, in _new_fun
        return f(*args, **kwargs)
      File "<string>", line 52, in broadcast_mul
      File "/me/incubator-mxnet/python/mxnet/_ctypes/ndarray.py", line 82, in _imperative_invoke
        check_call(_LIB.MXImperativeInvokeEx(
      File "/me/incubator-mxnet/python/mxnet/base.py", line 246, in check_call
        raise get_last_ffi_error()
    mxnet.base.MXNetError: Traceback (most recent call last):
      File "/me/incubator-mxnet/src/io/../operator/elemwise_op_common.h", line 135
    MXNetError: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node  at 1-th input: expected bfloat16, got float32
    Error in atexit._run_exitfuncs:
    Traceback (most recent call last):
      File "/me/incubator-mxnet/python/mxnet/base.py", line 587, in _notify_shutdown
        check_call(_LIB.MXNotifyShutdown())
      File "/me/incubator-mxnet/python/mxnet/base.py", line 246, in check_call
        raise get_last_ffi_error()
    mxnet.base.MXNetError: Traceback (most recent call last):
      File "/me/incubator-mxnet/src/operator/tensor/./amp_cast.h", line 136
    MXNetError: Unknown type enum 12
    ```
    which is tested under mxnet v1.x, but seems also affect v2.0
    
    since 30-series RTX card support bfloat16, there is no need to disable it using `#ifndef __NVCC__` explicitly,
    
    I don't know whether it works, but things could not be worse.
    
    * forgive my garbage coding, I'm not a computer scientist
    
    * revert all the modification of base.h
    
    Co-authored-by: Neutron3529 <qw...@mail.ustc.edu.cn>
---
 src/operator/tensor/amp_cast.h | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/src/operator/tensor/amp_cast.h b/src/operator/tensor/amp_cast.h
index 685a05a..9a45bd0 100644
--- a/src/operator/tensor/amp_cast.h
+++ b/src/operator/tensor/amp_cast.h
@@ -133,9 +133,9 @@ void AMPCastCompute(const nnvm::NodeAttrs& attrs,
   using namespace mshadow;
   using namespace mshadow::expr;
   Stream<xpu> *s = ctx.get_stream<xpu>();
-  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DstDType, {
+  MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DstDType, {
     Tensor<xpu, 1, DstDType> out = outputs[0].FlatTo1D<xpu, DstDType>(s);
-    MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, SrcDType, {
+    MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, SrcDType, {
       Tensor<xpu, 1, SrcDType> data = inputs[0].FlatTo1D<xpu, SrcDType>(s);
       if (outputs[0].type_flag_ != inputs[0].type_flag_ ||
           req[0] != kWriteInplace) {
@@ -155,9 +155,9 @@ void AMPMultiCastCompute(const nnvm::NodeAttrs& attrs,
   using namespace mshadow::expr;
   Stream<xpu> *s = ctx.get_stream<xpu>();
   for (size_t i = 0; i < outputs.size(); ++i) {
-    MSHADOW_TYPE_SWITCH(outputs[i].type_flag_, DstDType, {
+    MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[i].type_flag_, DstDType, {
       Tensor<xpu, 1, DstDType> out = outputs[i].FlatTo1D<xpu, DstDType>(s);
-      MSHADOW_TYPE_SWITCH(inputs[i].type_flag_, SrcDType, {
+      MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[i].type_flag_, SrcDType, {
         Tensor<xpu, 1, SrcDType> data = inputs[i].FlatTo1D<xpu, SrcDType>(s);
         if (outputs[i].type_flag_ != inputs[i].type_flag_ ||
             req[i] != kWriteInplace) {