You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/12/12 21:42:10 UTC
[incubator-mxnet] branch master updated: Add truncated bptt RNN
example using symbol & module API (#9038)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 06ed4dc Add truncated bptt RNN example using symbol & module API (#9038)
06ed4dc is described below
commit 06ed4dcb1ab6aa1c506dbd3d7bb052b123f63589
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Tue Dec 12 13:42:05 2017 -0800
Add truncated bptt RNN example using symbol & module API (#9038)
* restructure folder
* initial commmit
* update doc
* one more fix
* update dfault args and readme
* update data
---
example/rnn/README.md | 21 ++--
example/rnn/{ => bucketing}/README.md | 2 -
.../rnn/{ => bucketing}/cudnn_lstm_bucketing.py | 0
example/rnn/{ => bucketing}/get_ptb_data.sh | 0
example/rnn/{ => bucketing}/lstm_bucketing.py | 0
example/rnn/word_lm/README.md | 49 ++++++++
example/rnn/word_lm/data.py | 114 +++++++++++++++++
example/rnn/{ => word_lm}/get_ptb_data.sh | 13 +-
example/rnn/word_lm/model.py | 67 ++++++++++
example/rnn/word_lm/module.py | 134 ++++++++++++++++++++
example/rnn/word_lm/train.py | 136 +++++++++++++++++++++
11 files changed, 519 insertions(+), 17 deletions(-)
diff --git a/example/rnn/README.md b/example/rnn/README.md
index 8a6f29d..f0d80c3 100644
--- a/example/rnn/README.md
+++ b/example/rnn/README.md
@@ -1,15 +1,14 @@
-RNN Example
+Recurrent Neural Network Examples
===========
-This folder contains RNN examples using high level mxnet.rnn interface.
-Examples using low level symbol interface have been deprecated and moved to old/
+This directory contains functions for creating recurrent neural networks
+models using high level mxnet.rnn interface.
-## Data
-Run `get_ptb_data.sh` to download PenTreeBank data.
+Here is a short overview of what is in this directory.
-## Python
-
-- [lstm_bucketing.py](lstm_bucketing.py) PennTreeBank language model by using LSTM
-
-Performance Note:
-More ```MXNET_GPU_WORKER_NTHREADS``` may lead to better performance. For setting ```MXNET_GPU_WORKER_NTHREADS```, please refer to [Environment Variables](https://mxnet.readthedocs.org/en/latest/how_to/env_var.html).
+Directory | What's in it?
+--- | ---
+`word_lm/` | Language model trained on the PTB dataset achieving state of the art performance
+`bucketing/` | Language model with bucketing API with python
+`bucket_R/` | Language model with bucketing API with R
+`old/` | Language model trained with low level symbol interface (deprecated)
diff --git a/example/rnn/README.md b/example/rnn/bucketing/README.md
similarity index 85%
copy from example/rnn/README.md
copy to example/rnn/bucketing/README.md
index 8a6f29d..6baf1ec 100644
--- a/example/rnn/README.md
+++ b/example/rnn/bucketing/README.md
@@ -2,8 +2,6 @@ RNN Example
===========
This folder contains RNN examples using high level mxnet.rnn interface.
-Examples using low level symbol interface have been deprecated and moved to old/
-
## Data
Run `get_ptb_data.sh` to download PenTreeBank data.
diff --git a/example/rnn/cudnn_lstm_bucketing.py b/example/rnn/bucketing/cudnn_lstm_bucketing.py
similarity index 100%
rename from example/rnn/cudnn_lstm_bucketing.py
rename to example/rnn/bucketing/cudnn_lstm_bucketing.py
diff --git a/example/rnn/get_ptb_data.sh b/example/rnn/bucketing/get_ptb_data.sh
similarity index 100%
copy from example/rnn/get_ptb_data.sh
copy to example/rnn/bucketing/get_ptb_data.sh
diff --git a/example/rnn/lstm_bucketing.py b/example/rnn/bucketing/lstm_bucketing.py
similarity index 100%
rename from example/rnn/lstm_bucketing.py
rename to example/rnn/bucketing/lstm_bucketing.py
diff --git a/example/rnn/word_lm/README.md b/example/rnn/word_lm/README.md
new file mode 100644
index 0000000..c498032
--- /dev/null
+++ b/example/rnn/word_lm/README.md
@@ -0,0 +1,49 @@
+Word Level Language Modeling
+===========
+This example trains a multi-layer LSTM on Penn Treebank (PTB) language modeling benchmark.
+
+The following techniques have been adopted for SOTA results:
+- [LSTM for LM](https://arxiv.org/pdf/1409.2329.pdf)
+- [Weight tying](https://arxiv.org/abs/1608.05859) between word vectors and softmax output embeddings
+
+## Prerequisite
+The example requires MXNet built with CUDA.
+
+## Data
+The PTB data is the processed version from [(Mikolov et al, 2010)](http://www.fit.vutbr.cz/research/groups/speech/publi/2010/mikolov_interspeech2010_IS100722.pdf):
+
+## Usage
+Example runs and the results:
+
+```
+python train.py --tied --nhid 650 --emsize 650 --dropout 0.5 # Test ppl of 75.4
+```
+
+```
+usage: train.py [-h] [--data DATA] [--emsize EMSIZE] [--nhid NHID]
+ [--nlayers NLAYERS] [--lr LR] [--clip CLIP] [--epochs EPOCHS]
+ [--batch_size BATCH_SIZE] [--dropout DROPOUT] [--tied]
+ [--bptt BPTT] [--log-interval LOG_INTERVAL] [--seed SEED]
+
+PennTreeBank LSTM Language Model
+
+optional arguments:
+ -h, --help show this help message and exit
+ --data DATA location of the data corpus
+ --emsize EMSIZE size of word embeddings
+ --nhid NHID number of hidden units per layer
+ --nlayers NLAYERS number of layers
+ --lr LR initial learning rate
+ --clip CLIP gradient clipping by global norm
+ --epochs EPOCHS upper epoch limit
+ --batch_size BATCH_SIZE
+ batch size
+ --dropout DROPOUT dropout applied to layers (0 = no dropout)
+ --tied tie the word embedding and softmax weights
+ --bptt BPTT sequence length
+ --log-interval LOG_INTERVAL
+ report interval
+ --seed SEED random seed
+```
+
+
diff --git a/example/rnn/word_lm/data.py b/example/rnn/word_lm/data.py
new file mode 100644
index 0000000..ff67088
--- /dev/null
+++ b/example/rnn/word_lm/data.py
@@ -0,0 +1,114 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os, gzip
+import sys
+import mxnet as mx
+import numpy as np
+
+class Dictionary(object):
+ def __init__(self):
+ self.word2idx = {}
+ self.idx2word = []
+ self.word_count = []
+
+ def add_word(self, word):
+ if word not in self.word2idx:
+ self.idx2word.append(word)
+ self.word2idx[word] = len(self.idx2word) - 1
+ self.word_count.append(0)
+ index = self.word2idx[word]
+ self.word_count[index] += 1
+ return index
+
+ def __len__(self):
+ return len(self.idx2word)
+
+class Corpus(object):
+ def __init__(self, path):
+ self.dictionary = Dictionary()
+ self.train = self.tokenize(path + 'train.txt')
+ self.valid = self.tokenize(path + 'valid.txt')
+ self.test = self.tokenize(path + 'test.txt')
+
+ def tokenize(self, path):
+ """Tokenizes a text file."""
+ assert os.path.exists(path)
+ # Add words to the dictionary
+ with open(path, 'r') as f:
+ tokens = 0
+ for line in f:
+ words = line.split() + ['<eos>']
+ tokens += len(words)
+ for word in words:
+ self.dictionary.add_word(word)
+
+ # Tokenize file content
+ with open(path, 'r') as f:
+ ids = np.zeros((tokens,), dtype='int32')
+ token = 0
+ for line in f:
+ words = line.split() + ['<eos>']
+ for word in words:
+ ids[token] = self.dictionary.word2idx[word]
+ token += 1
+
+ return mx.nd.array(ids, dtype='int32')
+
+def batchify(data, batch_size):
+ """Reshape data into (num_example, batch_size)"""
+ nbatch = data.shape[0] // batch_size
+ data = data[:nbatch * batch_size]
+ data = data.reshape((batch_size, nbatch)).T
+ return data
+
+class CorpusIter(mx.io.DataIter):
+ "An iterator that returns the a batch of sequence each time"
+ def __init__(self, source, batch_size, bptt):
+ super(CorpusIter, self).__init__()
+ self.batch_size = batch_size
+ self.provide_data = [('data', (bptt, batch_size), np.int32)]
+ self.provide_label = [('label', (bptt, batch_size))]
+ self._index = 0
+ self._bptt = bptt
+ self._source = batchify(source, batch_size)
+
+ def iter_next(self):
+ i = self._index
+ if i+self._bptt > self._source.shape[0] - 1:
+ return False
+ self._next_data = self._source[i:i+self._bptt]
+ self._next_label = self._source[i+1:i+1+self._bptt].astype(np.float32)
+ self._index += self._bptt
+ return True
+
+ def next(self):
+ if self.iter_next():
+ return mx.io.DataBatch(data=self.getdata(), label=self.getlabel())
+ else:
+ raise StopIteration
+
+ def reset(self):
+ self._index = 0
+ self._next_data = None
+ self._next_label = None
+
+ def getdata(self):
+ return [self._next_data]
+
+ def getlabel(self):
+ return [self._next_label]
diff --git a/example/rnn/get_ptb_data.sh b/example/rnn/word_lm/get_ptb_data.sh
similarity index 59%
rename from example/rnn/get_ptb_data.sh
rename to example/rnn/word_lm/get_ptb_data.sh
index d2641cb..0a0c705 100755
--- a/example/rnn/get_ptb_data.sh
+++ b/example/rnn/word_lm/get_ptb_data.sh
@@ -17,6 +17,11 @@
# specific language governing permissions and limitations
# under the License.
+echo ""
+echo "NOTE: Please review the licensing of the datasets in this script before proceeding"
+echo "See https://catalog.ldc.upenn.edu/ldc99t42 for the licensing"
+echo "Once that is done, please uncomment the wget commands in this script"
+echo ""
RNN_DIR=$(cd `dirname $0`; pwd)
DATA_DIR="${RNN_DIR}/data/"
@@ -26,7 +31,7 @@ if [[ ! -d "${DATA_DIR}" ]]; then
mkdir -p ${DATA_DIR}
fi
-wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.train.txt;
-wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.valid.txt;
-wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.test.txt;
-wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt;
+#wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.train.txt;
+#wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.valid.txt;
+#wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.test.txt;
+#wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt;
diff --git a/example/rnn/word_lm/model.py b/example/rnn/word_lm/model.py
new file mode 100644
index 0000000..aa3710a
--- /dev/null
+++ b/example/rnn/word_lm/model.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+
+def rnn(bptt, vocab_size, num_embed, nhid,
+ num_layers, dropout, batch_size, tied):
+ # encoder
+ data = mx.sym.Variable('data')
+ weight = mx.sym.var("encoder_weight", init=mx.init.Uniform(0.1))
+ embed = mx.sym.Embedding(data=data, weight=weight, input_dim=vocab_size,
+ output_dim=num_embed, name='embed')
+
+ # stacked rnn layers
+ states = []
+ state_names = []
+ outputs = mx.sym.Dropout(embed, p=dropout)
+ for i in range(num_layers):
+ prefix = 'lstm_l%d_' % i
+ cell = mx.rnn.FusedRNNCell(num_hidden=nhid, prefix=prefix, get_next_state=True,
+ forget_bias=0.0, dropout=dropout)
+ state_shape = (1, batch_size, nhid)
+ begin_cell_state_name = prefix + 'cell'
+ begin_hidden_state_name = prefix + 'hidden'
+ begin_cell_state = mx.sym.var(begin_cell_state_name, shape=state_shape)
+ begin_hidden_state = mx.sym.var(begin_hidden_state_name, shape=state_shape)
+ state_names += [begin_cell_state_name, begin_hidden_state_name]
+ outputs, next_states = cell.unroll(bptt, inputs=outputs,
+ begin_state=[begin_cell_state, begin_hidden_state],
+ merge_outputs=True, layout='TNC')
+ outputs = mx.sym.Dropout(outputs, p=dropout)
+ states += next_states
+
+ # decoder
+ pred = mx.sym.Reshape(outputs, shape=(-1, nhid))
+ if tied:
+ assert(nhid == num_embed), \
+ "the number of hidden units and the embedding size must batch for weight tying"
+ pred = mx.sym.FullyConnected(data=pred, weight=weight,
+ num_hidden=vocab_size, name='pred')
+ else:
+ pred = mx.sym.FullyConnected(data=pred, num_hidden=vocab_size, name='pred')
+ pred = mx.sym.Reshape(pred, shape=(-1, vocab_size))
+ return pred, [mx.sym.stop_gradient(s) for s in states], state_names
+
+def softmax_ce_loss(pred):
+ # softmax cross-entropy loss
+ label = mx.sym.Variable('label')
+ label = mx.sym.Reshape(label, shape=(-1,))
+ logits = mx.sym.log_softmax(pred, axis=-1)
+ loss = -mx.sym.pick(logits, label, axis=-1, keepdims=True)
+ loss = mx.sym.mean(loss, axis=0, exclude=True)
+ return mx.sym.make_loss(loss, name='nll')
diff --git a/example/rnn/word_lm/module.py b/example/rnn/word_lm/module.py
new file mode 100644
index 0000000..864700c
--- /dev/null
+++ b/example/rnn/word_lm/module.py
@@ -0,0 +1,134 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import logging
+
+class CustomStatefulModule():
+ """CustomStatefulModule is a module that takes a custom loss symbol and state symbols.
+ The custom loss is typically composed by `mx.sym.make_loss` or `mx.sym.MakeLoss`.
+ The states listed in `state_names` will be carried between iterations.
+
+ Parameters
+ ----------
+ loss : Symbol
+ The custom loss symbol
+ states: list of Symbol
+ The symbols of next states
+ state_names : list of str
+ states are similar to data and label, but not provided by data iterator.
+ Instead they are initialized to `initial_states` and can be carried between iterations.
+ data_names : list of str
+ Defaults to `('data')` for a typical model used in image classification.
+ label_names : list of str
+ Defaults to `('softmax_label')` for a typical model used in image
+ classification.
+ logger : Logger
+ Defaults to `logging`.
+ context : Context or list of Context
+ Defaults to ``mx.cpu()``.
+ initial_states: float or list of NDArray
+ Defaults to 0.0.
+ """
+ def __init__(self, loss, states, state_names, data_names=('data',), label_names=('label',),
+ context=mx.cpu(), initial_states=0.0, **kwargs):
+ if isinstance(states, mx.symbol.Symbol):
+ states = [states]
+ self._net = mx.sym.Group(states + [loss])
+ self._next_states = initial_states
+ self._module = mx.module.Module(self._net, data_names=data_names, label_names=label_names,
+ context=context, state_names=state_names, **kwargs)
+
+ def backward(self, out_grads=None):
+ """Backward computation.
+ """
+ self._module.backward(out_grads=out_grads)
+
+ def init_params(self, initializer=mx.init.Uniform(0.01), **kwargs):
+ """Initializes the parameters and auxiliary states.
+ """
+ self._module.init_params(initializer=initializer, **kwargs)
+
+ def init_optimizer(self, **kwargs):
+ """Installs and initializes optimizers, as well as initialize kvstore for
+ distributed training.
+ """
+ self._module.init_optimizer(**kwargs)
+
+ def bind(self, data_shapes, **kwargs):
+ """Binds the symbols to construct executors. This is necessary before one
+ can perform computation with the module.
+ """
+ self._module.bind(data_shapes, **kwargs)
+
+ def forward(self, data_batch, is_train=None, carry_state=True):
+ """Forward computation. States from previous forward computation are carried
+ to the current iteration if `carry_state` is set to `True`.
+ """
+ # propagate states from the previous iteration
+ if carry_state:
+ if isinstance(self._next_states, (int, float)):
+ self._module.set_states(value=self._next_states)
+ else:
+ self._module.set_states(states=self._next_states)
+ self._module.forward(data_batch, is_train=is_train)
+ outputs = self._module.get_outputs(merge_multi_context=False)
+ self._next_states = outputs[:-1]
+
+ def update(self, max_norm=None):
+ """Updates parameters according to the installed optimizer and the gradients computed
+ in the previous forward-backward batch. Gradients are clipped by their global norm
+ if `max_norm` is set.
+
+ Parameters
+ ----------
+ max_norm: float, optional
+ If set, clip values of all gradients the ratio of the sum of their norms.
+ """
+ if max_norm is not None:
+ self._clip_by_global_norm(max_norm)
+ self._module.update()
+
+ def _clip_by_global_norm(self, max_norm):
+ """Clips gradient norm.
+
+ The norm is computed over all gradients together, as if they were
+ concatenated into a single vector. Gradients are modified in-place.
+ The method is first used in
+ `[ICML2013] On the difficulty of training recurrent neural networks`
+
+ Parameters
+ ----------
+ max_norm : float or int
+ The maximum clipping threshold of the gradient norm.
+
+ Returns
+ -------
+ norm_val : float
+ The computed norm of the gradients.
+ """
+ assert self._module.binded and self._module.params_initialized \
+ and self._module.optimizer_initialized
+ grad_array = []
+ for grad in self._module._exec_group.grad_arrays:
+ grad_array += grad
+ return mx.gluon.utils.clip_global_norm(grad_array, max_norm)
+
+ def get_loss(self):
+ """Gets the output loss of the previous forward computation.
+ """
+ return self._module.get_outputs(merge_multi_context=False)[-1]
diff --git a/example/rnn/word_lm/train.py b/example/rnn/word_lm/train.py
new file mode 100644
index 0000000..53b6bd3
--- /dev/null
+++ b/example/rnn/word_lm/train.py
@@ -0,0 +1,136 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import mxnet as mx, math
+import argparse, math
+import logging
+from data import Corpus, CorpusIter
+from model import *
+from module import *
+from mxnet.model import BatchEndParam
+
+parser = argparse.ArgumentParser(description='PennTreeBank LSTM Language Model')
+parser.add_argument('--data', type=str, default='./data/ptb.',
+ help='location of the data corpus')
+parser.add_argument('--emsize', type=int, default=650,
+ help='size of word embeddings')
+parser.add_argument('--nhid', type=int, default=650,
+ help='number of hidden units per layer')
+parser.add_argument('--nlayers', type=int, default=2,
+ help='number of layers')
+parser.add_argument('--lr', type=float, default=1.0,
+ help='initial learning rate')
+parser.add_argument('--clip', type=float, default=0.2,
+ help='gradient clipping by global norm')
+parser.add_argument('--epochs', type=int, default=40,
+ help='upper epoch limit')
+parser.add_argument('--batch_size', type=int, default=32,
+ help='batch size')
+parser.add_argument('--dropout', type=float, default=0.5,
+ help='dropout applied to layers (0 = no dropout)')
+parser.add_argument('--tied', action='store_true',
+ help='tie the word embedding and softmax weights')
+parser.add_argument('--bptt', type=int, default=35,
+ help='sequence length')
+parser.add_argument('--log-interval', type=int, default=200,
+ help='report interval')
+parser.add_argument('--seed', type=int, default=3,
+ help='random seed')
+args = parser.parse_args()
+
+best_loss = 9999
+
+def evaluate(valid_module, data_iter, epoch, mode, bptt, batch_size):
+ total_loss = 0.0
+ nbatch = 0
+ for batch in data_iter:
+ valid_module.forward(batch, is_train=False)
+ outputs = valid_module.get_loss()
+ total_loss += mx.nd.sum(outputs[0]).asscalar()
+ nbatch += 1
+ data_iter.reset()
+ loss = total_loss / bptt / batch_size / nbatch
+ logging.info('Iter[%d] %s loss:\t%.7f, Perplexity: %.7f' % \
+ (epoch, mode, loss, math.exp(loss)))
+ return loss
+
+if __name__ == '__main__':
+ # args
+ head = '%(asctime)-15s %(message)s'
+ logging.basicConfig(level=logging.DEBUG, format=head)
+ args = parser.parse_args()
+ logging.info(args)
+ ctx = mx.gpu()
+ batch_size = args.batch_size
+ bptt = args.bptt
+ mx.random.seed(args.seed)
+
+ # data
+ corpus = Corpus(args.data)
+ ntokens = len(corpus.dictionary)
+ train_data = CorpusIter(corpus.train, batch_size, bptt)
+ valid_data = CorpusIter(corpus.valid, batch_size, bptt)
+ test_data = CorpusIter(corpus.test, batch_size, bptt)
+
+ # model
+ pred, states, state_names = rnn(bptt, ntokens, args.emsize, args.nhid,
+ args.nlayers, args.dropout, batch_size, args.tied)
+ loss = softmax_ce_loss(pred)
+
+ # module
+ module = CustomStatefulModule(loss, states, state_names=state_names, context=ctx)
+ module.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label)
+ module.init_params(initializer=mx.init.Xavier())
+ optimizer = mx.optimizer.create('sgd', learning_rate=args.lr, rescale_grad=1.0/batch_size)
+ module.init_optimizer(optimizer=optimizer)
+
+ # metric
+ speedometer = mx.callback.Speedometer(batch_size, args.log_interval)
+
+ # train
+ logging.info("Training started ... ")
+ for epoch in range(args.epochs):
+ # train
+ total_loss = 0.0
+ nbatch = 0
+ for batch in train_data:
+ module.forward(batch)
+ module.backward()
+ module.update(max_norm=args.clip * bptt * batch_size)
+ # update metric
+ outputs = module.get_loss()
+ total_loss += mx.nd.sum(outputs[0]).asscalar()
+ speedometer_param = BatchEndParam(epoch=epoch, nbatch=nbatch,
+ eval_metric=None, locals=locals())
+ speedometer(speedometer_param)
+ if nbatch % args.log_interval == 0 and nbatch > 0:
+ cur_loss = total_loss / bptt / batch_size / args.log_interval
+ logging.info('Iter[%d] Batch [%d]\tLoss: %.7f,\tPerplexity:\t%.7f' % \
+ (epoch, nbatch, cur_loss, math.exp(cur_loss)))
+ total_loss = 0.0
+ nbatch += 1
+ # validation
+ valid_loss = evaluate(module, valid_data, epoch, 'Valid', bptt, batch_size)
+ if valid_loss < best_loss:
+ best_loss = valid_loss
+ # test
+ test_loss = evaluate(module, test_data, epoch, 'Test', bptt, batch_size)
+ else:
+ optimizer.lr *= 0.25
+ train_data.reset()
+ logging.info("Training completed. ")
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].