You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@mxnet.apache.org by GitBox <gi...@apache.org> on 2021/02/02 05:34:32 UTC

[GitHub] [incubator-mxnet] ceisenach opened a new issue #19817: F.Take Backwards Incorrect Gradient

ceisenach opened a new issue #19817:
URL: https://github.com/apache/incubator-mxnet/issues/19817


   ## Description
   Backwards implementation of F.take computes incorrect gradient when used after  sequence of transpose -> convolution -> transpose. any trainable parameters that receive gradients through the `F.take` operator are incorrect. Equivalent implementations using slice operators produce correct results.
   
   ### Other Details
   I have been unable to find any other scenario when it happens (for example, if one replaces the Conv Layers in the example below with a linear layer, there is no issue with the gradient computation).
   
   I also encounter the bug on MXNet 1.5 and 1.6 (have not tested with earlier versions).
   
   ## To Reproduce
   Below I provide an example of a simple model with two implementations -- one that uses `F.take` (Model A) and one that uses `F.slice_axis` (Model B) instead.
   
   ```py
   def conv_layer(atrous_rates, num_channels):
       convs = HybridSequential()
       convs.add(HybridLambda(lambda F, x: F.transpose(x, (0, 2, 1))))
       for rate in atrous_rates:
           convs.add(Conv1D(num_channels, 3, padding=rate, dilation=rate, activation='tanh'))
       convs.add(HybridLambda(lambda F, x: F.transpose(x, (0, 2, 1))))
       return convs
   
   
   class Model(HybridBlock):
       """
       Model takes tensors of shape N x T x C and produces predictions with shape N x T
       """
   
       def __init__(self, conv_units, atrous_rates, use_take=False, **kwargs):
           super().__init__(prefix=kwargs.get('prefix', None), params=kwargs.get('params', None))
           self.use_take = use_take
           with self.name_scope():
               self.convs = conv_layer(atrous_rates, conv_units)
               self.dense_out = Dense(1, flatten=False, activation='tanh')
   
       def hybrid_forward(self, F, X):
           X1 = X
           X2 = self.convs(X1)
           if self.use_take:
               X3 = F.take(X2, nd.array([1, 2, 3]), axis=-1)
           else:
               X3 = F.slice_axis(X2, begin=1, end=4, axis=-1)
           X4 = self.dense_out(X3)
           X4 = F.squeeze(X4, axis=-1)
           return X4
   ```
   
   The script provided below instantiates both implementations with the same initial weights, computes L2Loss and prints the gradients from both models. A random seed is set so the output should be deterministic (and it is for Model B).
   
   ### Steps to reproduce
   1. Download this script: https://gist.github.com/ceisenach/9ffed8343e5576748ec7d5623ffe6c46 
   1. Run script (`python take_bug.py`)
   
   
   ### Result
   1. As expected, output of forward pass is the same from both models
   2. Gradients (Model A): parameters in Model A that receive gradients through `F.take` are on the order of 1e28 (or in some cases are infinite). The results are non-deterministic
   3. Gradients (Model B): Gradient values seem reasonable and are deterministic (same results each time). 
   
   Example output from the script I provided
   
   ```
   ||g_param||_2: INF | Param: model0_conv0_weight
   ||g_param||_2: 7.21E+18 | Param: model0_conv0_bias
   ||g_param||_2: INF | Param: model0_conv1_weight
   ||g_param||_2: INF | Param: model0_conv1_bias
   ||g_param||_2: INF | Param: model0_conv2_weight
   ||g_param||_2: INF | Param: model0_conv2_bias
   ||g_param||_2: 1.38E-04 | Param: model0_dense0_weight
   ||g_param||_2: 1.06E-02 | Param: model0_dense0_bias
   
       -------------------------------------------
       -------  Grad Info
       *  ||g||_2: INF
       *  ||g||_1: 1.77E+21
       *  ||g||_inf: 5.79E+20
   
       
   ||g_param||_2: 2.37E-04 | Param: model1_conv0_weight
   ||g_param||_2: 2.29E-05 | Param: model1_conv0_bias
   ||g_param||_2: 2.23E-04 | Param: model1_conv1_weight
   ||g_param||_2: 1.50E-04 | Param: model1_conv1_bias
   ||g_param||_2: 4.26E-04 | Param: model1_conv2_weight
   ||g_param||_2: 7.02E-04 | Param: model1_conv2_bias
   ||g_param||_2: 1.38E-04 | Param: model1_dense0_weight
   ||g_param||_2: 1.06E-02 | Param: model1_dense0_bias
   
       -------------------------------------------
       -------  Grad Info
       *  ||g||_2: 1.06E-02
       *  ||g||_1: 1.75E-02
       *  ||g||_inf: 1.06E-02
   
       
   ==== Same outputs?
   Y_hat1 - Yhat2 = 0.0000
   ```
   
   It appears that there is either an OOB memory access or some values involved in the calculation are not initialized before they are used. I haven't attempted to track down the root cause. 
   
   
   ## What have you tried to solve it?
   
   In many cases, can workaround by using one of the slice operators instead. They do not appear to have any issues.
   
   ## Environment
   
   OS: ubuntu 18.04
   Python: 3.8.5
   pip: 20.2.3
   mxnet: 1.7.0 (Commit Hash: 64f737cdd59fe88d2c5b479f25d011c5156b6a8a)


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] szha commented on issue #19817: F.Take Backwards - Incorrect Gradient

