You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ctakes.apache.org by dl...@apache.org on 2016/09/21 19:21:06 UTC

svn commit: r1761798 - in /ctakes/trunk/ctakes-temporal/scripts/nn: predict.py train_and_package.py

Author: dligach
Date: Wed Sep 21 19:21:06 2016
New Revision: 1761798

URL: http://svn.apache.org/viewvc?rev=1761798&view=rev
Log:
minor cosmetic fixes

Modified:
    ctakes/trunk/ctakes-temporal/scripts/nn/predict.py
    ctakes/trunk/ctakes-temporal/scripts/nn/train_and_package.py

Modified: ctakes/trunk/ctakes-temporal/scripts/nn/predict.py
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/scripts/nn/predict.py?rev=1761798&r1=1761797&r2=1761798&view=diff
==============================================================================
--- ctakes/trunk/ctakes-temporal/scripts/nn/predict.py (original)
+++ ctakes/trunk/ctakes-temporal/scripts/nn/predict.py Wed Sep 21 19:21:06 2016
@@ -34,22 +34,22 @@ def main(args):
 
             feats=[]
             for unigram in line.rstrip().split():
-                if(word2int.has_key(unigram)):
+                if(unigram in word2int):
                     feats.append(word2int[unigram])
                 else:
-                    feats.append(word2int["none"])
+                    feats.append(word2int['oov_word'])
                     
-            if(len(feats) > maxlen):
+            if len(feats) > maxlen:
                 feats=feats[0:maxlen]
             test_x = pad_sequences([feats], maxlen=maxlen)
 
-            X_dup = []
-            X_dup.append(test_x)
-            X_dup.append(test_x)
-            X_dup.append(test_x)
-            X_dup.append(test_x)
+            test_xs = []
+            test_xs.append(test_x)
+            test_xs.append(test_x)
+            test_xs.append(test_x)
+            test_xs.append(test_x)
 
-            out = model.predict(X_dup, batch_size=50)[0]
+            out = model.predict(test_xs, batch_size=50)[0]
 
         except KeyboardInterrupt:
             sys.stderr.write("Caught keyboard interrupt\n")
@@ -60,7 +60,7 @@ def main(args):
             break
 
         out_str = int2label[out.argmax()]
-        print(out_str)
+        print out_str
         sys.stdout.flush()
 
     sys.exit(0)

Modified: ctakes/trunk/ctakes-temporal/scripts/nn/train_and_package.py
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/scripts/nn/train_and_package.py?rev=1761798&r1=1761797&r2=1761798&view=diff
==============================================================================
--- ctakes/trunk/ctakes-temporal/scripts/nn/train_and_package.py (original)
+++ ctakes/trunk/ctakes-temporal/scripts/nn/train_and_package.py Wed Sep 21 19:21:06 2016
@@ -32,8 +32,7 @@ def main(args):
     train_x, train_y = provider.load(data_file)
     # turn x and y into numpy array among other things
     maxlen = max([len(seq) for seq in train_x])
-    outcomes = set(train_y)
-    classes = len(outcomes)
+    classes = len(set(train_y))
 
     train_x = pad_sequences(train_x, maxlen=maxlen)
     train_y = to_categorical(np.array(train_y), classes)
@@ -49,7 +48,6 @@ def main(args):
     train_xs = [] # train x for each branch
 
     for filter_len in '2,3,4,5'.split(','):
-      
         branch = Sequential()
         branch.add(Embedding(len(provider.word2int),
                              300,
@@ -77,8 +75,7 @@ def main(args):
     model.add(Dense(classes))
     model.add(Activation('softmax'))
 
-    optimizer = RMSprop(lr=0.0001,
-                        rho=0.9, epsilon=1e-08)
+    optimizer = RMSprop(lr=0.0001, rho=0.9, epsilon=1e-08)
     model.compile(loss='categorical_crossentropy',
                   optimizer=optimizer,
                   metrics=['accuracy'])
@@ -87,8 +84,7 @@ def main(args):
               nb_epoch=3,
               batch_size=50,
               verbose=1,
-              validation_split=0.1,
-              class_weight=None)
+              validation_split=0.1)
 
     json_string = model.to_json()
     open(os.path.join(working_dir, 'model_0.json'), 'w').write(json_string)