You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2018/05/30 15:58:55 UTC
[opennlp-sandbox] branch master updated: Disable dropout for
inference
This is an automated email from the ASF dual-hosted git repository.
joern pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/opennlp-sandbox.git
The following commit(s) were added to refs/heads/master by this push:
new 77a39ad Disable dropout for inference
77a39ad is described below
commit 77a39ad27f5e78d0bfd5aaaab5b02805ad06444e
Author: Jörn Kottmann <jo...@apache.org>
AuthorDate: Wed May 30 17:58:35 2018 +0200
Disable dropout for inference
---
.../opennlp/tf/guillaumegenthial/SequenceTagging.java | 3 +--
tf-ner-poc/src/main/python/namefinder.py | 15 ++++++++-------
2 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/SequenceTagging.java b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/SequenceTagging.java
index b30509d..7519e84 100644
--- a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/SequenceTagging.java
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/SequenceTagging.java
@@ -62,8 +62,7 @@ public class SequenceTagging implements TokenNameFinder, AutoCloseable {
List<Tensor<?>> run = session.runner()
.feed("chars/char_ids:0", fd.getCharIdsTensor())
- // TODO: missing in the python code ...
- //.feed("dropout:0", fd.getDropoutTensor())
+ .feed("dropout_keep_prop:0", fd.getDropoutTensor())
.feed("words/sequence_lengths:0", fd.getSentenceLengthsTensor())
.feed("words/word_ids:0", fd.getWordIdsTensor())
.feed("chars/word_lengths:0", fd.getWordLengthsTensor())
diff --git a/tf-ner-poc/src/main/python/namefinder.py b/tf-ner-poc/src/main/python/namefinder.py
index e4a015e..3bf8405 100644
--- a/tf-ner-poc/src/main/python/namefinder.py
+++ b/tf-ner-poc/src/main/python/namefinder.py
@@ -171,6 +171,7 @@ class NameFinder:
def create_graph(self, nchars, embedding_dict): # probably not necessary to pass in the embedding_dict, can be passed to init directly
+ dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prop")
with tf.variable_scope("chars"):
# shape = (batch size, max length of sentence, max length of word)
@@ -225,7 +226,7 @@ class NameFinder:
# shape = (batch, sentence, 2 x char_hidden_size + word_vector_size)
word_embeddings = tf.concat([token_embeddings, char_rep], axis=-1)
- word_embeddings = tf.nn.dropout(word_embeddings, 0.5)
+ word_embeddings = tf.nn.dropout(word_embeddings, dropout_keep_prob)
hidden_size = 300
@@ -241,7 +242,7 @@ class NameFinder:
context_rep = tf.concat([output_fw, output_bw], axis=-1)
- context_rep = tf.nn.dropout(context_rep, 0.5)
+ context_rep = tf.nn.dropout(context_rep, dropout_keep_prob)
labels = tf.placeholder(tf.int32, shape=[None, None], name="labels")
@@ -264,13 +265,13 @@ class NameFinder:
train_op = tf.train.AdamOptimizer().minimize(loss)
return embedding_placeholder, token_ids, char_ids, word_lengths_ph, \
- sequence_lengths, labels, train_op
+ sequence_lengths, labels, dropout_keep_prob, train_op
def predict_batch(self, sess, token_ids_ph, char_ids_ph, word_lengths_ph,
- sequence_lengths_ph, sentences, char_ids, word_length, lengths):
+ sequence_lengths_ph, sentences, char_ids, word_length, lengths, dropout_keep_prob):
feed_dict = {token_ids_ph: sentences, char_ids_ph: char_ids, word_lengths_ph: word_length,
- sequence_lengths_ph: lengths}
+ sequence_lengths_ph: lengths, dropout_keep_prob: 1}
viterbi_sequences = []
logits, trans_params = sess.run([self.logits, self.transition_params], feed_dict=feed_dict)
@@ -369,7 +370,7 @@ def main():
name_finder.mini_batch(rev_word_dict, char_dict, sentences, labels, batch_size, batch_index)
feed_dict = {token_ids_ph: sentences_batch, char_ids_ph: chars_batch, word_lengths_ph: word_length_batch, sequence_lengths_ph: lengths,
- labels_ph: labels_batch}
+ labels_ph: labels_batch, dropout_keep_prob: 0.5}
train_op.run(feed_dict, sess)
@@ -387,7 +388,7 @@ def main():
labels_pred, sequence_lengths = name_finder.predict_batch(
sess, token_ids_ph, char_ids_ph, word_lengths_ph, sequence_lengths_ph,
- sentences_test_batch, chars_batch_test, word_length_batch_test, length_test)
+ sentences_test_batch, chars_batch_test, word_length_batch_test, length_test, dropout_keep_prob)
for lab, lab_pred, length in zip(labels_test_batch, labels_pred,
sequence_lengths):
--
To stop receiving notification emails like this one, please contact
joern@apache.org.