You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ctakes.apache.org by dl...@apache.org on 2016/12/13 17:48:36 UTC
svn commit: r1774067 -
/ctakes/trunk/ctakes-temporal/scripts/nn/cnn_classify_position.py
Author: dligach
Date: Tue Dec 13 17:48:36 2016
New Revision: 1774067
URL: http://svn.apache.org/viewvc?rev=1774067&view=rev
Log:
wrote classification script for position features
Modified:
ctakes/trunk/ctakes-temporal/scripts/nn/cnn_classify_position.py
Modified: ctakes/trunk/ctakes-temporal/scripts/nn/cnn_classify_position.py
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/scripts/nn/cnn_classify_position.py?rev=1774067&r1=1774066&r2=1774067&view=diff
==============================================================================
--- ctakes/trunk/ctakes-temporal/scripts/nn/cnn_classify_position.py (original)
+++ ctakes/trunk/ctakes-temporal/scripts/nn/cnn_classify_position.py Tue Dec 13 17:48:36 2016
@@ -19,6 +19,9 @@ def main(args):
maxlen = pickle.load(open(os.path.join(model_dir, "maxlen.p"), "rb"))
word2int = pickle.load(open(os.path.join(model_dir, "word2int.p"), "rb"))
label2int = pickle.load(open(os.path.join(model_dir, "label2int.p"), "rb"))
+ tdist2int = pickle.load(open(os.path.join(model_dir, "tdist2int.p"), "rb"))
+ edist2int = pickle.load(open(os.path.join(model_dir, "edist2int.p"), "rb"))
+
model = model_from_json(open(os.path.join(model_dir, "model_0.json")).read())
model.load_weights(os.path.join(model_dir, "model_0.h5"))
@@ -32,23 +35,53 @@ def main(args):
if not line:
break
- feats=[]
- for unigram in line.rstrip().split():
- if unigram in word2int:
- feats.append(word2int[unigram])
+ text, tdist, edist = line.strip().split('|')
+
+ tokens = []
+ for token in text.rstrip().split():
+ if token in word2int:
+ tokens.append(word2int[token])
+ else:
+ tokens.append(word2int['none'])
+
+ tdists = []
+ for dist in tdist.rstrip().split():
+ if dist in tdist2int:
+ tdists.append(tdist2int[dist])
else:
- # TODO: 'none' is not in vocabulary!
- feats.append(word2int['none'])
-
- if len(feats) > maxlen:
- feats=feats[0:maxlen]
- test_x = pad_sequences([feats], maxlen=maxlen)
+ tdists.append(tdist2int['none'])
+
+ edists = []
+ for dist in edist.rstrip().split():
+ if dist in edist2int:
+ edists.append(edist2int[dist])
+ else:
+ edists.append(edist2int['none'])
+
+ if len(tokens) > maxlen:
+ tokens = tokens[0:maxlen]
+ if len(tdists) > maxlen:
+ tdists = tdist[0:maxlen]
+ if len(edists) > maxlen:
+ edists = edist[0:maxlen]
+
+ test_x1 = pad_sequences([tokens], maxlen=maxlen)
+ test_x2 = pad_sequences([tdists], maxlen=maxlen)
+ test_x3 = pad_sequences([edists], maxlen=maxlen)
test_xs = []
- test_xs.append(test_x)
- test_xs.append(test_x)
- test_xs.append(test_x)
- test_xs.append(test_x)
+ test_xs.append(test_x1)
+ test_xs.append(test_x2)
+ test_xs.append(test_x3)
+ test_xs.append(test_x1)
+ test_xs.append(test_x2)
+ test_xs.append(test_x3)
+ test_xs.append(test_x1)
+ test_xs.append(test_x2)
+ test_xs.append(test_x3)
+ test_xs.append(test_x1)
+ test_xs.append(test_x2)
+ test_xs.append(test_x3)
out = model.predict(test_xs, batch_size=50)[0]