You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/21 11:46:36 UTC

[GitHub] anbrjohn opened a new issue #11352: CRF weights never updated in Bi-LSTM-CRF

anbrjohn opened a new issue #11352: CRF weights never updated in Bi-LSTM-CRF
URL: https://github.com/apache/incubator-mxnet/issues/11352
 
 
    When using `incubator-mxnet/example/gluon/lstm_crf.py`, CRF transition matrix weights are never updated during training, defeating the purpose of the CRF layer. Printing `model.transitions.data()` each epoch confirmed this.
   
   Compare these lines of MXNet [version](https://github.com/apache/incubator-mxnet/blob/master/example/gluon/lstm_crf.py) and the PyTorch [reference](https://pytorch.org/tutorials/beginner/nlp/advanced_tutorial.html):
   ```
   self.transitions = nd.random.normal(shape=(self.tagset_size, self.tagset_size)) # MXNet
   self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size)) # PyTorch
   ```
   
   I was able to solve this issue by changing the above line to:
   ```
   self.transitions = gluon.Parameter("crf_transition_matrix", 
       shape=(self.tagset_size, self.tagset_size))
   ```
   Making this change required adding .data() to all other references to `self.transitions` in the code, eg:
   ```
   self.transitions[next_tag].reshape((1, -1) # Before
   self.transitions.data()[next_tag].reshape((1, -1) # After
   ```
   and manually updating the parameter dictionary outside of the class before model initialization:
   ```
   model.params.update({'crf_transition_matrix':model.transitions}) # Added this line
   model.initialize(mx.init.Xavier(magnitude=2.24), ctx=mx.cpu())
   optimizer = gluon.Trainer(model.collect_params(), 'sgd', {'learning_rate': 0.01, 'wd': 1e-4})
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services