You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2018/05/30 15:58:55 UTC

[opennlp-sandbox] branch master updated: Disable dropout for inference

This is an automated email from the ASF dual-hosted git repository.

joern pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/opennlp-sandbox.git


The following commit(s) were added to refs/heads/master by this push:
     new 77a39ad  Disable dropout for inference
77a39ad is described below

commit 77a39ad27f5e78d0bfd5aaaab5b02805ad06444e
Author: Jörn Kottmann <jo...@apache.org>
AuthorDate: Wed May 30 17:58:35 2018 +0200

    Disable dropout for inference
---
 .../opennlp/tf/guillaumegenthial/SequenceTagging.java     |  3 +--
 tf-ner-poc/src/main/python/namefinder.py                  | 15 ++++++++-------
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/SequenceTagging.java b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/SequenceTagging.java
index b30509d..7519e84 100644
--- a/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/SequenceTagging.java
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/tf/guillaumegenthial/SequenceTagging.java
@@ -62,8 +62,7 @@ public class SequenceTagging implements TokenNameFinder, AutoCloseable {
 
     List<Tensor<?>> run = session.runner()
             .feed("chars/char_ids:0", fd.getCharIdsTensor())
-            // TODO: missing in the python code ...
-            //.feed("dropout:0", fd.getDropoutTensor())
+            .feed("dropout_keep_prop:0", fd.getDropoutTensor())
             .feed("words/sequence_lengths:0", fd.getSentenceLengthsTensor())
             .feed("words/word_ids:0", fd.getWordIdsTensor())
             .feed("chars/word_lengths:0", fd.getWordLengthsTensor())
diff --git a/tf-ner-poc/src/main/python/namefinder.py b/tf-ner-poc/src/main/python/namefinder.py
index e4a015e..3bf8405 100644
--- a/tf-ner-poc/src/main/python/namefinder.py
+++ b/tf-ner-poc/src/main/python/namefinder.py
@@ -171,6 +171,7 @@ class NameFinder:
 
     def create_graph(self, nchars, embedding_dict): # probably not necessary to pass in the embedding_dict, can be passed to init directly
 
+        dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prop")
 
         with tf.variable_scope("chars"):
             # shape = (batch size, max length of sentence, max length of word)
@@ -225,7 +226,7 @@ class NameFinder:
             # shape = (batch, sentence, 2 x char_hidden_size + word_vector_size)
             word_embeddings = tf.concat([token_embeddings, char_rep], axis=-1)
 
-            word_embeddings = tf.nn.dropout(word_embeddings, 0.5)
+            word_embeddings = tf.nn.dropout(word_embeddings, dropout_keep_prob)
 
         hidden_size = 300
 
@@ -241,7 +242,7 @@ class NameFinder:
 
             context_rep = tf.concat([output_fw, output_bw], axis=-1)
 
-            context_rep = tf.nn.dropout(context_rep, 0.5)
+            context_rep = tf.nn.dropout(context_rep, dropout_keep_prob)
 
             labels = tf.placeholder(tf.int32, shape=[None, None], name="labels")
 
@@ -264,13 +265,13 @@ class NameFinder:
         train_op = tf.train.AdamOptimizer().minimize(loss)
 
         return embedding_placeholder, token_ids, char_ids, word_lengths_ph, \
-               sequence_lengths, labels, train_op
+               sequence_lengths, labels, dropout_keep_prob, train_op
 
     def predict_batch(self, sess, token_ids_ph, char_ids_ph, word_lengths_ph,
-                      sequence_lengths_ph, sentences, char_ids, word_length, lengths):
+                      sequence_lengths_ph, sentences, char_ids, word_length, lengths, dropout_keep_prob):
 
         feed_dict = {token_ids_ph: sentences, char_ids_ph: char_ids, word_lengths_ph: word_length,
-                     sequence_lengths_ph: lengths}
+                     sequence_lengths_ph: lengths, dropout_keep_prob: 1}
 
         viterbi_sequences = []
         logits, trans_params = sess.run([self.logits, self.transition_params], feed_dict=feed_dict)
@@ -369,7 +370,7 @@ def main():
                     name_finder.mini_batch(rev_word_dict, char_dict, sentences, labels, batch_size, batch_index)
 
                 feed_dict = {token_ids_ph:  sentences_batch, char_ids_ph: chars_batch, word_lengths_ph: word_length_batch, sequence_lengths_ph: lengths,
-                             labels_ph: labels_batch}
+                             labels_ph: labels_batch, dropout_keep_prob: 0.5}
 
                 train_op.run(feed_dict, sess)
 
@@ -387,7 +388,7 @@ def main():
 
                 labels_pred, sequence_lengths = name_finder.predict_batch(
                     sess, token_ids_ph, char_ids_ph, word_lengths_ph, sequence_lengths_ph,
-                    sentences_test_batch, chars_batch_test, word_length_batch_test, length_test)
+                    sentences_test_batch, chars_batch_test, word_length_batch_test, length_test, dropout_keep_prob)
 
                 for lab, lab_pred, length in zip(labels_test_batch, labels_pred,
                                                  sequence_lengths):

-- 
To stop receiving notification emails like this one, please contact
joern@apache.org.