Posted by GitBox <gi...@apache.org>.
szha commented on issue #19817:
URL: https://github.com/apache/incubator-mxnet/issues/19817#issuecomment-778660607


   It's unclear to me. The following commits are only on master and not on v1.x:
   e3d7866e6854a5c11ab2b2c8bfb63de66f79e132
   c1098aa33d6795f84a19601d0319d5bb8e19f317
   344587f295666e4375042d054cd5a134fdeaf517
   50312af58b2ec3e951da0809dd0c800a62dcf1f9
   18a784a3276dab8208be54c0a81b3e85b5495a46


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] waytrue17 commented on issue #19817: F.Take Backwards - Incorrect Gradient

Posted by GitBox <gi...@apache.org>.
waytrue17 commented on issue #19817:
URL: https://github.com/apache/incubator-mxnet/issues/19817#issuecomment-922647909


   I think the issue should be fixed by #20166. Would we close the issue? @ceisenach @szha 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] szha edited a comment on issue #19817: F.Take Backwards - Incorrect Gradient

Posted by GitBox <gi...@apache.org>.
szha edited a comment on issue #19817:
URL: https://github.com/apache/incubator-mxnet/issues/19817#issuecomment-778661805


   Actually I think this bug appears to be non-deterministic. If I run the script a couple more times I get weird results such as the following, which happens on both v1.x and on master:
   
   <details><summary>script</summary>
   
   ```
   import numpy as np
   import mxnet as mx
   from mxnet.gluon.nn import HybridBlock, Conv1D, HybridSequential, HybridLambda, Dense
   from mxnet import autograd, nd
   from mxnet.gluon.loss import L2Loss
   
   print(mx.__version__)
   print(mx.runtime.feature_list())
   
   
   def print_grads(model, ctx=mx.cpu()):
       pd = model.collect_params()
       total_grad_l2 = 0
       total_grad_l1 = 0
       total_grad_linf = 0
       for p in pd:
           try:
               g = pd[p].grad(ctx) / N
               g2 = (g**2).sum().as_in_context(mx.cpu()).asscalar()
               g1 = g.abs().sum().as_in_context(mx.cpu()).asscalar()
               ginf = g.max().as_in_context(mx.cpu()).asscalar()
               total_grad_linf = max(total_grad_linf, ginf)
               total_grad_l2 += g2
               total_grad_l1 += g1
               print(f"||g_param||_2: {g2**0.5:.2E} | Param: {p}")
           except Exception:
               pass
       grad_info = f"""
       -------------------------------------------
       -------  Grad Info
       *  ||g||_2: {total_grad_l2**0.5:.2E}
       *  ||g||_1: {total_grad_l1:.2E}
       *  ||g||_inf: {total_grad_linf:.2E}
       """
       print(grad_info)
   
   
   def run_model(model, loss, X, Y, num_iters=1):
       for i in range(num_iters):
           with autograd.record():
               Y_hat = model(X)
               ll = loss(Y_hat, Y)
               ll = ll.sum()
               ll.backward()
               print_grads(model)
       return Y_hat
   
   
   def conv_layer(atrous_rates, num_channels):
       convs = HybridSequential()
       convs.add(HybridLambda(lambda F, x: F.transpose(x, (0, 2, 1))))
       for rate in atrous_rates:
           convs.add(Conv1D(num_channels, 3, padding=rate, dilation=rate, activation='tanh'))
       convs.add(HybridLambda(lambda F, x: F.transpose(x, (0, 2, 1))))
       return convs
   
   
   class Model(HybridBlock):
       """
       Model takes tensors of shape N x T x C and produces predictions with shape N x T
       """
   
       def __init__(self, conv_units, atrous_rates, use_take=False, **kwargs):
           super().__init__()
           self.use_take = use_take
           self.convs = conv_layer(atrous_rates, conv_units)
           self.dense_out = Dense(1, flatten=False, activation='tanh')
   
       def hybrid_forward(self, F, X):
           X1 = X
           X2 = self.convs(X1)
           if self.use_take:
               X3 = F.take(X2, nd.array([1, 2, 3]), axis=-1)
           else:
               X3 = F.slice_axis(X2, begin=1, end=4, axis=-1)
           X4 = self.dense_out(X3)
           X4 = F.squeeze(X4, axis=-1)
           return X4
   
   
   if __name__ == "__main__":
       N = 30
       T = 20
       C = 8
       conv_units = 5
       atrous_rates = [1, 2, 4]
       np.random.seed(1234)
   
       X = np.random.normal(size=(N, T, C))
       Y = np.random.normal(size=(N, T))
       X, Y = nd.array(X), nd.array(Y)
   
       # Using F.take
       mx.random.seed(12354)
       model = Model(conv_units, atrous_rates, use_take=True)
       model.initialize()
       loss = L2Loss()
       Y_hat1 = run_model(model, loss, X, Y)
   
       # Using F.slice_axis
       mx.random.seed(12354)
       model2 = Model(conv_units, atrous_rates, use_take=False)
       model2.initialize()
       loss2 = L2Loss()
       Y_hat2 = run_model(model2, loss2, X, Y)
   
       delta = nd.abs(Y_hat1-Y_hat2).sum().asscalar()
       print("==== Same outputs?")
       print(f"Y_hat1 - Yhat2 = {delta:.4f}")
   ```
   
   </details>
   
   <details><summary>environment</summary>
   
   from commit bca8de85c6011d145e52ae06639fb2ae129d0480
   
   ```
   ----------Python Info----------
   Version      : 3.8.7
   Compiler     : Clang 12.0.0 (clang-1200.0.32.28)
   Build        : ('default', 'Dec 30 2020 10:14:55')
   Arch         : ('64bit', '')
   ------------Pip Info-----------
   Version      : 20.3.3
   Directory    : /usr/local/lib/python3.8/site-packages/pip
   ----------MXNet Info-----------
   Version      : 2.0.0
   Directory    : /Users/zhasheng/mxnet/python/mxnet
   Commit hash file "/Users/zhasheng/mxnet/python/mxnet/COMMIT_HASH" not found. Not installed from pre-built package or built from source.
   Library      : ['/Users/zhasheng/mxnet/python/mxnet/../../build/libmxnet.dylib']
   Build features:
   ✖ CUDA
   ✖ CUDNN
   ✖ NCCL
   ✖ TENSORRT
   ✖ CUTENSOR
   ✔ CPU_SSE
   ✔ CPU_SSE2
   ✔ CPU_SSE3
   ✔ CPU_SSE4_1
   ✔ CPU_SSE4_2
   ✖ CPU_SSE4A
   ✔ CPU_AVX
   ✖ CPU_AVX2
   ✖ OPENMP
   ✖ SSE
   ✔ F16C
   ✖ JEMALLOC
   ✖ BLAS_OPEN
   ✖ BLAS_ATLAS
   ✖ BLAS_MKL
   ✔ BLAS_APPLE
   ✔ LAPACK
   ✔ MKLDNN
   ✖ OPENCV
   ✖ DIST_KVSTORE
   ✖ INT64_TENSOR_SIZE
   ✔ SIGNAL_HANDLER
   ✔ DEBUG
   ✖ TVM_OP
   ----------System Info----------
   Platform     : macOS-11.2.1-x86_64-i386-64bit
   system       : Darwin
   node         : a483e79ab3ab
   release      : 20.3.0
   version      : Darwin Kernel Version 20.3.0: Thu Jan 21 00:07:06 PST 2021; root:xnu-7195.81.3~1/RELEASE_X86_64
   ----------Hardware Info----------
   machine      : x86_64
   processor    : i386
   b'machdep.cpu.brand_string: Intel(R) Core(TM) i7-8569U CPU @ 2.80GHz'
   b'machdep.cpu.features: FPU VME DE PSE TSC MSR PAE MCE CX8 APIC SEP MTRR PGE MCA CMOV PAT PSE36 CLFSH DS ACPI MMX FXSR SSE SSE2 SS HTT TM PBE SSE3 PCLMULQDQ DTES64 MON DSCPL VMX EST TM2 SSSE3 FMA CX16 TPR PDCM SSE4.1 SSE4.2 x2APIC MOVBE POPCNT AES PCID XSAVE OSXSAVE SEGLIM64 TSCTMR AVX1.0 RDRAND F16C'
   b'machdep.cpu.leaf7_features: RDWRFSGS TSC_THREAD_OFFSET SGX BMI1 AVX2 SMEP BMI2 ERMS INVPCID FPU_CSDS MPX RDSEED ADX SMAP CLFSOPT IPT MDCLEAR TSXFA IBRS STIBP L1DF SSBD'
   b'machdep.cpu.extfeatures: SYSCALL XD 1GBPAGE EM64T LAHF LZCNT PREFETCHW RDTSCP TSCI'
   ----------Network Test----------
   Setting timeout: 10
   Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 0.0137 sec, LOAD: 0.2581 sec.
   Timing for Gluon Tutorial(en): http://gluon.mxnet.io, DNS: 0.0852 sec, LOAD: 0.2603 sec.
   Error open Gluon Tutorial(cn): https://zh.gluon.ai, <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1125)>, DNS finished in 0.23605990409851074 sec.
   Timing for FashionMNIST: https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 0.0248 sec, LOAD: 0.2969 sec.
   Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0103 sec, LOAD: 0.3477 sec.
   Error open Conda: https://repo.continuum.io/pkgs/free/, HTTP Error 403: Forbidden, DNS finished in 0.014931201934814453 sec.
   ----------Environment----------
   CC="/usr/local/opt/llvm/bin/clang"
   CXX="/usr/local/opt/llvm/bin/clang++"
   KMP_DUPLICATE_LIB_OK="True"
   KMP_INIT_AT_FORK="FALSE"
   ```
   
   </details>
   
   ```
   2.0.0
   [✖ CUDA, ✖ CUDNN, ✖ NCCL, ✖ TENSORRT, ✖ CUTENSOR, ✔ CPU_SSE, ✔ CPU_SSE2, ✔ CPU_SSE3, ✔ CPU_SSE4_1, ✔ CPU_SSE4_2, ✖ CPU_SSE4A, ✔ CPU_AVX, ✖ CPU_AVX2, ✖ OPENMP, ✖ SSE, ✔ F16C, ✖ JEMALLOC, ✖ BLAS_OPEN, ✖ BLAS_ATLAS, ✖ BLAS_MKL, ✔ BLAS_APPLE, ✔ LAPACK, ✔ MKLDNN, ✖ OPENCV, ✖ DIST_KVSTORE, ✖ INT64_TENSOR_SIZE, ✔ SIGNAL_HANDLER, ✔ DEBUG, ✖ TVM_OP]
   [13:57:38] ../src/storage/storage.cc:199: Using Pooled (Naive) StorageManager for CPU
   ||g_param||_2: 2.27E+11 | Param: convs.1.weight
   ||g_param||_2: 2.15E+10 | Param: convs.1.bias
   ||g_param||_2: 2.46E+11 | Param: convs.2.weight
   ||g_param||_2: 4.30E+11 | Param: convs.2.bias
   ||g_param||_2: 2.54E+11 | Param: convs.3.weight
   ||g_param||_2: 2.66E+12 | Param: convs.3.bias
   ||g_param||_2: 1.38E-04 | Param: dense_out.weight
   ||g_param||_2: 1.06E-02 | Param: dense_out.bias
   
       -------------------------------------------
       -------  Grad Info
       *  ||g||_2: 2.73E+12
       *  ||g||_1: 1.19E+13
       *  ||g||_inf: 1.86E+12
   
   ||g_param||_2: 2.37E-04 | Param: convs.1.weight
   ||g_param||_2: 2.29E-05 | Param: convs.1.bias
   ||g_param||_2: 2.23E-04 | Param: convs.2.weight
   ||g_param||_2: 1.50E-04 | Param: convs.2.bias
   ||g_param||_2: 4.26E-04 | Param: convs.3.weight
   ||g_param||_2: 7.02E-04 | Param: convs.3.bias
   ||g_param||_2: 1.38E-04 | Param: dense_out.weight
   ||g_param||_2: 1.06E-02 | Param: dense_out.bias
   
       -------------------------------------------
       -------  Grad Info
       *  ||g||_2: 1.06E-02
       *  ||g||_1: 1.75E-02
       *  ||g||_inf: 1.06E-02
   
   ==== Same outputs?
   Y_hat1 - Yhat2 = 0.0000
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] ceisenach commented on issue #19817: F.Take Backwards - Incorrect Gradient

Posted by GitBox <gi...@apache.org>.
ceisenach commented on issue #19817:
URL: https://github.com/apache/incubator-mxnet/issues/19817#issuecomment-923339412


   When I use the latest nightly builds, I no longer observe the bug, so it seems resolved to me. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] szha commented on issue #19817: F.Take Backwards - Incorrect Gradient

Posted by GitBox <gi...@apache.org>.
szha commented on issue #19817:
URL: https://github.com/apache/incubator-mxnet/issues/19817#issuecomment-778661805


   Actually I think this bug appears to be non-deterministic. If I run the script a couple more times I get weird results such as the following, which happens on both v1.x and on master:
   ```
   python3 snippets/take_bug.py
   2.0.0
   [13:51:16] ../src/storage/storage.cc:199: Using Pooled (Naive) StorageManager for CPU
   ||g_param||_2: 7.51E+16 | Param: convs.1.weight
   ||g_param||_2: 7.08E+15 | Param: convs.1.bias
   ||g_param||_2: 8.05E+16 | Param: convs.2.weight
   ||g_param||_2: 1.39E+17 | Param: convs.2.bias
   ||g_param||_2: 8.18E+16 | Param: convs.3.weight
   ||g_param||_2: 8.63E+17 | Param: convs.3.bias
   ||g_param||_2: 1.38E-04 | Param: dense_out.weight
   ||g_param||_2: 1.06E-02 | Param: dense_out.bias
   
       -------------------------------------------
       -------  Grad Info
       *  ||g||_2: 8.85E+17
       *  ||g||_1: 3.82E+18
       *  ||g||_inf: 1.07E+17
   
   ||g_param||_2: 2.37E-04 | Param: convs.1.weight
   ||g_param||_2: 2.29E-05 | Param: convs.1.bias
   ||g_param||_2: 2.23E-04 | Param: convs.2.weight
   ||g_param||_2: 1.50E-04 | Param: convs.2.bias
   ||g_param||_2: 4.26E-04 | Param: convs.3.weight
   ||g_param||_2: 7.02E-04 | Param: convs.3.bias
   ||g_param||_2: 1.38E-04 | Param: dense_out.weight
   ||g_param||_2: 1.06E-02 | Param: dense_out.bias
   
       -------------------------------------------
       -------  Grad Info
       *  ||g||_2: 1.06E-02
       *  ||g||_1: 1.75E-02
       *  ||g||_inf: 1.06E-02
   
   ==== Same outputs?
   Y_hat1 - Yhat2 = 0.0000
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] ceisenach commented on issue #19817: F.Take Backwards - Incorrect Gradient

