You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ctakes.apache.org by dl...@apache.org on 2016/12/13 17:48:36 UTC

svn commit: r1774067 - /ctakes/trunk/ctakes-temporal/scripts/nn/cnn_classify_position.py

Author: dligach
Date: Tue Dec 13 17:48:36 2016
New Revision: 1774067

URL: http://svn.apache.org/viewvc?rev=1774067&view=rev
Log:
wrote classification script for position features

Modified:
    ctakes/trunk/ctakes-temporal/scripts/nn/cnn_classify_position.py

Modified: ctakes/trunk/ctakes-temporal/scripts/nn/cnn_classify_position.py
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/scripts/nn/cnn_classify_position.py?rev=1774067&r1=1774066&r2=1774067&view=diff
==============================================================================
--- ctakes/trunk/ctakes-temporal/scripts/nn/cnn_classify_position.py (original)
+++ ctakes/trunk/ctakes-temporal/scripts/nn/cnn_classify_position.py Tue Dec 13 17:48:36 2016
@@ -19,6 +19,9 @@ def main(args):
     maxlen   = pickle.load(open(os.path.join(model_dir, "maxlen.p"), "rb"))
     word2int = pickle.load(open(os.path.join(model_dir, "word2int.p"), "rb"))
     label2int = pickle.load(open(os.path.join(model_dir, "label2int.p"), "rb"))
+    tdist2int = pickle.load(open(os.path.join(model_dir, "tdist2int.p"), "rb"))
+    edist2int = pickle.load(open(os.path.join(model_dir, "edist2int.p"), "rb"))
+
     model = model_from_json(open(os.path.join(model_dir, "model_0.json")).read())
     model.load_weights(os.path.join(model_dir, "model_0.h5"))
 
@@ -32,23 +35,53 @@ def main(args):
             if not line:
                 break
 
-            feats=[]
-            for unigram in line.rstrip().split():
-                if unigram in word2int:
-                    feats.append(word2int[unigram])
+            text, tdist, edist = line.strip().split('|')
+
+            tokens = []
+            for token in text.rstrip().split():
+                if token in word2int:
+                    tokens.append(word2int[token])
+                else:
+                    tokens.append(word2int['none'])
+
+            tdists = []
+            for dist in tdist.rstrip().split():
+                if dist in tdist2int:
+                    tdists.append(tdist2int[dist])
                 else:
-                    # TODO: 'none' is not in vocabulary!
-                    feats.append(word2int['none'])
-                    
-            if len(feats) > maxlen:
-                feats=feats[0:maxlen]
-            test_x = pad_sequences([feats], maxlen=maxlen)
+                    tdists.append(tdist2int['none'])
+
+            edists = []
+            for dist in edist.rstrip().split():
+                if dist in edist2int:
+                    edists.append(edist2int[dist])
+                else:
+                    edists.append(edist2int['none'])
+
+            if len(tokens) > maxlen:
+                tokens = tokens[0:maxlen]
+            if len(tdists) > maxlen:
+                tdists = tdist[0:maxlen]
+            if len(edists) > maxlen:
+                edists = edist[0:maxlen]
+
+            test_x1 = pad_sequences([tokens], maxlen=maxlen)
+            test_x2 = pad_sequences([tdists], maxlen=maxlen)
+            test_x3 = pad_sequences([edists], maxlen=maxlen)
 
             test_xs = []
-            test_xs.append(test_x)
-            test_xs.append(test_x)
-            test_xs.append(test_x)
-            test_xs.append(test_x)
+            test_xs.append(test_x1)
+            test_xs.append(test_x2)
+            test_xs.append(test_x3)
+            test_xs.append(test_x1)
+            test_xs.append(test_x2)
+            test_xs.append(test_x3)
+            test_xs.append(test_x1)
+            test_xs.append(test_x2)
+            test_xs.append(test_x3)
+            test_xs.append(test_x1)
+            test_xs.append(test_x2)
+            test_xs.append(test_x3)
 
             out = model.predict(test_xs, batch_size=50)[0]