You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ctakes.apache.org by dl...@apache.org on 2016/09/24 16:09:18 UTC
svn commit: r1762139 - in /ctakes/trunk/ctakes-temporal/scripts/nn:
classify.sh lstm_classify.py lstm_train.py train.sh
Author: dligach
Date: Sat Sep 24 16:09:18 2016
New Revision: 1762139
URL: http://svn.apache.org/viewvc?rev=1762139&view=rev
Log:
now trying lstm on pos tags
Added:
ctakes/trunk/ctakes-temporal/scripts/nn/lstm_classify.py
ctakes/trunk/ctakes-temporal/scripts/nn/lstm_train.py
Modified:
ctakes/trunk/ctakes-temporal/scripts/nn/classify.sh
ctakes/trunk/ctakes-temporal/scripts/nn/train.sh
Modified: ctakes/trunk/ctakes-temporal/scripts/nn/classify.sh
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/scripts/nn/classify.sh?rev=1762139&r1=1762138&r2=1762139&view=diff
==============================================================================
--- ctakes/trunk/ctakes-temporal/scripts/nn/classify.sh (original)
+++ ctakes/trunk/ctakes-temporal/scripts/nn/classify.sh Sat Sep 24 16:09:18 2016
@@ -1,7 +1,7 @@
#!/bin/bash
source $(dirname $0)/env/bin/activate
-python $(dirname $0)/cnn_classify.py $*
+python $(dirname $0)/lstm_classify.py $*
ret=$?
deactivate
exit $ret
Added: ctakes/trunk/ctakes-temporal/scripts/nn/lstm_classify.py
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/scripts/nn/lstm_classify.py?rev=1762139&view=auto
==============================================================================
--- ctakes/trunk/ctakes-temporal/scripts/nn/lstm_classify.py (added)
+++ ctakes/trunk/ctakes-temporal/scripts/nn/lstm_classify.py Sat Sep 24 16:09:18 2016
@@ -0,0 +1,63 @@
+#!python
+
+from keras.models import Sequential, model_from_json
+import numpy as np
+import et_cleartk_io as ctk_io
+import sys
+import os.path
+import pickle
+from keras.preprocessing.sequence import pad_sequences
+
+def main(args):
+ if len(args) < 1:
+ sys.stderr.write("Error - one required argument: <model directory>\n")
+ sys.exit(-1)
+ working_dir = args[0]
+
+ target_dir = 'ctakes-temporal/target/eval/thyme/train_and_test/event-time/'
+ model_dir = os.path.join(os.environ['CTAKES_ROOT'], target_dir)
+ maxlen = pickle.load(open(os.path.join(model_dir, "maxlen.p"), "rb"))
+ word2int = pickle.load(open(os.path.join(model_dir, "word2int.p"), "rb"))
+ label2int = pickle.load(open(os.path.join(model_dir, "label2int.p"), "rb"))
+ model = model_from_json(open(os.path.join(model_dir, "model_0.json")).read())
+ model.load_weights(os.path.join(model_dir, "model_0.h5"))
+
+ int2label = {}
+ for label, integer in label2int.items():
+ int2label[integer] = label
+
+ while True:
+ try:
+ line = sys.stdin.readline().rstrip()
+ if not line:
+ break
+
+ feats = []
+ for unigram in line.rstrip().split():
+ if unigram in word2int:
+ feats.append(word2int[unigram])
+ else:
+ # TODO: 'none' is not in vocabulary!
+ feats.append(word2int['none'])
+
+ if len(feats) > maxlen:
+ feats=feats[0:maxlen]
+ test_x = pad_sequences([feats], maxlen=maxlen)
+ out = model.predict(test_x, batch_size=50)[0]
+
+ except KeyboardInterrupt:
+ sys.stderr.write("Caught keyboard interrupt\n")
+ break
+
+ if line == '':
+ sys.stderr.write("Encountered empty string so exiting\n")
+ break
+
+ out_str = int2label[out.argmax()]
+ print out_str
+ sys.stdout.flush()
+
+ sys.exit(0)
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
Added: ctakes/trunk/ctakes-temporal/scripts/nn/lstm_train.py
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/scripts/nn/lstm_train.py?rev=1762139&view=auto
==============================================================================
--- ctakes/trunk/ctakes-temporal/scripts/nn/lstm_train.py (added)
+++ ctakes/trunk/ctakes-temporal/scripts/nn/lstm_train.py Sat Sep 24 16:09:18 2016
@@ -0,0 +1,75 @@
+#!/usr/bin/env python
+
+import sklearn as sk
+import numpy as np
+np.random.seed(1337)
+import et_cleartk_io as ctk_io
+import nn_models
+import sys
+import os.path
+import dataset
+import keras as k
+from keras.utils.np_utils import to_categorical
+from keras.optimizers import RMSprop
+from keras.preprocessing.sequence import pad_sequences
+from keras.models import Sequential
+from keras.layers.core import Dense, Dropout, Activation
+from keras.layers.embeddings import Embedding
+from keras.layers import LSTM
+import pickle
+
+def main(args):
+ if len(args) < 1:
+ sys.stderr.write("Error - one required argument: <data directory>\n")
+ sys.exit(-1)
+ working_dir = args[0]
+ data_file = os.path.join(working_dir, 'training-data.liblinear')
+
+ # learn alphabet from training data
+ provider = dataset.DatasetProvider(data_file)
+ # now load training examples and labels
+ train_x, train_y = provider.load(data_file)
+ # turn x and y into numpy array among other things
+ maxlen = max([len(seq) for seq in train_x])
+ classes = len(set(train_y))
+
+ train_x = pad_sequences(train_x, maxlen=maxlen)
+ train_y = to_categorical(np.array(train_y), classes)
+
+ pickle.dump(maxlen, open(os.path.join(working_dir, 'maxlen.p'),"wb"))
+ pickle.dump(provider.word2int, open(os.path.join(working_dir, 'word2int.p'),"wb"))
+ pickle.dump(provider.label2int, open(os.path.join(working_dir, 'label2int.p'),"wb"))
+
+ print 'train_x shape:', train_x.shape
+ print 'train_y shape:', train_y.shape
+
+ model = Sequential()
+
+ model.add(Embedding(len(provider.word2int),
+ 300,
+ input_length=maxlen,
+ dropout=0.25))
+ model.add(LSTM(128,
+ dropout_W = 0.20,
+ dropout_U = 0.20))
+ model.add(Dense(classes))
+ model.add(Activation('softmax'))
+
+ optimizer = RMSprop(lr=0.001, rho=0.9, epsilon=1e-08)
+ model.compile(loss='categorical_crossentropy',
+ optimizer=optimizer,
+ metrics=['accuracy'])
+ model.fit(train_x,
+ train_y,
+ nb_epoch=1,
+ batch_size=50,
+ verbose=1,
+ validation_split=0.1)
+
+ json_string = model.to_json()
+ open(os.path.join(working_dir, 'model_0.json'), 'w').write(json_string)
+ model.save_weights(os.path.join(working_dir, 'model_0.h5'), overwrite=True)
+ sys.exit(0)
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
Modified: ctakes/trunk/ctakes-temporal/scripts/nn/train.sh
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/scripts/nn/train.sh?rev=1762139&r1=1762138&r2=1762139&view=diff
==============================================================================
--- ctakes/trunk/ctakes-temporal/scripts/nn/train.sh (original)
+++ ctakes/trunk/ctakes-temporal/scripts/nn/train.sh Sat Sep 24 16:09:18 2016
@@ -1,7 +1,7 @@
#!/bin/bash
source $(dirname $0)/env/bin/activate
-python $(dirname $0)/cnn_train.py $*
+python $(dirname $0)/lstm_train.py $*
ret=$?
deactivate
exit $ret