Posted by GitBox <gi...@apache.org>.
ceisenach commented on issue #19817:
URL: https://github.com/apache/incubator-mxnet/issues/19817#issuecomment-779508902


   Yeah I observe similar behavior on v1.x -- sometimes the grad calculation is correct, but most of the time they are different


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] ceisenach closed issue #19817: F.Take Backwards - Incorrect Gradient

Posted by GitBox <gi...@apache.org>.
ceisenach closed issue #19817:
URL: https://github.com/apache/incubator-mxnet/issues/19817


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] szha edited a comment on issue #19817: F.Take Backwards - Incorrect Gradient

Posted by GitBox <gi...@apache.org>.
szha edited a comment on issue #19817:
URL: https://github.com/apache/incubator-mxnet/issues/19817#issuecomment-778661805


   Actually I think this bug appears to be non-deterministic. If I run the script a couple more times I get weird results such as the following, which happens on both v1.x and on master:
   
   <details><summary>script</summary>
   
   ```
   import numpy as np
   import mxnet as mx
   from mxnet.gluon.nn import HybridBlock, Conv1D, HybridSequential, HybridLambda, Dense
   from mxnet import autograd, nd
   from mxnet.gluon.loss import L2Loss
   
   print(mx.__version__)
   print(mx.runtime.feature_list())
   
   
   def print_grads(model, ctx=mx.cpu()):
       pd = model.collect_params()
       total_grad_l2 = 0
       total_grad_l1 = 0
       total_grad_linf = 0
       for p in pd:
           try:
               g = pd[p].grad(ctx) / N
               g2 = (g**2).sum().as_in_context(mx.cpu()).asscalar()
               g1 = g.abs().sum().as_in_context(mx.cpu()).asscalar()
               ginf = g.max().as_in_context(mx.cpu()).asscalar()
               total_grad_linf = max(total_grad_linf, ginf)
               total_grad_l2 += g2
               total_grad_l1 += g1
               print(f"||g_param||_2: {g2**0.5:.2E} | Param: {p}")
           except Exception:
               pass
       grad_info = f"""
       -------------------------------------------
       -------  Grad Info
       *  ||g||_2: {total_grad_l2**0.5:.2E}
       *  ||g||_1: {total_grad_l1:.2E}
       *  ||g||_inf: {total_grad_linf:.2E}
       """
       print(grad_info)
   
   
   def run_model(model, loss, X, Y, num_iters=1):
       for i in range(num_iters):
           with autograd.record():
               Y_hat = model(X)
               ll = loss(Y_hat, Y)
               ll = ll.sum()
               ll.backward()
               print_grads(model)
       return Y_hat
   
   
   def conv_layer(atrous_rates, num_channels):
       convs = HybridSequential()
       convs.add(HybridLambda(lambda F, x: F.transpose(x, (0, 2, 1))))
       for rate in atrous_rates:
           convs.add(Conv1D(num_channels, 3, padding=rate, dilation=rate, activation='tanh'))
       convs.add(HybridLambda(lambda F, x: F.transpose(x, (0, 2, 1))))
       return convs
   
   
   class Model(HybridBlock):
       """
       Model takes tensors of shape N x T x C and produces predictions with shape N x T
       """
   
       def __init__(self, conv_units, atrous_rates, use_take=False, **kwargs):
           super().__init__()
           self.use_take = use_take
           self.convs = conv_layer(atrous_rates, conv_units)
           self.dense_out = Dense(1, flatten=False, activation='tanh')
   
       def hybrid_forward(self, F, X):
           X1 = X
           X2 = self.convs(X1)
           if self.use_take:
               X3 = F.take(X2, nd.array([1, 2, 3]), axis=-1)
           else:
               X3 = F.slice_axis(X2, begin=1, end=4, axis=-1)
           X4 = self.dense_out(X3)
           X4 = F.squeeze(X4, axis=-1)
           return X4
   
   
   if __name__ == "__main__":
       N = 30
       T = 20
       C = 8
       conv_units = 5
       atrous_rates = [1, 2, 4]
       np.random.seed(1234)
   
       X = np.random.normal(size=(N, T, C))
       Y = np.random.normal(size=(N, T))
       X, Y = nd.array(X), nd.array(Y)
   
       # Using F.take
       mx.random.seed(12354)
       model = Model(conv_units, atrous_rates, use_take=True)
       model.initialize()
       loss = L2Loss()
       Y_hat1 = run_model(model, loss, X, Y)
   
       # Using F.slice_axis
       mx.random.seed(12354)
       model2 = Model(conv_units, atrous_rates, use_take=False)
       model2.initialize()
       loss2 = L2Loss()
       Y_hat2 = run_model(model2, loss2, X, Y)
   
       delta = nd.abs(Y_hat1-Y_hat2).sum().asscalar()
       print("==== Same outputs?")
       print(f"Y_hat1 - Yhat2 = {delta:.4f}")
   ```
   
   </details>
   
   <details><summary>environment</summary>
   
   from commit bca8de85c6011d145e52ae06639fb2ae129d0480
   
   ```
   ----------Python Info----------
   Version      : 3.8.7
   Compiler     : Clang 12.0.0 (clang-1200.0.32.28)
   Build        : ('default', 'Dec 30 2020 10:14:55')
   Arch         : ('64bit', '')
   ------------Pip Info-----------
   Version      : 20.3.3
   Directory    : /usr/local/lib/python3.8/site-packages/pip
   ----------MXNet Info-----------
   Version      : 2.0.0
   Directory    : /Users/zhasheng/mxnet/python/mxnet
   Commit hash file "/Users/zhasheng/mxnet/python/mxnet/COMMIT_HASH" not found. Not installed from pre-built package or built from source.
   Library      : ['/Users/zhasheng/mxnet/python/mxnet/../../build/libmxnet.dylib']
   Build features:
   ✖ CUDA
   ✖ CUDNN
   ✖ NCCL
   ✖ TENSORRT
   ✖ CUTENSOR
   ✔ CPU_SSE
   ✔ CPU_SSE2
   ✔ CPU_SSE3
   ✔ CPU_SSE4_1
   ✔ CPU_SSE4_2
   ✖ CPU_SSE4A
   ✔ CPU_AVX
   ✖ CPU_AVX2
   ✖ OPENMP
   ✖ SSE
   ✔ F16C
   ✖ JEMALLOC
   ✖ BLAS_OPEN
   ✖ BLAS_ATLAS
   ✖ BLAS_MKL
   ✔ BLAS_APPLE
   ✔ LAPACK
   ✔ MKLDNN
   ✖ OPENCV
   ✖ DIST_KVSTORE
   ✖ INT64_TENSOR_SIZE
   ✔ SIGNAL_HANDLER
   ✔ DEBUG
   ✖ TVM_OP
   ----------System Info----------
   Platform     : macOS-11.2.1-x86_64-i386-64bit
   system       : Darwin
   node         : a483e79ab3ab
   release      : 20.3.0
   version      : Darwin Kernel Version 20.3.0: Thu Jan 21 00:07:06 PST 2021; root:xnu-7195.81.3~1/RELEASE_X86_64
   ----------Hardware Info----------
   machine      : x86_64
   processor    : i386
   b'machdep.cpu.brand_string: Intel(R) Core(TM) i7-8569U CPU @ 2.80GHz'
   b'machdep.cpu.features: FPU VME DE PSE TSC MSR PAE MCE CX8 APIC SEP MTRR PGE MCA CMOV PAT PSE36 CLFSH DS ACPI MMX FXSR SSE SSE2 SS HTT TM PBE SSE3 PCLMULQDQ DTES64 MON DSCPL VMX EST TM2 SSSE3 FMA CX16 TPR PDCM SSE4.1 SSE4.2 x2APIC MOVBE POPCNT AES PCID XSAVE OSXSAVE SEGLIM64 TSCTMR AVX1.0 RDRAND F16C'
   b'machdep.cpu.leaf7_features: RDWRFSGS TSC_THREAD_OFFSET SGX BMI1 AVX2 SMEP BMI2 ERMS INVPCID FPU_CSDS MPX RDSEED ADX SMAP CLFSOPT IPT MDCLEAR TSXFA IBRS STIBP L1DF SSBD'
   b'machdep.cpu.extfeatures: SYSCALL XD 1GBPAGE EM64T LAHF LZCNT PREFETCHW RDTSCP TSCI'
   ----------Network Test----------
   Setting timeout: 10
   Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 0.0137 sec, LOAD: 0.2581 sec.
   Timing for Gluon Tutorial(en): http://gluon.mxnet.io, DNS: 0.0852 sec, LOAD: 0.2603 sec.
   Error open Gluon Tutorial(cn): https://zh.gluon.ai, <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1125)>, DNS finished in 0.23605990409851074 sec.
   Timing for FashionMNIST: https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 0.0248 sec, LOAD: 0.2969 sec.
   Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0103 sec, LOAD: 0.3477 sec.
   Error open Conda: https://repo.continuum.io/pkgs/free/, HTTP Error 403: Forbidden, DNS finished in 0.014931201934814453 sec.
   ----------Environment----------
   CC="/usr/local/opt/llvm/bin/clang"
   CXX="/usr/local/opt/llvm/bin/clang++"
   KMP_DUPLICATE_LIB_OK="True"
   KMP_INIT_AT_FORK="FALSE"
   ```
   
   </details>
   
   ```
   2.0.0
   [✖ CUDA, ✖ CUDNN, ✖ NCCL, ✖ TENSORRT, ✖ CUTENSOR, ✔ CPU_SSE, ✔ CPU_SSE2, ✔ CPU_SSE3, ✔ CPU_SSE4_1, ✔ CPU_SSE4_2, ✖ CPU_SSE4A, ✔ CPU_AVX, ✖ CPU_AVX2, ✖ OPENMP, ✖ SSE, ✔ F16C, ✖ JEMALLOC, ✖ BLAS_OPEN, ✖ BLAS_ATLAS, ✖ BLAS_MKL, ✔ BLAS_APPLE, ✔ LAPACK, ✔ MKLDNN, ✖ OPENCV, ✖ DIST_KVSTORE, ✖ INT64_TENSOR_SIZE, ✔ SIGNAL_HANDLER, ✔ DEBUG, ✖ TVM_OP]
   [13:57:38] ../src/storage/storage.cc:199: Using Pooled (Naive) StorageManager for CPU
   ||g_param||_2: 2.27E+11 | Param: convs.1.weight
   ||g_param||_2: 2.15E+10 | Param: convs.1.bias
   ||g_param||_2: 2.46E+11 | Param: convs.2.weight
   ||g_param||_2: 4.30E+11 | Param: convs.2.bias
   ||g_param||_2: 2.54E+11 | Param: convs.3.weight
   ||g_param||_2: 2.66E+12 | Param: convs.3.bias
   ||g_param||_2: 1.38E-04 | Param: dense_out.weight
   ||g_param||_2: 1.06E-02 | Param: dense_out.bias
   
       -------------------------------------------
       -------  Grad Info
       *  ||g||_2: 2.73E+12
       *  ||g||_1: 1.19E+13
       *  ||g||_inf: 1.86E+12
   
   ||g_param||_2: 2.37E-04 | Param: convs.1.weight
   ||g_param||_2: 2.29E-05 | Param: convs.1.bias
   ||g_param||_2: 2.23E-04 | Param: convs.2.weight
   ||g_param||_2: 1.50E-04 | Param: convs.2.bias
   ||g_param||_2: 4.26E-04 | Param: convs.3.weight
   ||g_param||_2: 7.02E-04 | Param: convs.3.bias
   ||g_param||_2: 1.38E-04 | Param: dense_out.weight
   ||g_param||_2: 1.06E-02 | Param: dense_out.bias
   
       -------------------------------------------
       -------  Grad Info
       *  ||g||_2: 1.06E-02
       *  ||g||_1: 1.75E-02
       *  ||g||_inf: 1.06E-02
   
   ==== Same outputs?
   Y_hat1 - Yhat2 = 0.0000
   ```
   
   Update: if I turn off mkldnn, the results are consistently different
   ```
   2.0.0
   [✖ CUDA, ✖ CUDNN, ✖ NCCL, ✖ TENSORRT, ✖ CUTENSOR, ✔ CPU_SSE, ✔ CPU_SSE2, ✔ CPU_SSE3, ✔ CPU_SSE
   4_1, ✔ CPU_SSE4_2, ✖ CPU_SSE4A, ✔ CPU_AVX, ✖ CPU_AVX2, ✖ OPENMP, ✖ SSE, ✔ F16C, ✖ JEMALLOC, ✖
   BLAS_OPEN, ✖ BLAS_ATLAS, ✖ BLAS_MKL, ✔ BLAS_APPLE, ✔ LAPACK, ✖ MKLDNN, ✖ OPENCV, ✖ DIST_KVSTOR
   E, ✖ INT64_TENSOR_SIZE, ✔ SIGNAL_HANDLER, ✔ DEBUG, ✖ TVM_OP]
   ||g_param||_2: 3.91E-03 | Param: convs.1.weight
   ||g_param||_2: 1.57E-04 | Param: convs.1.bias
   ||g_param||_2: 5.76E-03 | Param: convs.2.weight
   ||g_param||_2: 7.88E-04 | Param: convs.2.bias
   ||g_param||_2: 6.51E-03 | Param: convs.3.weight
   ||g_param||_2: 5.04E-03 | Param: convs.3.bias
   ||g_param||_2: 1.38E-04 | Param: dense_out.weight
   ||g_param||_2: 1.06E-02 | Param: dense_out.bias
   
       -------------------------------------------
       -------  Grad Info
       *  ||g||_2: 1.51E-02
       *  ||g||_1: 1.39E-01
       *  ||g||_inf: 1.06E-02
   
   ||g_param||_2: 2.37E-04 | Param: convs.1.weight
   ||g_param||_2: 2.29E-05 | Param: convs.1.bias
   ||g_param||_2: 2.23E-04 | Param: convs.2.weight
   ||g_param||_2: 1.50E-04 | Param: convs.2.bias
   ||g_param||_2: 4.26E-04 | Param: convs.3.weight
   ||g_param||_2: 7.02E-04 | Param: convs.3.bias
   ||g_param||_2: 1.38E-04 | Param: dense_out.weight
   ||g_param||_2: 1.06E-02 | Param: dense_out.bias
   
       -------------------------------------------
       -------  Grad Info
       *  ||g||_2: 1.06E-02
       *  ||g||_1: 1.75E-02
       *  ||g||_inf: 1.06E-02
   
   ==== Same outputs?
   Y_hat1 - Yhat2 = 0.0000
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] ceisenach commented on issue #19817: F.Take Backwards - Incorrect Gradient

Posted by GitBox <gi...@apache.org>.
ceisenach commented on issue #19817:
URL: https://github.com/apache/incubator-mxnet/issues/19817#issuecomment-778655158


   Thanks for looking into this -- do you know which commit fixed the bug? Also, do you know which upcoming release would contain the bugfix?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] szha commented on issue #19817: F.Take Backwards - Incorrect Gradient

Posted by GitBox <gi...@apache.org>.
szha commented on issue #19817:
URL: https://github.com/apache/incubator-mxnet/issues/19817#issuecomment-775388041


   I can confirm that this bug has been fixed on master branch. Here are the outputs from the master branch (after adopting the new Gluon interface)
   
   <details><summary>script</summary>
   
   ```python
   import numpy as np
   import mxnet as mx
   from mxnet.gluon.nn import HybridBlock, Conv1D, HybridSequential, HybridLambda, Dense
   from mxnet import autograd, nd
   from mxnet.gluon.loss import L2Loss
   
   
   def print_grads(model, ctx=mx.cpu()):
       pd = model.collect_params()
       total_grad_l2 = 0
       total_grad_l1 = 0
       total_grad_linf = 0
       for p in pd:
           try:
               g = pd[p].grad(ctx) / N
               g2 = (g**2).sum().as_in_context(mx.cpu()).asscalar()
               g1 = g.abs().sum().as_in_context(mx.cpu()).asscalar()
               ginf = g.max().as_in_context(mx.cpu()).asscalar()
               total_grad_linf = max(total_grad_linf, ginf)
               total_grad_l2 += g2
               total_grad_l1 += g1
               print(f"||g_param||_2: {g2**0.5:.2E} | Param: {p}")
           except Exception:
               pass
       grad_info = f"""
       -------------------------------------------
       -------  Grad Info
       *  ||g||_2: {total_grad_l2**0.5:.2E}
       *  ||g||_1: {total_grad_l1:.2E}
       *  ||g||_inf: {total_grad_linf:.2E}
   
       """
       print(grad_info)
   
   
   def run_model(model, loss, X, Y, num_iters=1):
       for i in range(num_iters):
           with autograd.record():
               Y_hat = model(X)
               ll = loss(Y_hat, Y)
               ll = ll.sum()
               ll.backward()
               print_grads(model)
       return Y_hat
   
   
   def conv_layer(atrous_rates, num_channels):
       convs = HybridSequential()
       convs.add(HybridLambda(lambda F, x: F.transpose(x, (0, 2, 1))))
       for rate in atrous_rates:
           convs.add(Conv1D(num_channels, 3, padding=rate, dilation=rate, activation='tanh'))
       convs.add(HybridLambda(lambda F, x: F.transpose(x, (0, 2, 1))))
       return convs
   
   
   class Model(HybridBlock):
       """
       Model takes tensors of shape N x T x C and produces predictions with shape N x T
       """
   
       def __init__(self, conv_units, atrous_rates, use_take=False, **kwargs):
           super().__init__()
           self.use_take = use_take
           self.convs = conv_layer(atrous_rates, conv_units)
           self.dense_out = Dense(1, flatten=False, activation='tanh')
   
       def hybrid_forward(self, F, X):
           X1 = X
           X2 = self.convs(X1)
           if self.use_take:
               X3 = F.take(X2, nd.array([1, 2, 3]), axis=-1)
           else:
               X3 = F.slice_axis(X2, begin=1, end=4, axis=-1)
           X4 = self.dense_out(X3)
           X4 = F.squeeze(X4, axis=-1)
           return X4
   
   
   if __name__ == "__main__":
       N = 30
       T = 20
       C = 8
       conv_units = 5
       atrous_rates = [1, 2, 4]
       np.random.seed(1234)
   
       X = np.random.normal(size=(N, T, C))
       Y = np.random.normal(size=(N, T))
       X, Y = nd.array(X), nd.array(Y)
   
       # Using F.take
       mx.random.seed(12354)
       model = Model(conv_units, atrous_rates, use_take=True)
       model.initialize()
       loss = L2Loss()
       Y_hat1 = run_model(model, loss, X, Y)
   
       # Using F.slice_axis
       mx.random.seed(12354)
       model2 = Model(conv_units, atrous_rates, use_take=False)
       model2.initialize()
       loss2 = L2Loss()
       Y_hat2 = run_model(model2, loss2, X, Y)
   
       delta = nd.abs(Y_hat1-Y_hat2).sum().asscalar()
       print("==== Same outputs?")
       print(f"Y_hat1 - Yhat2 = {delta:.4f}")
   ```
   </details>
   
   ```
   ▶ python3 take_bug.py
   [14:28:50] ../src/storage/storage.cc:199: Using Pooled (Naive) StorageManager for CPU
   ||g_param||_2: 2.37E-04 | Param: convs.1.weight
   ||g_param||_2: 2.29E-05 | Param: convs.1.bias
   ||g_param||_2: 2.23E-04 | Param: convs.2.weight
   ||g_param||_2: 1.50E-04 | Param: convs.2.bias
   ||g_param||_2: 4.26E-04 | Param: convs.3.weight
   ||g_param||_2: 7.02E-04 | Param: convs.3.bias
   ||g_param||_2: 1.38E-04 | Param: dense_out.weight
   ||g_param||_2: 1.06E-02 | Param: dense_out.bias
   
       -------------------------------------------
       -------  Grad Info
       *  ||g||_2: 1.06E-02
       *  ||g||_1: 1.75E-02
       *  ||g||_inf: 1.06E-02
   
   
   ||g_param||_2: 2.37E-04 | Param: convs.1.weight
   ||g_param||_2: 2.29E-05 | Param: convs.1.bias
   ||g_param||_2: 2.23E-04 | Param: convs.2.weight
   ||g_param||_2: 1.50E-04 | Param: convs.2.bias
   ||g_param||_2: 4.26E-04 | Param: convs.3.weight
   ||g_param||_2: 7.02E-04 | Param: convs.3.bias
   ||g_param||_2: 1.38E-04 | Param: dense_out.weight
   ||g_param||_2: 1.06E-02 | Param: dense_out.bias
   
       -------------------------------------------
       -------  Grad Info
       *  ||g||_2: 1.06E-02
       *  ||g||_1: 1.75E-02
       *  ||g||_inf: 1.06E-02
   
   
   ==== Same outputs?
   Y_hat1 - Yhat2 = 0.0000
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] ceisenach edited a comment on issue #19817: F.Take Backwards - Incorrect Gradient

Posted by GitBox <gi...@apache.org>.
ceisenach edited a comment on issue #19817:
URL: https://github.com/apache/incubator-mxnet/issues/19817#issuecomment-923339412


   When I use the latest nightly builds, I no longer observe the bug, so it seems resolved to me. Thanks for the fix!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org