You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/01/11 07:00:15 UTC

[GitHub] DabiaoMa commented on issue #9302: nd.contrib.fft issue

DabiaoMa commented on issue #9302: nd.contrib.fft issue
URL: https://github.com/apache/incubator-mxnet/issues/9302#issuecomment-356844107
 
 
   @shiyangdaisy23 
   Here is an simple example:
   ######################
   
   from mxnet import nd, gpu, autograd, gluon
   import numpy as np
   import librosa
   
   ctx = gpu(2)
   fft_length = 256
   fft_size = fft_length // 2 + 1
   
   f = gluon.nn.Dense(1, flatten=False)
   f.collect_params().initialize(ctx=ctx)
   
   wavs = nd.array(librosa.load('210001.wav', sr=16000, mono=True)[0], ctx=ctx)
   x = nd.random.uniform(shape=(1, wavs.shape[0], 1), ctx=ctx)
   
   third_dim = x.shape[1] // fft_length + 1
   padding_needed = wavs.shape[0] % fft_length
   
   wavs = nd.expand_dims(nd.expand_dims(wavs, axis=0), axis=0)
   
   f(x)
   
   print f.weight.grad()
   
   with autograd.record():
       y = f(x)
       y = nd.swapaxes(y, dim1=1, dim2=2)
       y = nd.concat(y, nd.zeros(shape=(1, 1, padding_needed), ctx=ctx), dim=2).reshape((1, 1, -1, fft_length))
       wavs = nd.concat(wavs, nd.zeros(shape=(1, 1, padding_needed), ctx=ctx), dim=2).reshape((1, 1, -1, fft_length))
       stft_y = nd.contrib.fft(y, compute_size=fft_length).reshape((1, 1, third_dim, -1, 2))
       stft_wavs = nd.contrib.fft(y, compute_size=fft_length).reshape((1, 1, third_dim, -1, 2))
       stft_y = nd.sqrt(nd.sum(nd.square(stft_y), axis=4) + 1e-12)[:, :, :, : fft_size]
       stft_wavs = nd.sqrt(nd.sum(nd.square(stft_wavs), axis=4) + 1e-12)[:, :, :, : fft_size]
       loss = 0 * nd.mean(nd.square(stft_y - stft_wavs), axis=(1, 2, 3))
   
   loss.backward()
   print f.weight.grad()
   ####################
   
   The version I am using is 0.12.0, with support of cu8.0
   Initially the grad of f.weight is 0. After the backward operation, the grad is supposed to be 0( because the loss is multiplied by 0), but I got 81.09 instead
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services