You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2018/10/15 13:17:48 UTC
[opennlp-sandbox] branch master updated: Extract vector size from
embeddings file
This is an automated email from the ASF dual-hosted git repository.
joern pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/opennlp-sandbox.git
The following commit(s) were added to refs/heads/master by this push:
new cb36083 Extract vector size from embeddings file
cb36083 is described below
commit cb36083c7ecf8c1632a3b3b61c8d49523d9cdecc
Author: Jörn Kottmann <jo...@apache.org>
AuthorDate: Wed Oct 10 15:48:29 2018 +0200
Extract vector size from embeddings file
---
.../src/main/python/namefinder/namefinder.py | 58 +++++++++++++---------
1 file changed, 34 insertions(+), 24 deletions(-)
diff --git a/tf-ner-poc/src/main/python/namefinder/namefinder.py b/tf-ner-poc/src/main/python/namefinder/namefinder.py
index 00560ac..9150bd1 100644
--- a/tf-ner-poc/src/main/python/namefinder/namefinder.py
+++ b/tf-ner-poc/src/main/python/namefinder/namefinder.py
@@ -62,28 +62,6 @@ class NameFinder:
def __init__(self, vector_size=100):
self.__vector_size = vector_size
- def load_glove(self, glove_file):
- with open(glove_file) as f:
-
- word_dict = {}
- embeddings = []
-
- for line in f:
- parts = line.strip().split(" ")
- if len(parts) != self.__vector_size + 1:
- #print("Bad Vector: ",len(line),len(parts), line)
- raise VectorException("Bad Vector in line: {}, size: {} vector: {}".format(len(line),len(parts), line))
- continue
- word_dict[parts[0]] = len(word_dict)
- embeddings.append(np.array(parts[1:], dtype=np.float32))
-
- # Create a reverse word dict
- rev_word_dict = {}
- for word, id in word_dict.items():
- rev_word_dict[id] = word
-
- return word_dict, rev_word_dict, np.asarray(embeddings)
-
def load_data(self, word_dict, file):
with open(file) as f:
raw_data = f.readlines()
@@ -344,15 +322,47 @@ def write_mapping(tags, output_filename):
for i, tag in enumerate(tags):
f.write('{}\n'.format(tag))
+def load_glove(glove_file):
+ with open(glove_file) as f:
+
+ word_dict = {}
+ embeddings = []
+
+ vector_size = -1
+
+ for line in f:
+ parts = line.strip().split(" ")
+
+ if vector_size == -1:
+ if len(parts) == 2:
+ vector_size = int(parts[1])
+ continue
+ vector_size = len(parts) - 1
+
+ if len(parts) != vector_size + 1:
+ #print("Bad Vector: ",len(line),len(parts), line)
+ raise VectorException("Bad Vector in line: {}, size: {} vector: {}".format(len(line),len(parts), line))
+ continue
+ word_dict[parts[0]] = len(word_dict)
+ embeddings.append(np.array(parts[1:], dtype=np.float32))
+
+ # Create a reverse word dict
+ rev_word_dict = {}
+ for word, id in word_dict.items():
+ rev_word_dict[id] = word
+
+ return word_dict, rev_word_dict, np.asarray(embeddings), vector_size
+
def main():
if len(sys.argv) != 5:
print("Usage namefinder.py embedding_file train_file dev_file test_file")
return
- name_finder = NameFinder(100)
+ word_dict, rev_word_dict, embeddings, vector_size = load_glove(sys.argv[1])
+
+ name_finder = NameFinder(vector_size)
- word_dict, rev_word_dict, embeddings = name_finder.load_glove(sys.argv[1])
sentences, labels, char_set = name_finder.load_data(word_dict, sys.argv[2])
sentences_dev, labels_dev, char_set_dev = name_finder.load_data(word_dict, sys.argv[3])