You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2017/12/20 05:45:21 UTC
[incubator-mxnet] branch master updated: Usability improvement bi
lstm sort (#8944)
This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new c71a2f3 Usability improvement bi lstm sort (#8944)
c71a2f3 is described below
commit c71a2f3bc9a815b8ef9b6caa615013c74deb10ad
Author: Anirudh Subramanian <an...@gmail.com>
AuthorDate: Tue Dec 19 21:45:17 2017 -0800
Usability improvement bi lstm sort (#8944)
* Improve usability for the bilstm example
* Remove argparse from infer_sort since it changes existing usage
---
example/bi-lstm-sort/README.md | 48 +++++++++++-------------
example/bi-lstm-sort/infer_sort.py | 25 ++++++++++---
example/bi-lstm-sort/lstm.py | 1 -
example/bi-lstm-sort/lstm_sort.py | 75 +++++++++++++++++++++++++++++++++-----
example/bi-lstm-sort/rnn_model.py | 1 -
example/bi-lstm-sort/sort_io.py | 1 -
6 files changed, 106 insertions(+), 45 deletions(-)
diff --git a/example/bi-lstm-sort/README.md b/example/bi-lstm-sort/README.md
index a590a18..3bacc86 100644
--- a/example/bi-lstm-sort/README.md
+++ b/example/bi-lstm-sort/README.md
@@ -1,28 +1,24 @@
This is an example of using bidirection lstm to sort an array.
-Firstly, generate data by:
-
- python gen_data.py
-
-Move generated txt files to data directory
-
- mkdir data
- mv *.txt data
-
-Then, train the model by:
-
- python lstm_sort.py
-
-At last, test model by:
-
- python infer_sort.py 234 189 785 763 231
-
-and will output sorted seq
-
- 189
- 231
- 234
- 763
- 785
-
-
+Run the training script by doing the following:
+
+```
+python lstm_sort.py --start-range 100 --end-range 1000 --cpu
+```
+You can provide the start-range and end-range for the numbers and whether to train on the cpu or not.
+By default the script tries to train on the GPU. The default start-range is 100 and end-range is 1000.
+
+At last, test model by doing the following:
+
+```
+python infer_sort.py 234 189 785 763 231
+```
+
+This should output the sorted seq like the following:
+```
+189
+231
+234
+763
+785
+```
diff --git a/example/bi-lstm-sort/infer_sort.py b/example/bi-lstm-sort/infer_sort.py
index b074c03..f81c6c0 100644
--- a/example/bi-lstm-sort/infer_sort.py
+++ b/example/bi-lstm-sort/infer_sort.py
@@ -18,20 +18,29 @@
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
import sys
-sys.path.insert(0, "../../python")
+import os
+import argparse
import numpy as np
import mxnet as mx
from sort_io import BucketSentenceIter, default_build_vocab
from rnn_model import BiLSTMInferenceModel
+TRAIN_FILE = "sort.train.txt"
+TEST_FILE = "sort.test.txt"
+VALID_FILE = "sort.valid.txt"
+DATA_DIR = os.path.join(os.getcwd(), "data")
+SEQ_LEN = 5
+
def MakeInput(char, vocab, arr):
idx = vocab[char]
tmp = np.zeros((1,))
tmp[0] = idx
arr[:] = tmp
-if __name__ == '__main__':
+def main():
+ tks = sys.argv[1:]
+ assert len(tks) >= 5, "Please provide 5 numbers for sorting as sequence length is 5"
batch_size = 1
buckets = []
num_hidden = 300
@@ -42,20 +51,21 @@ if __name__ == '__main__':
learning_rate = 0.1
momentum = 0.9
- contexts = [mx.context.gpu(i) for i in range(1)]
+ contexts = [mx.context.cpu(i) for i in range(1)]
- vocab = default_build_vocab("./data/sort.train.txt")
+ vocab = default_build_vocab(os.path.join(DATA_DIR, TRAIN_FILE))
rvocab = {}
for k, v in vocab.items():
rvocab[v] = k
_, arg_params, __ = mx.model.load_checkpoint("sort", 1)
+ for tk in tks:
+ assert (tk in vocab), "{} not in range of numbers that the model trained for.".format(tk)
- model = BiLSTMInferenceModel(5, len(vocab),
+ model = BiLSTMInferenceModel(SEQ_LEN, len(vocab),
num_hidden=num_hidden, num_embed=num_embed,
num_label=len(vocab), arg_params=arg_params, ctx=contexts, dropout=0.0)
- tks = sys.argv[1:]
data = np.zeros((1, len(tks)))
for k in range(len(tks)):
data[0][k] = vocab[tks[k]]
@@ -65,3 +75,6 @@ if __name__ == '__main__':
for k in range(len(tks)):
print(rvocab[np.argmax(prob, axis = 1)[k]])
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/example/bi-lstm-sort/lstm.py b/example/bi-lstm-sort/lstm.py
index a082092..362481d 100644
--- a/example/bi-lstm-sort/lstm.py
+++ b/example/bi-lstm-sort/lstm.py
@@ -17,7 +17,6 @@
# pylint:skip-file
import sys
-sys.path.insert(0, "../../python")
import mxnet as mx
import numpy as np
from collections import namedtuple
diff --git a/example/bi-lstm-sort/lstm_sort.py b/example/bi-lstm-sort/lstm_sort.py
index aef88b8..3fd4a2a 100644
--- a/example/bi-lstm-sort/lstm_sort.py
+++ b/example/bi-lstm-sort/lstm_sort.py
@@ -17,14 +17,65 @@
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
+import os
import sys
-sys.path.insert(0, "../../python")
import numpy as np
import mxnet as mx
+import random
+import argparse
from lstm import bi_lstm_unroll
from sort_io import BucketSentenceIter, default_build_vocab
+import logging
+head = '%(asctime)-15s %(message)s'
+logging.basicConfig(level=logging.DEBUG, format=head)
+
+
+TRAIN_FILE = "sort.train.txt"
+TEST_FILE = "sort.test.txt"
+VALID_FILE = "sort.valid.txt"
+DATA_DIR = os.path.join(os.getcwd(), "data")
+SEQ_LEN = 5
+
+def gen_data(seq_len, start_range, end_range):
+ if not os.path.exists(DATA_DIR):
+ try:
+ logging.info('create directory %s', DATA_DIR)
+ os.makedirs(DATA_DIR)
+ except OSError as exc:
+ if exc.errno != errno.EEXIST:
+ raise OSError('failed to create ' + DATA_DIR)
+ vocab = [str(x) for x in range(start_range, end_range)]
+ sw_train = open(os.path.join(DATA_DIR, TRAIN_FILE), "w")
+ sw_test = open(os.path.join(DATA_DIR, TEST_FILE), "w")
+ sw_valid = open(os.path.join(DATA_DIR, VALID_FILE), "w")
+
+ for i in range(1000000):
+ seq = " ".join([vocab[random.randint(0, len(vocab) - 1)] for j in range(seq_len)])
+ k = i % 50
+ if k == 0:
+ sw_test.write(seq + "\n")
+ elif k == 1:
+ sw_valid.write(seq + "\n")
+ else:
+ sw_train.write(seq + "\n")
+
+ sw_train.close()
+ sw_test.close()
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Parse args for lstm_sort example",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('--start-range', type=int, default=100,
+ help='starting number of the range')
+ parser.add_argument('--end-range', type=int, default=1000,
+ help='Ending number of the range')
+ parser.add_argument('--cpu', action='store_true',
+ help='To use CPU for training')
+ return parser.parse_args()
+
+
def Perplexity(label, pred):
label = label.T.reshape((-1,))
loss = 0.
@@ -32,7 +83,9 @@ def Perplexity(label, pred):
loss += -np.log(max(1e-10, pred[i][int(label[i])]))
return np.exp(loss / label.size)
-if __name__ == '__main__':
+def main():
+ args = parse_args()
+ gen_data(SEQ_LEN, args.start_range, args.end_range)
batch_size = 100
buckets = []
num_hidden = 300
@@ -43,9 +96,12 @@ if __name__ == '__main__':
learning_rate = 0.1
momentum = 0.9
- contexts = [mx.context.gpu(i) for i in range(1)]
+ if args.cpu:
+ contexts = [mx.context.cpu(i) for i in range(1)]
+ else:
+ contexts = [mx.context.gpu(i) for i in range(1)]
- vocab = default_build_vocab("./data/sort.train.txt")
+ vocab = default_build_vocab(os.path.join(DATA_DIR, TRAIN_FILE))
def sym_gen(seq_len):
return bi_lstm_unroll(seq_len, len(vocab),
@@ -56,9 +112,9 @@ if __name__ == '__main__':
init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
init_states = init_c + init_h
- data_train = BucketSentenceIter("./data/sort.train.txt", vocab,
+ data_train = BucketSentenceIter(os.path.join(DATA_DIR, TRAIN_FILE), vocab,
buckets, batch_size, init_states)
- data_val = BucketSentenceIter("./data/sort.valid.txt", vocab,
+ data_val = BucketSentenceIter(os.path.join(DATA_DIR, VALID_FILE), vocab,
buckets, batch_size, init_states)
if len(buckets) == 1:
@@ -74,12 +130,11 @@ if __name__ == '__main__':
wd=0.00001,
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
- import logging
- head = '%(asctime)-15s %(message)s'
- logging.basicConfig(level=logging.DEBUG, format=head)
-
model.fit(X=data_train, eval_data=data_val,
eval_metric = mx.metric.np(Perplexity),
batch_end_callback=mx.callback.Speedometer(batch_size, 50),)
model.save("sort")
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/example/bi-lstm-sort/rnn_model.py b/example/bi-lstm-sort/rnn_model.py
index 202aae6..1079e90 100644
--- a/example/bi-lstm-sort/rnn_model.py
+++ b/example/bi-lstm-sort/rnn_model.py
@@ -18,7 +18,6 @@
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
import sys
-sys.path.insert(0, "../../python")
import numpy as np
import mxnet as mx
diff --git a/example/bi-lstm-sort/sort_io.py b/example/bi-lstm-sort/sort_io.py
index 8cb44c6..853d0ee 100644
--- a/example/bi-lstm-sort/sort_io.py
+++ b/example/bi-lstm-sort/sort_io.py
@@ -19,7 +19,6 @@
# pylint: disable=superfluous-parens, no-member, invalid-name
from __future__ import print_function
import sys
-sys.path.insert(0, "../../python")
import numpy as np
import mxnet as mx
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].