You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2017/12/20 05:45:21 UTC

[incubator-mxnet] branch master updated: Usability improvement bi lstm sort (#8944)

This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new c71a2f3  Usability improvement bi lstm sort (#8944)
c71a2f3 is described below

commit c71a2f3bc9a815b8ef9b6caa615013c74deb10ad
Author: Anirudh Subramanian <an...@gmail.com>
AuthorDate: Tue Dec 19 21:45:17 2017 -0800

    Usability improvement bi lstm sort (#8944)
    
    * Improve usability for the bilstm example
    
    * Remove argparse from infer_sort since it changes existing usage
---
 example/bi-lstm-sort/README.md     | 48 +++++++++++-------------
 example/bi-lstm-sort/infer_sort.py | 25 ++++++++++---
 example/bi-lstm-sort/lstm.py       |  1 -
 example/bi-lstm-sort/lstm_sort.py  | 75 +++++++++++++++++++++++++++++++++-----
 example/bi-lstm-sort/rnn_model.py  |  1 -
 example/bi-lstm-sort/sort_io.py    |  1 -
 6 files changed, 106 insertions(+), 45 deletions(-)

diff --git a/example/bi-lstm-sort/README.md b/example/bi-lstm-sort/README.md
index a590a18..3bacc86 100644
--- a/example/bi-lstm-sort/README.md
+++ b/example/bi-lstm-sort/README.md
@@ -1,28 +1,24 @@
 This is an example of using bidirection lstm to sort an array.
 
-Firstly, generate data by:
-
-	python gen_data.py
-
-Move generated txt files to data directory
-
-	mkdir data
-	mv *.txt data
-
-Then, train the model by:
-
-    python lstm_sort.py
-
-At last, test model by:
-
-    python infer_sort.py 234 189 785 763 231
-
-and will output sorted seq
-
-    189
-	231
-	234
-	763
-	785
-
-
+Run the training script by doing the following:
+
+```
+python lstm_sort.py --start-range 100 --end-range 1000 --cpu
+```
+You can provide the start-range and end-range for the numbers and whether to train on the cpu or not.
+By default the script tries to train on the GPU. The default start-range is 100 and end-range is 1000.
+
+At last, test model by doing the following:
+
+```
+python infer_sort.py 234 189 785 763 231
+```
+
+This should output the sorted seq like the following:
+```
+189
+231
+234
+763
+785
+```
diff --git a/example/bi-lstm-sort/infer_sort.py b/example/bi-lstm-sort/infer_sort.py
index b074c03..f81c6c0 100644
--- a/example/bi-lstm-sort/infer_sort.py
+++ b/example/bi-lstm-sort/infer_sort.py
@@ -18,20 +18,29 @@
 # pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
 # pylint: disable=superfluous-parens, no-member, invalid-name
 import sys
-sys.path.insert(0, "../../python")
+import os
+import argparse
 import numpy as np
 import mxnet as mx
 
 from sort_io import BucketSentenceIter, default_build_vocab
 from rnn_model import BiLSTMInferenceModel
 
+TRAIN_FILE = "sort.train.txt"
+TEST_FILE = "sort.test.txt"
+VALID_FILE = "sort.valid.txt"
+DATA_DIR = os.path.join(os.getcwd(), "data")
+SEQ_LEN = 5
+
 def MakeInput(char, vocab, arr):
     idx = vocab[char]
     tmp = np.zeros((1,))
     tmp[0] = idx
     arr[:] = tmp
 
-if __name__ == '__main__':
+def main():
+    tks = sys.argv[1:]
+    assert len(tks) >= 5, "Please provide 5 numbers for sorting as sequence length is 5"
     batch_size = 1
     buckets = []
     num_hidden = 300
@@ -42,20 +51,21 @@ if __name__ == '__main__':
     learning_rate = 0.1
     momentum = 0.9
 
-    contexts = [mx.context.gpu(i) for i in range(1)]
+    contexts = [mx.context.cpu(i) for i in range(1)]
 
-    vocab = default_build_vocab("./data/sort.train.txt")
+    vocab = default_build_vocab(os.path.join(DATA_DIR, TRAIN_FILE))
     rvocab = {}
     for k, v in vocab.items():
         rvocab[v] = k
 
     _, arg_params, __ = mx.model.load_checkpoint("sort", 1)
+    for tk in tks:
+        assert (tk in vocab), "{} not in range of numbers  that  the model trained for.".format(tk)
 
-    model = BiLSTMInferenceModel(5, len(vocab),
+    model = BiLSTMInferenceModel(SEQ_LEN, len(vocab),
                                  num_hidden=num_hidden, num_embed=num_embed,
                                  num_label=len(vocab), arg_params=arg_params, ctx=contexts, dropout=0.0)
 
-    tks = sys.argv[1:]
     data = np.zeros((1, len(tks)))
     for k in range(len(tks)):
         data[0][k] = vocab[tks[k]]
@@ -65,3 +75,6 @@ if __name__ == '__main__':
     for k in range(len(tks)):
         print(rvocab[np.argmax(prob, axis = 1)[k]])
 
+
+if __name__ == '__main__':
+    sys.exit(main())
diff --git a/example/bi-lstm-sort/lstm.py b/example/bi-lstm-sort/lstm.py
index a082092..362481d 100644
--- a/example/bi-lstm-sort/lstm.py
+++ b/example/bi-lstm-sort/lstm.py
@@ -17,7 +17,6 @@
 
 # pylint:skip-file
 import sys
-sys.path.insert(0, "../../python")
 import mxnet as mx
 import numpy as np
 from collections import namedtuple
diff --git a/example/bi-lstm-sort/lstm_sort.py b/example/bi-lstm-sort/lstm_sort.py
index aef88b8..3fd4a2a 100644
--- a/example/bi-lstm-sort/lstm_sort.py
+++ b/example/bi-lstm-sort/lstm_sort.py
@@ -17,14 +17,65 @@
 
 # pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
 # pylint: disable=superfluous-parens, no-member, invalid-name
+import os
 import sys
-sys.path.insert(0, "../../python")
 import numpy as np
 import mxnet as mx
+import random
+import argparse
 
 from lstm import bi_lstm_unroll
 from sort_io import BucketSentenceIter, default_build_vocab
 
+import logging
+head = '%(asctime)-15s %(message)s'
+logging.basicConfig(level=logging.DEBUG, format=head)
+
+
+TRAIN_FILE = "sort.train.txt"
+TEST_FILE = "sort.test.txt"
+VALID_FILE = "sort.valid.txt"
+DATA_DIR = os.path.join(os.getcwd(), "data")
+SEQ_LEN = 5
+
+def gen_data(seq_len, start_range, end_range):
+    if not os.path.exists(DATA_DIR):
+        try:
+            logging.info('create directory %s', DATA_DIR)
+            os.makedirs(DATA_DIR)
+        except OSError as exc:
+            if exc.errno != errno.EEXIST:
+                raise OSError('failed to create ' + DATA_DIR)
+    vocab = [str(x) for x in range(start_range, end_range)]
+    sw_train = open(os.path.join(DATA_DIR, TRAIN_FILE), "w")
+    sw_test = open(os.path.join(DATA_DIR, TEST_FILE), "w")
+    sw_valid = open(os.path.join(DATA_DIR, VALID_FILE), "w")
+
+    for i in range(1000000):
+        seq = " ".join([vocab[random.randint(0, len(vocab) - 1)] for j in range(seq_len)])
+        k = i % 50
+        if k == 0:
+            sw_test.write(seq + "\n")
+        elif k == 1:
+            sw_valid.write(seq + "\n")
+        else:
+            sw_train.write(seq + "\n")
+
+    sw_train.close()
+    sw_test.close()
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Parse args for lstm_sort example",
+                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add_argument('--start-range', type=int, default=100,
+                        help='starting number of the range')
+    parser.add_argument('--end-range', type=int, default=1000,
+                        help='Ending number of the range')
+    parser.add_argument('--cpu', action='store_true',
+                        help='To use CPU for training')
+    return parser.parse_args()
+
+
 def Perplexity(label, pred):
     label = label.T.reshape((-1,))
     loss = 0.
@@ -32,7 +83,9 @@ def Perplexity(label, pred):
         loss += -np.log(max(1e-10, pred[i][int(label[i])]))
     return np.exp(loss / label.size)
 
-if __name__ == '__main__':
+def main():
+    args = parse_args()
+    gen_data(SEQ_LEN, args.start_range, args.end_range)
     batch_size = 100
     buckets = []
     num_hidden = 300
@@ -43,9 +96,12 @@ if __name__ == '__main__':
     learning_rate = 0.1
     momentum = 0.9
 
-    contexts = [mx.context.gpu(i) for i in range(1)]
+    if args.cpu:
+        contexts = [mx.context.cpu(i) for i in range(1)]
+    else:
+        contexts = [mx.context.gpu(i) for i in range(1)]
 
-    vocab = default_build_vocab("./data/sort.train.txt")
+    vocab = default_build_vocab(os.path.join(DATA_DIR, TRAIN_FILE))
 
     def sym_gen(seq_len):
         return bi_lstm_unroll(seq_len, len(vocab),
@@ -56,9 +112,9 @@ if __name__ == '__main__':
     init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
     init_states = init_c + init_h
 
-    data_train = BucketSentenceIter("./data/sort.train.txt", vocab,
+    data_train = BucketSentenceIter(os.path.join(DATA_DIR, TRAIN_FILE), vocab,
                                     buckets, batch_size, init_states)
-    data_val = BucketSentenceIter("./data/sort.valid.txt", vocab,
+    data_val = BucketSentenceIter(os.path.join(DATA_DIR, VALID_FILE), vocab,
                                   buckets, batch_size, init_states)
 
     if len(buckets) == 1:
@@ -74,12 +130,11 @@ if __name__ == '__main__':
                                  wd=0.00001,
                                  initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
 
-    import logging
-    head = '%(asctime)-15s %(message)s'
-    logging.basicConfig(level=logging.DEBUG, format=head)
-
     model.fit(X=data_train, eval_data=data_val,
               eval_metric = mx.metric.np(Perplexity),
               batch_end_callback=mx.callback.Speedometer(batch_size, 50),)
 
     model.save("sort")
+
+if __name__ == '__main__':
+    sys.exit(main())
diff --git a/example/bi-lstm-sort/rnn_model.py b/example/bi-lstm-sort/rnn_model.py
index 202aae6..1079e90 100644
--- a/example/bi-lstm-sort/rnn_model.py
+++ b/example/bi-lstm-sort/rnn_model.py
@@ -18,7 +18,6 @@
 # pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
 # pylint: disable=superfluous-parens, no-member, invalid-name
 import sys
-sys.path.insert(0, "../../python")
 import numpy as np
 import mxnet as mx
 
diff --git a/example/bi-lstm-sort/sort_io.py b/example/bi-lstm-sort/sort_io.py
index 8cb44c6..853d0ee 100644
--- a/example/bi-lstm-sort/sort_io.py
+++ b/example/bi-lstm-sort/sort_io.py
@@ -19,7 +19,6 @@
 # pylint: disable=superfluous-parens, no-member, invalid-name
 from __future__ import print_function
 import sys
-sys.path.insert(0, "../../python")
 import numpy as np
 import mxnet as mx
 

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].