You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2018/05/31 09:29:49 UTC
[opennlp-sandbox] branch master updated: added vector size to
NameFinder + only save model if improved + stop training if not improved
for 5 iteration
This is an automated email from the ASF dual-hosted git repository.
joern pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/opennlp-sandbox.git
The following commit(s) were added to refs/heads/master by this push:
new 8dc9495 added vector size to NameFinder + only save model if improved + stop training if not improved for 5 iteration
8dc9495 is described below
commit 8dc9495d78017eb24cc7f6cf120c14d0c4964a3c
Author: Peter Thygesen <th...@apache.org>
AuthorDate: Thu May 31 11:20:42 2018 +0200
added vector size to NameFinder + only save model if improved + stop training if not improved for 5 iteration
---
tf-ner-poc/src/main/python/namefinder.py | 42 +++++++++++++++++++++++++-------
1 file changed, 33 insertions(+), 9 deletions(-)
diff --git a/tf-ner-poc/src/main/python/namefinder.py b/tf-ner-poc/src/main/python/namefinder.py
index cd5a464..5a15064 100644
--- a/tf-ner-poc/src/main/python/namefinder.py
+++ b/tf-ner-poc/src/main/python/namefinder.py
@@ -14,7 +14,7 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
-# under the License.
+# under the License.
#
# This poc is based on source code taken from:
@@ -48,10 +48,19 @@ class NameSample:
self.tokens.append(parts[i])
word_index += 1
+class VectorException(Exception):
+ def __init__(self, value):
+ self.value = value
+
+ def __str__(self):
+ return repr(self.value)
+
class NameFinder:
- def __init__(self):
- self.label_dict = {}
+ label_dict = {}
+
+ def __init__(self, vector_size=100):
+ self.__vector_size = vector_size
def load_glove(self, glove_file):
with open(glove_file) as f:
@@ -61,6 +70,10 @@ class NameFinder:
for line in f:
parts = line.strip().split(" ")
+ if len(parts) != self.__vector_size + 1:
+ #print("Bad Vector: ",len(line),len(parts), line)
+ raise VectorException("Bad Vector in line: {}, size: {} vector: {}".format(len(line),len(parts), line))
+ continue
word_dict[parts[0]] = len(word_dict)
embeddings.append(np.array(parts[1:], dtype=np.float32))
@@ -218,7 +231,7 @@ class NameFinder:
# This is a hack to make it load an embedding matrix larger than 2GB
# Don't hardcode this 300
embedding_placeholder = tf.placeholder(dtype=tf.float32, name="embedding_placeholder",
- shape=(len(embedding_dict), 100))
+ shape=(len(embedding_dict), self.__vector_size))
embedding_matrix = tf.Variable(embedding_placeholder, dtype=tf.float32, trainable=False, name="glove_embeddings")
token_embeddings = tf.nn.embedding_lookup(embedding_matrix, token_ids)
@@ -335,7 +348,7 @@ def main():
print("Usage namefinder.py embedding_file train_file dev_file test_file")
return
- name_finder = NameFinder()
+ name_finder = NameFinder(300)
word_dict, rev_word_dict, embeddings = name_finder.load_glove(sys.argv[1])
sentences, labels, char_set = name_finder.load_data(word_dict, sys.argv[2])
@@ -353,6 +366,8 @@ def main():
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
log_device_placement=True))
+ best_f1 = 0.0
+ no_improvement = 0
with sess.as_default():
init = tf.global_variables_initializer()
sess.run(init, feed_dict={embedding_ph: embeddings})
@@ -411,10 +426,19 @@ def main():
print("ACC " + str(acc))
print("F1 " + str(f1) + " P " + str(p) + " R " + str(r))
- saver = tf.train.Saver()
- builder = tf.saved_model.builder.SavedModelBuilder("./savedmodel")
- builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING])
- builder.save()
+ if (f1 > best_f1):
+ best_f1 = f1
+ no_improvement = 0
+ saver = tf.train.Saver()
+ builder = tf.saved_model.builder.SavedModelBuilder("./savedmodel/{}".format(epoch))
+ builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING])
+ builder.save()
+ else:
+ no_improvement += 1
+
+ if no_improvement > 5:
+ print("No further improvement. Stopping.")
+ break
if __name__ == "__main__":
main()
--
To stop receiving notification emails like this one, please contact
joern@apache.org.