You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/10/25 21:41:59 UTC
[incubator-mxnet] branch master updated: Update bilstm integer
array sorting example (#12929)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new a39152b Update bilstm integer array sorting example (#12929)
a39152b is described below
commit a39152b4ff00f8db0ad88f555c35941393afb289
Author: Thomas Delteil <th...@gmail.com>
AuthorDate: Thu Oct 25 14:41:42 2018 -0700
Update bilstm integer array sorting example (#12929)
* Update the bilstm example to Gluon
* Update formating
* Update example/vae/VAE_example.ipynb
Co-Authored-By: ThomasDelteil <th...@gmail.com>
---
example/bi-lstm-sort/README.md | 28 +-
example/bi-lstm-sort/bi-lstm-sort.ipynb | 607 ++++++++++++++++++++++++++++++++
example/bi-lstm-sort/gen_data.py | 37 --
example/bi-lstm-sort/infer_sort.py | 80 -----
example/bi-lstm-sort/lstm.py | 175 ---------
example/bi-lstm-sort/lstm_sort.py | 142 --------
example/bi-lstm-sort/rnn_model.py | 73 ----
example/bi-lstm-sort/sort_io.py | 255 --------------
8 files changed, 616 insertions(+), 781 deletions(-)
diff --git a/example/bi-lstm-sort/README.md b/example/bi-lstm-sort/README.md
index 3bacc86..f00cc85 100644
--- a/example/bi-lstm-sort/README.md
+++ b/example/bi-lstm-sort/README.md
@@ -1,24 +1,14 @@
-This is an example of using bidirection lstm to sort an array.
+# Bidirectionnal LSTM to sort an array.
-Run the training script by doing the following:
+This is an example of using bidirectionmal lstm to sort an array. Please refer to the notebook.
-```
-python lstm_sort.py --start-range 100 --end-range 1000 --cpu
-```
-You can provide the start-range and end-range for the numbers and whether to train on the cpu or not.
-By default the script tries to train on the GPU. The default start-range is 100 and end-range is 1000.
+We train a bidirectionnal LSTM to sort an array of integer.
-At last, test model by doing the following:
+For example:
-```
-python infer_sort.py 234 189 785 763 231
-```
+`500 30 999 10 130` should give us `10 30 130 500 999`
-This should output the sorted seq like the following:
-```
-189
-231
-234
-763
-785
-```
+![](https://cdn-images-1.medium.com/max/1200/1*6QnPUSv_t9BY9Fv8_aLb-Q.png)
+
+
+([Diagram source](http://colah.github.io/posts/2015-09-NN-Types-FP/))
\ No newline at end of file
diff --git a/example/bi-lstm-sort/bi-lstm-sort.ipynb b/example/bi-lstm-sort/bi-lstm-sort.ipynb
new file mode 100644
index 0000000..0851176
--- /dev/null
+++ b/example/bi-lstm-sort/bi-lstm-sort.ipynb
@@ -0,0 +1,607 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Using a bi-lstm to sort a sequence of integers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import random\n",
+ "import string\n",
+ "\n",
+ "import mxnet as mx\n",
+ "from mxnet import gluon, nd\n",
+ "import numpy as np"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Data Preparation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "max_num = 999\n",
+ "dataset_size = 60000\n",
+ "seq_len = 5\n",
+ "split = 0.8\n",
+ "batch_size = 512\n",
+ "ctx = mx.gpu() if len(mx.test_utils.list_gpus()) > 0 else mx.cpu()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We are getting a dataset of **dataset_size** sequences of integers of length **seq_len** between **0** and **max_num**. We use **split*100%** of them for training and the rest for testing.\n",
+ "\n",
+ "\n",
+ "For example:\n",
+ "\n",
+ "50 10 200 999 30\n",
+ "\n",
+ "Should return\n",
+ "\n",
+ "10 30 50 200 999"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "X = mx.random.uniform(low=0, high=max_num, shape=(dataset_size, seq_len)).astype('int32').asnumpy()\n",
+ "Y = X.copy()\n",
+ "Y.sort() #Let's sort X to get the target"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Input [548, 592, 714, 843, 602]\n",
+ "Target [548, 592, 602, 714, 843]\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Input {}\\nTarget {}\".format(X[0].tolist(), Y[0].tolist()))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "For the purpose of training, we encode the input as characters rather than numbers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0123456789 \n",
+ "{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, ' ': 10}\n"
+ ]
+ }
+ ],
+ "source": [
+ "vocab = string.digits + \" \"\n",
+ "print(vocab)\n",
+ "vocab_idx = { c:i for i,c in enumerate(vocab)}\n",
+ "print(vocab_idx)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We write a transform that will convert our numbers into text of maximum length **max_len**, and one-hot encode the characters.\n",
+ "For example:\n",
+ "\n",
+ "\"30 10\" corresponding indices are [3, 0, 10, 1, 0]\n",
+ "\n",
+ "We then one hot encode that and get a matrix representation of our input. We don't need to encode our target as the loss we are going to use support sparse labels"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Maximum length of the string: 19\n"
+ ]
+ }
+ ],
+ "source": [
+ "max_len = len(str(max_num))*seq_len+(seq_len-1)\n",
+ "print(\"Maximum length of the string: %s\" % max_len)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def transform(x, y):\n",
+ " x_string = ' '.join(map(str, x.tolist()))\n",
+ " x_string_padded = x_string + ' '*(max_len-len(x_string))\n",
+ " x = [vocab_idx[c] for c in x_string_padded]\n",
+ " y_string = ' '.join(map(str, y.tolist()))\n",
+ " y_string_padded = y_string + ' '*(max_len-len(y_string))\n",
+ " y = [vocab_idx[c] for c in y_string_padded]\n",
+ " return mx.nd.one_hot(mx.nd.array(x), len(vocab)), mx.nd.array(y)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "split_idx = int(split*len(X))\n",
+ "train_dataset = gluon.data.ArrayDataset(X[:split_idx], Y[:split_idx]).transform(transform)\n",
+ "test_dataset = gluon.data.ArrayDataset(X[split_idx:], Y[split_idx:]).transform(transform)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Input [548 592 714 843 602]\n",
+ "Transformed data Input \n",
+ "[[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
+ " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
+ " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
+ " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
+ " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
+ " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
+ " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
+ " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
+ " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n",
+ " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
+ " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
+ " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
+ " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
+ " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
+ " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n",
+ " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
+ " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n",
+ " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
+ " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]]\n",
+ "<NDArray 19x11 @cpu(0)>\n",
+ "Target [548 592 602 714 843]\n",
+ "Transformed data Target \n",
+ "[ 5. 4. 8. 10. 5. 9. 2. 10. 6. 0. 2. 10. 7. 1. 4. 10. 8. 4.\n",
+ " 3.]\n",
+ "<NDArray 19 @cpu(0)>\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Input {}\".format(X[0]))\n",
+ "print(\"Transformed data Input {}\".format(train_dataset[0][0]))\n",
+ "print(\"Target {}\".format(Y[0]))\n",
+ "print(\"Transformed data Target {}\".format(train_dataset[0][1]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_data = gluon.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=20, last_batch='rollover')\n",
+ "test_data = gluon.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=5, last_batch='rollover')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Creating the network"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "net = gluon.nn.HybridSequential()\n",
+ "with net.name_scope():\n",
+ " net.add(\n",
+ " gluon.rnn.LSTM(hidden_size=128, num_layers=2, layout='NTC', bidirectional=True),\n",
+ " gluon.nn.Dense(len(vocab), flatten=False)\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "net.initialize(mx.init.Xavier(), ctx=ctx)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "loss = gluon.loss.SoftmaxCELoss()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We use a learning rate schedule to improve the convergence of the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "schedule = mx.lr_scheduler.FactorScheduler(step=len(train_data)*10, factor=0.75)\n",
+ "schedule.base_lr = 0.01"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate':0.01, 'lr_scheduler':schedule})"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Training loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch [0] Loss: 1.6627886372227823, LR 0.01\n",
+ "Epoch [1] Loss: 1.210370733382854, LR 0.01\n",
+ "Epoch [2] Loss: 0.9692377131035987, LR 0.01\n",
+ "Epoch [3] Loss: 0.7976046623067653, LR 0.01\n",
+ "Epoch [4] Loss: 0.5714595343476983, LR 0.01\n",
+ "Epoch [5] Loss: 0.4458411196444897, LR 0.01\n",
+ "Epoch [6] Loss: 0.36039798817736035, LR 0.01\n",
+ "Epoch [7] Loss: 0.32665719377233626, LR 0.01\n",
+ "Epoch [8] Loss: 0.262064205702915, LR 0.01\n",
+ "Epoch [9] Loss: 0.22285924059279422, LR 0.0075\n",
+ "Epoch [10] Loss: 0.19018426854559717, LR 0.0075\n",
+ "Epoch [11] Loss: 0.1718730723604243, LR 0.0075\n",
+ "Epoch [12] Loss: 0.15736752171670237, LR 0.0075\n",
+ "Epoch [13] Loss: 0.14579375246737866, LR 0.0075\n",
+ "Epoch [14] Loss: 0.13546599733068587, LR 0.0075\n",
+ "Epoch [15] Loss: 0.12490207590955368, LR 0.0075\n",
+ "Epoch [16] Loss: 0.11803316300915133, LR 0.0075\n",
+ "Epoch [17] Loss: 0.10653189395336395, LR 0.0075\n",
+ "Epoch [18] Loss: 0.10514750379197141, LR 0.0075\n",
+ "Epoch [19] Loss: 0.09590611559279422, LR 0.005625\n",
+ "Epoch [20] Loss: 0.08146028108494256, LR 0.005625\n",
+ "Epoch [21] Loss: 0.07707348782965477, LR 0.005625\n",
+ "Epoch [22] Loss: 0.07206193436967566, LR 0.005625\n",
+ "Epoch [23] Loss: 0.07001185417175293, LR 0.005625\n",
+ "Epoch [24] Loss: 0.06797058351578252, LR 0.005625\n",
+ "Epoch [25] Loss: 0.0649358110224947, LR 0.005625\n",
+ "Epoch [26] Loss: 0.06219124286732775, LR 0.005625\n",
+ "Epoch [27] Loss: 0.06075144828634059, LR 0.005625\n",
+ "Epoch [28] Loss: 0.05711334495134251, LR 0.005625\n",
+ "Epoch [29] Loss: 0.054747099572039666, LR 0.00421875\n",
+ "Epoch [30] Loss: 0.0441775271233092, LR 0.00421875\n",
+ "Epoch [31] Loss: 0.041551097910454936, LR 0.00421875\n",
+ "Epoch [32] Loss: 0.04095017269093503, LR 0.00421875\n",
+ "Epoch [33] Loss: 0.04045371045457556, LR 0.00421875\n",
+ "Epoch [34] Loss: 0.038867686657195394, LR 0.00421875\n",
+ "Epoch [35] Loss: 0.038131744303601854, LR 0.00421875\n",
+ "Epoch [36] Loss: 0.039834817250569664, LR 0.00421875\n",
+ "Epoch [37] Loss: 0.03669035941996473, LR 0.00421875\n",
+ "Epoch [38] Loss: 0.03373505967728635, LR 0.00421875\n",
+ "Epoch [39] Loss: 0.03164981273894615, LR 0.0031640625\n",
+ "Epoch [40] Loss: 0.025532766055035336, LR 0.0031640625\n",
+ "Epoch [41] Loss: 0.022659448867148543, LR 0.0031640625\n",
+ "Epoch [42] Loss: 0.02307056112492338, LR 0.0031640625\n",
+ "Epoch [43] Loss: 0.02236944056571798, LR 0.0031640625\n",
+ "Epoch [44] Loss: 0.022204211963120328, LR 0.0031640625\n",
+ "Epoch [45] Loss: 0.02262336903430046, LR 0.0031640625\n",
+ "Epoch [46] Loss: 0.02253308448385685, LR 0.0031640625\n",
+ "Epoch [47] Loss: 0.025286573044797207, LR 0.0031640625\n",
+ "Epoch [48] Loss: 0.02439300988310127, LR 0.0031640625\n",
+ "Epoch [49] Loss: 0.017976388018181983, LR 0.002373046875\n",
+ "Epoch [50] Loss: 0.014343131095805067, LR 0.002373046875\n",
+ "Epoch [51] Loss: 0.013039355582379281, LR 0.002373046875\n",
+ "Epoch [52] Loss: 0.011884741885687715, LR 0.002373046875\n",
+ "Epoch [53] Loss: 0.011438189668858305, LR 0.002373046875\n",
+ "Epoch [54] Loss: 0.011447292693117832, LR 0.002373046875\n",
+ "Epoch [55] Loss: 0.014212571560068334, LR 0.002373046875\n",
+ "Epoch [56] Loss: 0.019900493724371797, LR 0.002373046875\n",
+ "Epoch [57] Loss: 0.02102568301748722, LR 0.002373046875\n",
+ "Epoch [58] Loss: 0.01346214400961044, LR 0.002373046875\n",
+ "Epoch [59] Loss: 0.010107964911359422, LR 0.0017797851562500002\n",
+ "Epoch [60] Loss: 0.008353193600972494, LR 0.0017797851562500002\n",
+ "Epoch [61] Loss: 0.007678258292218472, LR 0.0017797851562500002\n",
+ "Epoch [62] Loss: 0.007262124660167288, LR 0.0017797851562500002\n",
+ "Epoch [63] Loss: 0.00705223578087827, LR 0.0017797851562500002\n",
+ "Epoch [64] Loss: 0.006788556293774677, LR 0.0017797851562500002\n",
+ "Epoch [65] Loss: 0.006473606571238091, LR 0.0017797851562500002\n",
+ "Epoch [66] Loss: 0.006206096486842378, LR 0.0017797851562500002\n",
+ "Epoch [67] Loss: 0.00584477313021396, LR 0.0017797851562500002\n",
+ "Epoch [68] Loss: 0.005648705267137097, LR 0.0017797851562500002\n",
+ "Epoch [69] Loss: 0.006481769871204458, LR 0.0013348388671875003\n",
+ "Epoch [70] Loss: 0.008430448618341, LR 0.0013348388671875003\n",
+ "Epoch [71] Loss: 0.006877245421105242, LR 0.0013348388671875003\n",
+ "Epoch [72] Loss: 0.005671108281740578, LR 0.0013348388671875003\n",
+ "Epoch [73] Loss: 0.004832422162624116, LR 0.0013348388671875003\n",
+ "Epoch [74] Loss: 0.004441103402604448, LR 0.0013348388671875003\n",
+ "Epoch [75] Loss: 0.004216198591475791, LR 0.0013348388671875003\n",
+ "Epoch [76] Loss: 0.004041922989711967, LR 0.0013348388671875003\n",
+ "Epoch [77] Loss: 0.003937713643337818, LR 0.0013348388671875003\n",
+ "Epoch [78] Loss: 0.010251983049068046, LR 0.0013348388671875003\n",
+ "Epoch [79] Loss: 0.01829354052848004, LR 0.0010011291503906252\n",
+ "Epoch [80] Loss: 0.006723233448561802, LR 0.0010011291503906252\n",
+ "Epoch [81] Loss: 0.004397524798170049, LR 0.0010011291503906252\n",
+ "Epoch [82] Loss: 0.0038475305476087206, LR 0.0010011291503906252\n",
+ "Epoch [83] Loss: 0.003591177945441388, LR 0.0010011291503906252\n",
+ "Epoch [84] Loss: 0.003425112014175743, LR 0.0010011291503906252\n",
+ "Epoch [85] Loss: 0.0032633850549129728, LR 0.0010011291503906252\n",
+ "Epoch [86] Loss: 0.0031762316505959693, LR 0.0010011291503906252\n",
+ "Epoch [87] Loss: 0.0030452777096565734, LR 0.0010011291503906252\n",
+ "Epoch [88] Loss: 0.002950224184220837, LR 0.0010011291503906252\n",
+ "Epoch [89] Loss: 0.002821172171450676, LR 0.0007508468627929689\n",
+ "Epoch [90] Loss: 0.002725780961361337, LR 0.0007508468627929689\n",
+ "Epoch [91] Loss: 0.002660556359493986, LR 0.0007508468627929689\n",
+ "Epoch [92] Loss: 0.0026011724946319414, LR 0.0007508468627929689\n",
+ "Epoch [93] Loss: 0.0025355776256703317, LR 0.0007508468627929689\n",
+ "Epoch [94] Loss: 0.0024825221997626283, LR 0.0007508468627929689\n",
+ "Epoch [95] Loss: 0.0024245587435174497, LR 0.0007508468627929689\n",
+ "Epoch [96] Loss: 0.002365282145879602, LR 0.0007508468627929689\n",
+ "Epoch [97] Loss: 0.0023112583984719946, LR 0.0007508468627929689\n",
+ "Epoch [98] Loss: 0.002257173682780976, LR 0.0007508468627929689\n",
+ "Epoch [99] Loss: 0.002162747085094452, LR 0.0005631351470947267\n"
+ ]
+ }
+ ],
+ "source": [
+ "epochs = 100\n",
+ "for e in range(epochs):\n",
+ " epoch_loss = 0.\n",
+ " for i, (data, label) in enumerate(train_data):\n",
+ " data = data.as_in_context(ctx)\n",
+ " label = label.as_in_context(ctx)\n",
+ "\n",
+ " with mx.autograd.record():\n",
+ " output = net(data)\n",
+ " l = loss(output, label)\n",
+ "\n",
+ " l.backward()\n",
+ " trainer.step(data.shape[0])\n",
+ " \n",
+ " epoch_loss += l.mean()\n",
+ " \n",
+ " print(\"Epoch [{}] Loss: {}, LR {}\".format(e, epoch_loss.asscalar()/(i+1), trainer.learning_rate))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Testing"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We get a random element from the testing set"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "n = random.randint(0, len(test_data)-1)\n",
+ "\n",
+ "x_orig = X[split_idx+n]\n",
+ "y_orig = Y[split_idx+n]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_pred(x):\n",
+ " x, _ = transform(x, x)\n",
+ " output = net(x.as_in_context(ctx).expand_dims(axis=0))\n",
+ "\n",
+ " # Convert output back to string\n",
+ " pred = ''.join([vocab[int(o)] for o in output[0].argmax(axis=1).asnumpy().tolist()])\n",
+ " return pred"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Printing the result"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "X 611 671 275 871 944\n",
+ "Predicted 275 611 671 871 944\n",
+ "Label 275 611 671 871 944\n"
+ ]
+ }
+ ],
+ "source": [
+ "x_ = ' '.join(map(str,x_orig))\n",
+ "label = ' '.join(map(str,y_orig))\n",
+ "print(\"X {}\\nPredicted {}\\nLabel {}\".format(x_, get_pred(x_orig), label))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can also pick our own example, and the network manages to sort it without problem:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 66,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "10 30 130 500 999 \n"
+ ]
+ }
+ ],
+ "source": [
+ "print(get_pred(np.array([500, 30, 999, 10, 130])))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The model has even learned to generalize to examples not on the training set"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 64,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Only four numbers: 105 202 302 501 \n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Only four numbers:\", get_pred(np.array([105, 302, 501, 202])))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "However we can see it has trouble with other edge cases:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 63,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Small digits: 8 0 42 28 \n",
+ "Small digits, 6 numbers: 10 0 20 82 71 115 \n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Small digits:\", get_pred(np.array([10, 3, 5, 2, 8])))\n",
+ "print(\"Small digits, 6 numbers:\", get_pred(np.array([10, 33, 52, 21, 82, 10])))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This could be improved by adjusting the training dataset accordingly"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/example/bi-lstm-sort/gen_data.py b/example/bi-lstm-sort/gen_data.py
deleted file mode 100644
index 55af1b4..0000000
--- a/example/bi-lstm-sort/gen_data.py
+++ /dev/null
@@ -1,37 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import random
-
-vocab = [str(x) for x in range(100, 1000)]
-sw_train = open("sort.train.txt", "w")
-sw_test = open("sort.test.txt", "w")
-sw_valid = open("sort.valid.txt", "w")
-
-for i in range(1000000):
- seq = " ".join([vocab[random.randint(0, len(vocab) - 1)] for j in range(5)])
- k = i % 50
- if k == 0:
- sw_test.write(seq + "\n")
- elif k == 1:
- sw_valid.write(seq + "\n")
- else:
- sw_train.write(seq + "\n")
-
-sw_train.close()
-sw_test.close()
-sw_valid.close()
diff --git a/example/bi-lstm-sort/infer_sort.py b/example/bi-lstm-sort/infer_sort.py
deleted file mode 100644
index f81c6c0..0000000
--- a/example/bi-lstm-sort/infer_sort.py
+++ /dev/null
@@ -1,80 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
-# pylint: disable=superfluous-parens, no-member, invalid-name
-import sys
-import os
-import argparse
-import numpy as np
-import mxnet as mx
-
-from sort_io import BucketSentenceIter, default_build_vocab
-from rnn_model import BiLSTMInferenceModel
-
-TRAIN_FILE = "sort.train.txt"
-TEST_FILE = "sort.test.txt"
-VALID_FILE = "sort.valid.txt"
-DATA_DIR = os.path.join(os.getcwd(), "data")
-SEQ_LEN = 5
-
-def MakeInput(char, vocab, arr):
- idx = vocab[char]
- tmp = np.zeros((1,))
- tmp[0] = idx
- arr[:] = tmp
-
-def main():
- tks = sys.argv[1:]
- assert len(tks) >= 5, "Please provide 5 numbers for sorting as sequence length is 5"
- batch_size = 1
- buckets = []
- num_hidden = 300
- num_embed = 512
- num_lstm_layer = 2
-
- num_epoch = 1
- learning_rate = 0.1
- momentum = 0.9
-
- contexts = [mx.context.cpu(i) for i in range(1)]
-
- vocab = default_build_vocab(os.path.join(DATA_DIR, TRAIN_FILE))
- rvocab = {}
- for k, v in vocab.items():
- rvocab[v] = k
-
- _, arg_params, __ = mx.model.load_checkpoint("sort", 1)
- for tk in tks:
- assert (tk in vocab), "{} not in range of numbers that the model trained for.".format(tk)
-
- model = BiLSTMInferenceModel(SEQ_LEN, len(vocab),
- num_hidden=num_hidden, num_embed=num_embed,
- num_label=len(vocab), arg_params=arg_params, ctx=contexts, dropout=0.0)
-
- data = np.zeros((1, len(tks)))
- for k in range(len(tks)):
- data[0][k] = vocab[tks[k]]
-
- data = mx.nd.array(data)
- prob = model.forward(data)
- for k in range(len(tks)):
- print(rvocab[np.argmax(prob, axis = 1)[k]])
-
-
-if __name__ == '__main__':
- sys.exit(main())
diff --git a/example/bi-lstm-sort/lstm.py b/example/bi-lstm-sort/lstm.py
deleted file mode 100644
index 362481d..0000000
--- a/example/bi-lstm-sort/lstm.py
+++ /dev/null
@@ -1,175 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint:skip-file
-import sys
-import mxnet as mx
-import numpy as np
-from collections import namedtuple
-import time
-import math
-LSTMState = namedtuple("LSTMState", ["c", "h"])
-LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
- "h2h_weight", "h2h_bias"])
-LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol",
- "init_states", "last_states", "forward_state", "backward_state",
- "seq_data", "seq_labels", "seq_outputs",
- "param_blocks"])
-
-def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.):
- """LSTM Cell symbol"""
- if dropout > 0.:
- indata = mx.sym.Dropout(data=indata, p=dropout)
- i2h = mx.sym.FullyConnected(data=indata,
- weight=param.i2h_weight,
- bias=param.i2h_bias,
- num_hidden=num_hidden * 4,
- name="t%d_l%d_i2h" % (seqidx, layeridx))
- h2h = mx.sym.FullyConnected(data=prev_state.h,
- weight=param.h2h_weight,
- bias=param.h2h_bias,
- num_hidden=num_hidden * 4,
- name="t%d_l%d_h2h" % (seqidx, layeridx))
- gates = i2h + h2h
- slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
- name="t%d_l%d_slice" % (seqidx, layeridx))
- in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
- in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
- forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
- out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
- next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
- next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
- return LSTMState(c=next_c, h=next_h)
-
-
-def bi_lstm_unroll(seq_len, input_size,
- num_hidden, num_embed, num_label, dropout=0.):
-
- embed_weight = mx.sym.Variable("embed_weight")
- cls_weight = mx.sym.Variable("cls_weight")
- cls_bias = mx.sym.Variable("cls_bias")
- last_states = []
- last_states.append(LSTMState(c = mx.sym.Variable("l0_init_c"), h = mx.sym.Variable("l0_init_h")))
- last_states.append(LSTMState(c = mx.sym.Variable("l1_init_c"), h = mx.sym.Variable("l1_init_h")))
- forward_param = LSTMParam(i2h_weight=mx.sym.Variable("l0_i2h_weight"),
- i2h_bias=mx.sym.Variable("l0_i2h_bias"),
- h2h_weight=mx.sym.Variable("l0_h2h_weight"),
- h2h_bias=mx.sym.Variable("l0_h2h_bias"))
- backward_param = LSTMParam(i2h_weight=mx.sym.Variable("l1_i2h_weight"),
- i2h_bias=mx.sym.Variable("l1_i2h_bias"),
- h2h_weight=mx.sym.Variable("l1_h2h_weight"),
- h2h_bias=mx.sym.Variable("l1_h2h_bias"))
-
- # embeding layer
- data = mx.sym.Variable('data')
- label = mx.sym.Variable('softmax_label')
- embed = mx.sym.Embedding(data=data, input_dim=input_size,
- weight=embed_weight, output_dim=num_embed, name='embed')
- wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1)
-
- forward_hidden = []
- for seqidx in range(seq_len):
- hidden = wordvec[seqidx]
- next_state = lstm(num_hidden, indata=hidden,
- prev_state=last_states[0],
- param=forward_param,
- seqidx=seqidx, layeridx=0, dropout=dropout)
- hidden = next_state.h
- last_states[0] = next_state
- forward_hidden.append(hidden)
-
- backward_hidden = []
- for seqidx in range(seq_len):
- k = seq_len - seqidx - 1
- hidden = wordvec[k]
- next_state = lstm(num_hidden, indata=hidden,
- prev_state=last_states[1],
- param=backward_param,
- seqidx=k, layeridx=1,dropout=dropout)
- hidden = next_state.h
- last_states[1] = next_state
- backward_hidden.insert(0, hidden)
-
- hidden_all = []
- for i in range(seq_len):
- hidden_all.append(mx.sym.Concat(*[forward_hidden[i], backward_hidden[i]], dim=1))
-
- hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
- pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label,
- weight=cls_weight, bias=cls_bias, name='pred')
-
- label = mx.sym.transpose(data=label)
- label = mx.sym.Reshape(data=label, target_shape=(0,))
- sm = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')
-
- return sm
-
-
-def bi_lstm_inference_symbol(input_size, seq_len,
- num_hidden, num_embed, num_label, dropout=0.):
- seqidx = 0
- embed_weight=mx.sym.Variable("embed_weight")
- cls_weight = mx.sym.Variable("cls_weight")
- cls_bias = mx.sym.Variable("cls_bias")
- last_states = [LSTMState(c = mx.sym.Variable("l0_init_c"), h = mx.sym.Variable("l0_init_h")),
- LSTMState(c = mx.sym.Variable("l1_init_c"), h = mx.sym.Variable("l1_init_h"))]
- forward_param = LSTMParam(i2h_weight=mx.sym.Variable("l0_i2h_weight"),
- i2h_bias=mx.sym.Variable("l0_i2h_bias"),
- h2h_weight=mx.sym.Variable("l0_h2h_weight"),
- h2h_bias=mx.sym.Variable("l0_h2h_bias"))
- backward_param = LSTMParam(i2h_weight=mx.sym.Variable("l1_i2h_weight"),
- i2h_bias=mx.sym.Variable("l1_i2h_bias"),
- h2h_weight=mx.sym.Variable("l1_h2h_weight"),
- h2h_bias=mx.sym.Variable("l1_h2h_bias"))
- data = mx.sym.Variable("data")
- embed = mx.sym.Embedding(data=data, input_dim=input_size,
- weight=embed_weight, output_dim=num_embed, name='embed')
- wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1)
- forward_hidden = []
- for seqidx in range(seq_len):
- next_state = lstm(num_hidden, indata=wordvec[seqidx],
- prev_state=last_states[0],
- param=forward_param,
- seqidx=seqidx, layeridx=0, dropout=0.0)
- hidden = next_state.h
- last_states[0] = next_state
- forward_hidden.append(hidden)
-
- backward_hidden = []
- for seqidx in range(seq_len):
- k = seq_len - seqidx - 1
- next_state = lstm(num_hidden, indata=wordvec[k],
- prev_state=last_states[1],
- param=backward_param,
- seqidx=k, layeridx=1, dropout=0.0)
- hidden = next_state.h
- last_states[1] = next_state
- backward_hidden.insert(0, hidden)
-
- hidden_all = []
- for i in range(seq_len):
- hidden_all.append(mx.sym.Concat(*[forward_hidden[i], backward_hidden[i]], dim=1))
- hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
- fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label,
- weight=cls_weight, bias=cls_bias, name='pred')
- sm = mx.sym.SoftmaxOutput(data=fc, name='softmax')
- output = [sm]
- for state in last_states:
- output.append(state.c)
- output.append(state.h)
- return mx.sym.Group(output)
-
diff --git a/example/bi-lstm-sort/lstm_sort.py b/example/bi-lstm-sort/lstm_sort.py
deleted file mode 100644
index 3d7090a..0000000
--- a/example/bi-lstm-sort/lstm_sort.py
+++ /dev/null
@@ -1,142 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
-# pylint: disable=superfluous-parens, no-member, invalid-name
-import argparse
-import errno
-import logging
-import os
-import random
-import sys
-
-import numpy as np
-
-import mxnet as mx
-from lstm import bi_lstm_unroll
-from sort_io import BucketSentenceIter, default_build_vocab
-
-head = '%(asctime)-15s %(message)s'
-logging.basicConfig(level=logging.DEBUG, format=head)
-
-TRAIN_FILE = "sort.train.txt"
-TEST_FILE = "sort.test.txt"
-VALID_FILE = "sort.valid.txt"
-DATA_DIR = os.path.join(os.getcwd(), "data")
-SEQ_LEN = 5
-
-
-def gen_data(seq_len, start_range, end_range):
- if not os.path.exists(DATA_DIR):
- try:
- logging.info('create directory %s', DATA_DIR)
- os.makedirs(DATA_DIR)
- except OSError as exc:
- if exc.errno != errno.EEXIST:
- raise OSError('failed to create ' + DATA_DIR)
- vocab = [str(x) for x in range(start_range, end_range)]
- sw_train = open(os.path.join(DATA_DIR, TRAIN_FILE), "w")
- sw_test = open(os.path.join(DATA_DIR, TEST_FILE), "w")
- sw_valid = open(os.path.join(DATA_DIR, VALID_FILE), "w")
-
- for i in range(1000000):
- seq = " ".join([vocab[random.randint(0, len(vocab) - 1)] for j in range(seq_len)])
- k = i % 50
- if k == 0:
- sw_test.write(seq + "\n")
- elif k == 1:
- sw_valid.write(seq + "\n")
- else:
- sw_train.write(seq + "\n")
-
- sw_train.close()
- sw_test.close()
-
-def parse_args():
- parser = argparse.ArgumentParser(description="Parse args for lstm_sort example",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument('--start-range', type=int, default=100,
- help='starting number of the range')
- parser.add_argument('--end-range', type=int, default=1000,
- help='Ending number of the range')
- parser.add_argument('--cpu', action='store_true',
- help='To use CPU for training')
- return parser.parse_args()
-
-
-def Perplexity(label, pred):
- label = label.T.reshape((-1,))
- loss = 0.
- for i in range(pred.shape[0]):
- loss += -np.log(max(1e-10, pred[i][int(label[i])]))
- return np.exp(loss / label.size)
-
-def main():
- args = parse_args()
- gen_data(SEQ_LEN, args.start_range, args.end_range)
- batch_size = 100
- buckets = []
- num_hidden = 300
- num_embed = 512
- num_lstm_layer = 2
-
- num_epoch = 1
- learning_rate = 0.1
- momentum = 0.9
-
- if args.cpu:
- contexts = [mx.context.cpu(i) for i in range(1)]
- else:
- contexts = [mx.context.gpu(i) for i in range(1)]
-
- vocab = default_build_vocab(os.path.join(DATA_DIR, TRAIN_FILE))
-
- def sym_gen(seq_len):
- return bi_lstm_unroll(seq_len, len(vocab),
- num_hidden=num_hidden, num_embed=num_embed,
- num_label=len(vocab))
-
- init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
- init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
- init_states = init_c + init_h
-
- data_train = BucketSentenceIter(os.path.join(DATA_DIR, TRAIN_FILE), vocab,
- buckets, batch_size, init_states)
- data_val = BucketSentenceIter(os.path.join(DATA_DIR, VALID_FILE), vocab,
- buckets, batch_size, init_states)
-
- if len(buckets) == 1:
- symbol = sym_gen(buckets[0])
- else:
- symbol = sym_gen
-
- model = mx.model.FeedForward(ctx=contexts,
- symbol=symbol,
- num_epoch=num_epoch,
- learning_rate=learning_rate,
- momentum=momentum,
- wd=0.00001,
- initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
-
- model.fit(X=data_train, eval_data=data_val,
- eval_metric = mx.metric.np(Perplexity),
- batch_end_callback=mx.callback.Speedometer(batch_size, 50),)
-
- model.save("sort")
-
-if __name__ == '__main__':
- sys.exit(main())
diff --git a/example/bi-lstm-sort/rnn_model.py b/example/bi-lstm-sort/rnn_model.py
deleted file mode 100644
index 1079e90..0000000
--- a/example/bi-lstm-sort/rnn_model.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
-# pylint: disable=superfluous-parens, no-member, invalid-name
-import sys
-import numpy as np
-import mxnet as mx
-
-from lstm import LSTMState, LSTMParam, lstm, bi_lstm_inference_symbol
-
-class BiLSTMInferenceModel(object):
- def __init__(self,
- seq_len,
- input_size,
- num_hidden,
- num_embed,
- num_label,
- arg_params,
- ctx=mx.cpu(),
- dropout=0.):
- self.sym = bi_lstm_inference_symbol(input_size, seq_len,
- num_hidden,
- num_embed,
- num_label,
- dropout)
- batch_size = 1
- init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(2)]
- init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(2)]
-
- data_shape = [("data", (batch_size, seq_len, ))]
-
- input_shapes = dict(init_c + init_h + data_shape)
- self.executor = self.sym.simple_bind(ctx=mx.cpu(), **input_shapes)
-
- for key in self.executor.arg_dict.keys():
- if key in arg_params:
- arg_params[key].copyto(self.executor.arg_dict[key])
-
- state_name = []
- for i in range(2):
- state_name.append("l%d_init_c" % i)
- state_name.append("l%d_init_h" % i)
-
- self.states_dict = dict(zip(state_name, self.executor.outputs[1:]))
- self.input_arr = mx.nd.zeros(data_shape[0][1])
-
- def forward(self, input_data, new_seq=False):
- if new_seq == True:
- for key in self.states_dict.keys():
- self.executor.arg_dict[key][:] = 0.
- input_data.copyto(self.executor.arg_dict["data"])
- self.executor.forward()
- for key in self.states_dict.keys():
- self.states_dict[key].copyto(self.executor.arg_dict[key])
- prob = self.executor.outputs[0].asnumpy()
- return prob
-
-
diff --git a/example/bi-lstm-sort/sort_io.py b/example/bi-lstm-sort/sort_io.py
deleted file mode 100644
index 853d0ee..0000000
--- a/example/bi-lstm-sort/sort_io.py
+++ /dev/null
@@ -1,255 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
-# pylint: disable=superfluous-parens, no-member, invalid-name
-from __future__ import print_function
-import sys
-import numpy as np
-import mxnet as mx
-
-# The interface of a data iter that works for bucketing
-#
-# DataIter
-# - default_bucket_key: the bucket key for the default symbol.
-#
-# DataBatch
-# - provide_data: same as DataIter, but specific to this batch
-# - provide_label: same as DataIter, but specific to this batch
-# - bucket_key: the key for the bucket that should be used for this batch
-
-def default_read_content(path):
- with open(path) as ins:
- content = ins.read()
- content = content.replace('\n', ' <eos> ').replace('. ', ' <eos> ')
- return content
-
-def default_build_vocab(path):
- content = default_read_content(path)
- content = content.split(' ')
-
- words = set([x for x in content if len(x) > 0])
- words = [x for x in words]
- words = sorted(words)
- the_vocab = {}
- idx = 1 # 0 is left for zero-padding
- the_vocab[' '] = 0 # put a dummy element here so that len(vocab) is correct
- for word in words:
- if len(word) == 0:
- continue
- if not word in the_vocab:
- the_vocab[word] = idx
- idx += 1
- return the_vocab
-
-def default_text2id(sentence, the_vocab):
- words = sentence.split(' ')
- words = [the_vocab[w] for w in words if len(w) > 0]
- return words
-
-def default_gen_buckets(sentences, batch_size, the_vocab):
- len_dict = {}
- max_len = -1
- for sentence in sentences:
- words = default_text2id(sentence, the_vocab)
- lw = len(words)
- if lw == 0:
- continue
- if lw > max_len:
- max_len = lw
- if lw in len_dict:
- len_dict[lw] += 1
- else:
- len_dict[lw] = 1
- print(len_dict)
-
- tl = 0
- buckets = []
- for l, n in len_dict.items(): # TODO: There are better heuristic ways to do this
- if n + tl >= batch_size:
- buckets.append(l)
- tl = 0
- else:
- tl += n
- if tl > 0:
- buckets.append(max_len)
- return buckets
-
-
-class SimpleBatch(object):
- def __init__(self, data_names, data, label_names, label, bucket_key):
- self.data = data
- self.label = label
- self.data_names = data_names
- self.label_names = label_names
- self.bucket_key = bucket_key
-
- self.pad = 0
- self.index = None # TODO: what is index?
-
- @property
- def provide_data(self):
- return [(n, x.shape) for n, x in zip(self.data_names, self.data)]
-
- @property
- def provide_label(self):
- return [(n, x.shape) for n, x in zip(self.label_names, self.label)]
-
-class DummyIter(mx.io.DataIter):
- "A dummy iterator that always return the same batch, used for speed testing"
- def __init__(self, real_iter):
- super(DummyIter, self).__init__()
- self.real_iter = real_iter
- self.provide_data = real_iter.provide_data
- self.provide_label = real_iter.provide_label
- self.batch_size = real_iter.batch_size
-
- for batch in real_iter:
- self.the_batch = batch
- break
-
- def __iter__(self):
- return self
-
- def next(self):
- return self.the_batch
-
-class BucketSentenceIter(mx.io.DataIter):
- def __init__(self, path, vocab, buckets, batch_size,
- init_states, data_name='data', label_name='label',
- seperate_char=' <eos> ', text2id=None, read_content=None):
- super(BucketSentenceIter, self).__init__()
-
- if text2id is None:
- self.text2id = default_text2id
- else:
- self.text2id = text2id
- if read_content is None:
- self.read_content = default_read_content
- else:
- self.read_content = read_content
- content = self.read_content(path)
- sentences = content.split(seperate_char)
-
- if len(buckets) == 0:
- buckets = default_gen_buckets(sentences, batch_size, vocab)
- print(buckets)
- self.vocab_size = len(vocab)
- self.data_name = data_name
- self.label_name = label_name
-
- buckets.sort()
- self.buckets = buckets
- self.data = [[] for _ in buckets]
-
- # pre-allocate with the largest bucket for better memory sharing
- self.default_bucket_key = max(buckets)
-
- for sentence in sentences:
- sentence = self.text2id(sentence, vocab)
- if len(sentence) == 0:
- continue
- for i, bkt in enumerate(buckets):
- if bkt >= len(sentence):
- self.data[i].append(sentence)
- break
- # we just ignore the sentence it is longer than the maximum
- # bucket size here
-
- # convert data into ndarrays for better speed during training
- data = [np.zeros((len(x), buckets[i])) for i, x in enumerate(self.data)]
- for i_bucket in range(len(self.buckets)):
- for j in range(len(self.data[i_bucket])):
- sentence = self.data[i_bucket][j]
- data[i_bucket][j, :len(sentence)] = sentence
- self.data = data
-
- # Get the size of each bucket, so that we could sample
- # uniformly from the bucket
- bucket_sizes = [len(x) for x in self.data]
-
- print("Summary of dataset ==================")
- for bkt, size in zip(buckets, bucket_sizes):
- print("bucket of len %3d : %d samples" % (bkt, size))
-
- self.batch_size = batch_size
- self.make_data_iter_plan()
-
- self.init_states = init_states
- self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states]
-
- self.provide_data = [('data', (batch_size, self.default_bucket_key))] + init_states
- self.provide_label = [('softmax_label', (self.batch_size, self.default_bucket_key))]
-
- def make_data_iter_plan(self):
- "make a random data iteration plan"
- # truncate each bucket into multiple of batch-size
- bucket_n_batches = []
- for i in range(len(self.data)):
- bucket_n_batches.append(len(self.data[i]) / self.batch_size)
- self.data[i] = self.data[i][:int(bucket_n_batches[i]*self.batch_size)]
-
- bucket_plan = np.hstack([np.zeros(n, int)+i for i, n in enumerate(bucket_n_batches)])
- np.random.shuffle(bucket_plan)
-
- bucket_idx_all = [np.random.permutation(len(x)) for x in self.data]
-
- self.bucket_plan = bucket_plan
- self.bucket_idx_all = bucket_idx_all
- self.bucket_curr_idx = [0 for x in self.data]
-
- self.data_buffer = []
- self.label_buffer = []
- for i_bucket in range(len(self.data)):
- data = np.zeros((self.batch_size, self.buckets[i_bucket]))
- label = np.zeros((self.batch_size, self.buckets[i_bucket]))
- self.data_buffer.append(data)
- self.label_buffer.append(label)
-
- def __iter__(self):
- init_state_names = [x[0] for x in self.init_states]
-
- for i_bucket in self.bucket_plan:
- data = self.data_buffer[i_bucket]
- label = self.label_buffer[i_bucket]
-
- i_idx = self.bucket_curr_idx[i_bucket]
- idx = self.bucket_idx_all[i_bucket][i_idx:i_idx+self.batch_size]
- self.bucket_curr_idx[i_bucket] += self.batch_size
- data[:] = self.data[i_bucket][idx]
-
- for k in range(len(data)):
- label[k] = sorted(data[k])
- #count = len(data[k]) / 2
- #for j in range(count):
- # data[j+count] = data[j]
-
- #label[:, :-1] = data[:, 1:]
- #label[:, -1] = 0
-
- data_all = [mx.nd.array(data)] + self.init_state_arrays
- label_all = [mx.nd.array(label)]
- data_names = ['data'] + init_state_names
- label_names = ['softmax_label']
-
- data_batch = SimpleBatch(data_names, data_all, label_names, label_all,
- self.buckets[i_bucket])
-
- yield data_batch
-
- def reset(self):
- self.bucket_curr_idx = [0 for x in self.data]