You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by gi...@git.apache.org on 2017/07/31 16:04:49 UTC
[GitHub] leezu opened a new issue #7268: Autograd retain_graph=True bugs
leezu opened a new issue #7268: Autograd retain_graph=True bugs
URL: https://github.com/apache/incubator-mxnet/issues/7268
Consider the following example
```
import mxnet as mx
from mxnet import autograd
from mxnet import gluon
encoder = gluon.rnn.LSTM(hidden_size=300, num_layers=1)
encoder.collect_params().initialize(mx.init.Xavier(), ctx=mx.cpu())
decoder = gluon.rnn.LSTM(hidden_size=300, num_layers=1)
decoder.collect_params().initialize(mx.init.Xavier(), ctx=mx.cpu())
encoder_begin_state = encoder.begin_state(
func=mx.nd.zeros, batch_size=8, ctx=mx.cpu())
expected_label = mx.nd.ones((8, 8))
loss = gluon.loss.SoftmaxCrossEntropyLoss()
# Encoder
with autograd.record():
output, hidden = encoder(mx.nd.ones((8, 8, 300)), encoder_begin_state)
# for i in hidden:
# i.attach_grad()
hidden_detached = [i.detach() for i in hidden]
for i in hidden_detached:
i.attach_grad()
prediction, _ = decoder(mx.nd.ones((8, 8, 300)), hidden_detached)
l = loss(prediction, expected_label)
l.backward(retain_graph=True)
hidden[0].backward(hidden_detached[0].grad, retain_graph=True)
hidden[1].backward(hidden_detached[1].grad, retain_graph=True)
# Collect gradients to force execution
params = encoder.collect_params()
params.update(decoder.collect_params())
print([mx.nd.mean(p._grad[mx.cpu()]) for p in params.values()])
```
It will fail with `corrupted double-linked list`. Uncommenting
```
# for i in hidden:
# i.attach_grad()
```
will fix that problem, but will stop the gradient from flowing back to the encoder parameters (which arguably should be documented?).
Or is this not supposed to work? Essentially I want to decompose the autograd graph into two parts, similar to having two modules and using the input gradients of the second for the backward pass of the first.
@piiswrong
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services