You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/21 11:46:36 UTC
[GitHub] anbrjohn opened a new issue #11352: CRF weights never updated in
Bi-LSTM-CRF
anbrjohn opened a new issue #11352: CRF weights never updated in Bi-LSTM-CRF
URL: https://github.com/apache/incubator-mxnet/issues/11352
When using `incubator-mxnet/example/gluon/lstm_crf.py`, CRF transition matrix weights are never updated during training, defeating the purpose of the CRF layer. Printing `model.transitions.data()` each epoch confirmed this.
Compare these lines of MXNet [version](https://github.com/apache/incubator-mxnet/blob/master/example/gluon/lstm_crf.py) and the PyTorch [reference](https://pytorch.org/tutorials/beginner/nlp/advanced_tutorial.html):
```
self.transitions = nd.random.normal(shape=(self.tagset_size, self.tagset_size)) # MXNet
self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size)) # PyTorch
```
I was able to solve this issue by changing the above line to:
```
self.transitions = gluon.Parameter("crf_transition_matrix",
shape=(self.tagset_size, self.tagset_size))
```
Making this change required adding .data() to all other references to `self.transitions` in the code, eg:
```
self.transitions[next_tag].reshape((1, -1) # Before
self.transitions.data()[next_tag].reshape((1, -1) # After
```
and manually updating the parameter dictionary outside of the class before model initialization:
```
model.params.update({'crf_transition_matrix':model.transitions}) # Added this line
model.initialize(mx.init.Xavier(magnitude=2.24), ctx=mx.cpu())
optimizer = gluon.Trainer(model.collect_params(), 'sgd', {'learning_rate': 0.01, 'wd': 1e-4})
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services