You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/06/26 18:46:26 UTC
[incubator-mxnet] branch master updated: Fix bi-lstm-crf to update
crf weights (#11375)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 4db5a83 Fix bi-lstm-crf to update crf weights (#11375)
4db5a83 is described below
commit 4db5a83abb79296e5a40a29b3ac6db8e21226b3b
Author: anbrjohn <an...@indiana.edu>
AuthorDate: Tue Jun 26 20:46:19 2018 +0200
Fix bi-lstm-crf to update crf weights (#11375)
* Fix bi-lstm-crf to update crf weights
* Use self.params.get to declare params
---
example/gluon/lstm_crf.py | 17 +++++++++--------
1 file changed, 9 insertions(+), 8 deletions(-)
diff --git a/example/gluon/lstm_crf.py b/example/gluon/lstm_crf.py
index 561b4c2..3e95c05 100644
--- a/example/gluon/lstm_crf.py
+++ b/example/gluon/lstm_crf.py
@@ -62,8 +62,9 @@ class BiLSTM_CRF(Block):
# Matrix of transition parameters. Entry i,j is the score of
# transitioning *to* i *from* j.
- self.transitions = nd.random.normal(shape=(self.tagset_size, self.tagset_size))
-
+ self.transitions = self.params.get("crf_transition_matrix",
+ shape=(self.tagset_size, self.tagset_size))
+
self.hidden = self.init_hidden()
def init_hidden(self):
@@ -85,7 +86,7 @@ class BiLSTM_CRF(Block):
emit_score = feat[next_tag].reshape((1, -1))
# the ith entry of trans_score is the score of transitioning to
# next_tag from i
- trans_score = self.transitions[next_tag].reshape((1, -1))
+ trans_score = self.transitions.data()[next_tag].reshape((1, -1))
# The ith entry of next_tag_var is the value for the
# edge (i -> next_tag) before we do log-sum-exp
next_tag_var = alphas + trans_score + emit_score
@@ -93,7 +94,7 @@ class BiLSTM_CRF(Block):
# scores.
alphas_t.append(log_sum_exp(next_tag_var))
alphas = nd.concat(*alphas_t, dim=0).reshape((1, -1))
- terminal_var = alphas + self.transitions[self.tag2idx[STOP_TAG]]
+ terminal_var = alphas + self.transitions.data()[self.tag2idx[STOP_TAG]]
alpha = log_sum_exp(terminal_var)
return alpha
@@ -112,8 +113,8 @@ class BiLSTM_CRF(Block):
tags = nd.concat(nd.array([self.tag2idx[START_TAG]]), *tags, dim=0)
for i, feat in enumerate(feats):
score = score + \
- self.transitions[to_scalar(tags[i+1]), to_scalar(tags[i])] + feat[to_scalar(tags[i+1])]
- score = score + self.transitions[self.tag2idx[STOP_TAG],
+ self.transitions.data()[to_scalar(tags[i+1]), to_scalar(tags[i])] + feat[to_scalar(tags[i+1])]
+ score = score + self.transitions.data()[self.tag2idx[STOP_TAG],
to_scalar(tags[int(tags.shape[0]-1)])]
return score
@@ -134,7 +135,7 @@ class BiLSTM_CRF(Block):
# from tag i to next_tag.
# We don't include the emission scores here because the max
# does not depend on them (we add them in below)
- next_tag_var = vvars + self.transitions[next_tag]
+ next_tag_var = vvars + self.transitions.data()[next_tag]
best_tag_id = argmax(next_tag_var)
bptrs_t.append(best_tag_id)
viterbivars_t.append(next_tag_var[0, best_tag_id])
@@ -144,7 +145,7 @@ class BiLSTM_CRF(Block):
backpointers.append(bptrs_t)
# Transition to STOP_TAG
- terminal_var = vvars + self.transitions[self.tag2idx[STOP_TAG]]
+ terminal_var = vvars + self.transitions.data()[self.tag2idx[STOP_TAG]]
best_tag_id = argmax(terminal_var)
path_score = terminal_var[0, best_tag_id]