You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/06/26 18:46:26 UTC

[incubator-mxnet] branch master updated: Fix bi-lstm-crf to update crf weights (#11375)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 4db5a83  Fix bi-lstm-crf to update crf weights (#11375)
4db5a83 is described below

commit 4db5a83abb79296e5a40a29b3ac6db8e21226b3b
Author: anbrjohn <an...@indiana.edu>
AuthorDate: Tue Jun 26 20:46:19 2018 +0200

    Fix bi-lstm-crf to update crf weights (#11375)
    
    * Fix bi-lstm-crf to update crf weights
    
    * Use self.params.get to declare params
---
 example/gluon/lstm_crf.py | 17 +++++++++--------
 1 file changed, 9 insertions(+), 8 deletions(-)

diff --git a/example/gluon/lstm_crf.py b/example/gluon/lstm_crf.py
index 561b4c2..3e95c05 100644
--- a/example/gluon/lstm_crf.py
+++ b/example/gluon/lstm_crf.py
@@ -62,8 +62,9 @@ class BiLSTM_CRF(Block):
 
             # Matrix of transition parameters.  Entry i,j is the score of
             # transitioning *to* i *from* j.
-            self.transitions = nd.random.normal(shape=(self.tagset_size, self.tagset_size))
-
+            self.transitions = self.params.get("crf_transition_matrix", 
+                                               shape=(self.tagset_size, self.tagset_size))
+            
             self.hidden = self.init_hidden()
 
     def init_hidden(self):
@@ -85,7 +86,7 @@ class BiLSTM_CRF(Block):
                 emit_score = feat[next_tag].reshape((1, -1))
                 # the ith entry of trans_score is the score of transitioning to
                 # next_tag from i
-                trans_score = self.transitions[next_tag].reshape((1, -1))
+                trans_score = self.transitions.data()[next_tag].reshape((1, -1))
                 # The ith entry of next_tag_var is the value for the
                 # edge (i -> next_tag) before we do log-sum-exp
                 next_tag_var = alphas + trans_score + emit_score
@@ -93,7 +94,7 @@ class BiLSTM_CRF(Block):
                 # scores.
                 alphas_t.append(log_sum_exp(next_tag_var))
             alphas = nd.concat(*alphas_t, dim=0).reshape((1, -1))
-        terminal_var = alphas + self.transitions[self.tag2idx[STOP_TAG]]
+        terminal_var = alphas + self.transitions.data()[self.tag2idx[STOP_TAG]]
         alpha = log_sum_exp(terminal_var)
         return alpha
 
@@ -112,8 +113,8 @@ class BiLSTM_CRF(Block):
         tags = nd.concat(nd.array([self.tag2idx[START_TAG]]), *tags, dim=0)
         for i, feat in enumerate(feats):
             score = score + \
-                self.transitions[to_scalar(tags[i+1]), to_scalar(tags[i])] + feat[to_scalar(tags[i+1])]
-        score = score + self.transitions[self.tag2idx[STOP_TAG],
+                self.transitions.data()[to_scalar(tags[i+1]), to_scalar(tags[i])] + feat[to_scalar(tags[i+1])]
+        score = score + self.transitions.data()[self.tag2idx[STOP_TAG],
                                          to_scalar(tags[int(tags.shape[0]-1)])]
         return score
 
@@ -134,7 +135,7 @@ class BiLSTM_CRF(Block):
                 # from tag i to next_tag.
                 # We don't include the emission scores here because the max
                 # does not depend on them (we add them in below)
-                next_tag_var = vvars + self.transitions[next_tag]
+                next_tag_var = vvars + self.transitions.data()[next_tag]
                 best_tag_id = argmax(next_tag_var)
                 bptrs_t.append(best_tag_id)
                 viterbivars_t.append(next_tag_var[0, best_tag_id])
@@ -144,7 +145,7 @@ class BiLSTM_CRF(Block):
             backpointers.append(bptrs_t)
 
         # Transition to STOP_TAG
-        terminal_var = vvars + self.transitions[self.tag2idx[STOP_TAG]]
+        terminal_var = vvars + self.transitions.data()[self.tag2idx[STOP_TAG]]
         best_tag_id = argmax(terminal_var)
         path_score = terminal_var[0, best_tag_id]