You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/12/12 21:42:10 UTC

[incubator-mxnet] branch master updated: Add truncated bptt RNN example using symbol & module API (#9038)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 06ed4dc  Add truncated bptt RNN example using symbol & module API (#9038)
06ed4dc is described below

commit 06ed4dcb1ab6aa1c506dbd3d7bb052b123f63589
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Tue Dec 12 13:42:05 2017 -0800

    Add truncated bptt RNN example using symbol & module API (#9038)
    
    * restructure folder
    
    * initial commmit
    
    * update doc
    
    * one more fix
    
    * update dfault args and readme
    
    * update data
---
 example/rnn/README.md                              |  21 ++--
 example/rnn/{ => bucketing}/README.md              |   2 -
 .../rnn/{ => bucketing}/cudnn_lstm_bucketing.py    |   0
 example/rnn/{ => bucketing}/get_ptb_data.sh        |   0
 example/rnn/{ => bucketing}/lstm_bucketing.py      |   0
 example/rnn/word_lm/README.md                      |  49 ++++++++
 example/rnn/word_lm/data.py                        | 114 +++++++++++++++++
 example/rnn/{ => word_lm}/get_ptb_data.sh          |  13 +-
 example/rnn/word_lm/model.py                       |  67 ++++++++++
 example/rnn/word_lm/module.py                      | 134 ++++++++++++++++++++
 example/rnn/word_lm/train.py                       | 136 +++++++++++++++++++++
 11 files changed, 519 insertions(+), 17 deletions(-)

diff --git a/example/rnn/README.md b/example/rnn/README.md
index 8a6f29d..f0d80c3 100644
--- a/example/rnn/README.md
+++ b/example/rnn/README.md
@@ -1,15 +1,14 @@
-RNN Example
+Recurrent Neural Network Examples
 ===========
-This folder contains RNN examples using high level mxnet.rnn interface.
 
-Examples using low level symbol interface have been deprecated and moved to old/
+This directory contains functions for creating recurrent neural networks
+models using high level mxnet.rnn interface.
 
-## Data
-Run `get_ptb_data.sh` to download PenTreeBank data.
+Here is a short overview of what is in this directory.
 
-## Python
-
-- [lstm_bucketing.py](lstm_bucketing.py) PennTreeBank language model by using LSTM
-
-Performance Note:
-More ```MXNET_GPU_WORKER_NTHREADS``` may lead to better performance. For setting ```MXNET_GPU_WORKER_NTHREADS```, please refer to [Environment Variables](https://mxnet.readthedocs.org/en/latest/how_to/env_var.html).
+Directory | What's in it?
+--- | ---
+`word_lm/` | Language model trained on the PTB dataset achieving state of the art performance
+`bucketing/` | Language model with bucketing API with python
+`bucket_R/` | Language model with bucketing API with R
+`old/` | Language model trained with low level symbol interface (deprecated)
diff --git a/example/rnn/README.md b/example/rnn/bucketing/README.md
similarity index 85%
copy from example/rnn/README.md
copy to example/rnn/bucketing/README.md
index 8a6f29d..6baf1ec 100644
--- a/example/rnn/README.md
+++ b/example/rnn/bucketing/README.md
@@ -2,8 +2,6 @@ RNN Example
 ===========
 This folder contains RNN examples using high level mxnet.rnn interface.
 
-Examples using low level symbol interface have been deprecated and moved to old/
-
 ## Data
 Run `get_ptb_data.sh` to download PenTreeBank data.
 
diff --git a/example/rnn/cudnn_lstm_bucketing.py b/example/rnn/bucketing/cudnn_lstm_bucketing.py
similarity index 100%
rename from example/rnn/cudnn_lstm_bucketing.py
rename to example/rnn/bucketing/cudnn_lstm_bucketing.py
diff --git a/example/rnn/get_ptb_data.sh b/example/rnn/bucketing/get_ptb_data.sh
similarity index 100%
copy from example/rnn/get_ptb_data.sh
copy to example/rnn/bucketing/get_ptb_data.sh
diff --git a/example/rnn/lstm_bucketing.py b/example/rnn/bucketing/lstm_bucketing.py
similarity index 100%
rename from example/rnn/lstm_bucketing.py
rename to example/rnn/bucketing/lstm_bucketing.py
diff --git a/example/rnn/word_lm/README.md b/example/rnn/word_lm/README.md
new file mode 100644
index 0000000..c498032
--- /dev/null
+++ b/example/rnn/word_lm/README.md
@@ -0,0 +1,49 @@
+Word Level Language Modeling
+===========
+This example trains a multi-layer LSTM on Penn Treebank (PTB) language modeling benchmark.
+
+The following techniques have been adopted for SOTA results:
+- [LSTM for LM](https://arxiv.org/pdf/1409.2329.pdf)
+- [Weight tying](https://arxiv.org/abs/1608.05859) between word vectors and softmax output embeddings
+
+## Prerequisite
+The example requires MXNet built with CUDA.
+
+## Data
+The PTB data is the processed version from [(Mikolov et al, 2010)](http://www.fit.vutbr.cz/research/groups/speech/publi/2010/mikolov_interspeech2010_IS100722.pdf):
+
+## Usage
+Example runs and the results:
+
+```
+python train.py --tied --nhid 650 --emsize 650 --dropout 0.5        # Test ppl of 75.4
+```
+
+```
+usage: train.py [-h] [--data DATA] [--emsize EMSIZE] [--nhid NHID]
+                [--nlayers NLAYERS] [--lr LR] [--clip CLIP] [--epochs EPOCHS]
+                [--batch_size BATCH_SIZE] [--dropout DROPOUT] [--tied]
+                [--bptt BPTT] [--log-interval LOG_INTERVAL] [--seed SEED]
+
+PennTreeBank LSTM Language Model
+
+optional arguments:
+  -h, --help            show this help message and exit
+  --data DATA           location of the data corpus
+  --emsize EMSIZE       size of word embeddings
+  --nhid NHID           number of hidden units per layer
+  --nlayers NLAYERS     number of layers
+  --lr LR               initial learning rate
+  --clip CLIP           gradient clipping by global norm
+  --epochs EPOCHS       upper epoch limit
+  --batch_size BATCH_SIZE
+                        batch size
+  --dropout DROPOUT     dropout applied to layers (0 = no dropout)
+  --tied                tie the word embedding and softmax weights
+  --bptt BPTT           sequence length
+  --log-interval LOG_INTERVAL
+                        report interval
+  --seed SEED           random seed
+```
+
+
diff --git a/example/rnn/word_lm/data.py b/example/rnn/word_lm/data.py
new file mode 100644
index 0000000..ff67088
--- /dev/null
+++ b/example/rnn/word_lm/data.py
@@ -0,0 +1,114 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os, gzip
+import sys
+import mxnet as mx
+import numpy as np
+
+class Dictionary(object):
+    def __init__(self):
+        self.word2idx = {}
+        self.idx2word = []
+        self.word_count = []
+
+    def add_word(self, word):
+        if word not in self.word2idx:
+            self.idx2word.append(word)
+            self.word2idx[word] = len(self.idx2word) - 1
+            self.word_count.append(0)
+        index = self.word2idx[word]
+        self.word_count[index] += 1
+        return index
+
+    def __len__(self):
+        return len(self.idx2word)
+
+class Corpus(object):
+    def __init__(self, path):
+        self.dictionary = Dictionary()
+        self.train = self.tokenize(path + 'train.txt')
+        self.valid = self.tokenize(path + 'valid.txt')
+        self.test = self.tokenize(path + 'test.txt')
+
+    def tokenize(self, path):
+        """Tokenizes a text file."""
+        assert os.path.exists(path)
+        # Add words to the dictionary
+        with open(path, 'r') as f:
+            tokens = 0
+            for line in f:
+                words = line.split() + ['<eos>']
+                tokens += len(words)
+                for word in words:
+                    self.dictionary.add_word(word)
+
+        # Tokenize file content
+        with open(path, 'r') as f:
+            ids = np.zeros((tokens,), dtype='int32')
+            token = 0
+            for line in f:
+                words = line.split() + ['<eos>']
+                for word in words:
+                    ids[token] = self.dictionary.word2idx[word]
+                    token += 1
+
+        return mx.nd.array(ids, dtype='int32')
+
+def batchify(data, batch_size):
+    """Reshape data into (num_example, batch_size)"""
+    nbatch = data.shape[0] // batch_size
+    data = data[:nbatch * batch_size]
+    data = data.reshape((batch_size, nbatch)).T
+    return data
+
+class CorpusIter(mx.io.DataIter):
+    "An iterator that returns the a batch of sequence each time"
+    def __init__(self, source, batch_size, bptt):
+        super(CorpusIter, self).__init__()
+        self.batch_size = batch_size
+        self.provide_data = [('data', (bptt, batch_size), np.int32)]
+        self.provide_label = [('label', (bptt, batch_size))]
+        self._index = 0
+        self._bptt = bptt
+        self._source = batchify(source, batch_size)
+
+    def iter_next(self):
+        i = self._index
+        if i+self._bptt > self._source.shape[0] - 1:
+            return False
+        self._next_data = self._source[i:i+self._bptt]
+        self._next_label = self._source[i+1:i+1+self._bptt].astype(np.float32)
+        self._index += self._bptt
+        return True
+
+    def next(self):
+        if self.iter_next():
+            return mx.io.DataBatch(data=self.getdata(), label=self.getlabel())
+        else:
+            raise StopIteration
+
+    def reset(self):
+        self._index = 0
+        self._next_data = None
+        self._next_label = None
+
+    def getdata(self):
+        return [self._next_data]
+
+    def getlabel(self):
+        return [self._next_label]
diff --git a/example/rnn/get_ptb_data.sh b/example/rnn/word_lm/get_ptb_data.sh
similarity index 59%
rename from example/rnn/get_ptb_data.sh
rename to example/rnn/word_lm/get_ptb_data.sh
index d2641cb..0a0c705 100755
--- a/example/rnn/get_ptb_data.sh
+++ b/example/rnn/word_lm/get_ptb_data.sh
@@ -17,6 +17,11 @@
 # specific language governing permissions and limitations
 # under the License.
 
+echo ""
+echo "NOTE: Please review the licensing of the datasets in this script before proceeding"
+echo "See https://catalog.ldc.upenn.edu/ldc99t42 for the licensing"
+echo "Once that is done, please uncomment the wget commands in this script"
+echo ""
 
 RNN_DIR=$(cd `dirname $0`; pwd)
 DATA_DIR="${RNN_DIR}/data/"
@@ -26,7 +31,7 @@ if [[ ! -d "${DATA_DIR}" ]]; then
   mkdir -p ${DATA_DIR}
 fi
 
-wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.train.txt;
-wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.valid.txt;
-wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.test.txt;
-wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt;
+#wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.train.txt;
+#wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.valid.txt;
+#wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.test.txt;
+#wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt;
diff --git a/example/rnn/word_lm/model.py b/example/rnn/word_lm/model.py
new file mode 100644
index 0000000..aa3710a
--- /dev/null
+++ b/example/rnn/word_lm/model.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+
+def rnn(bptt, vocab_size, num_embed, nhid,
+        num_layers, dropout, batch_size, tied):
+    # encoder
+    data = mx.sym.Variable('data')
+    weight = mx.sym.var("encoder_weight", init=mx.init.Uniform(0.1))
+    embed = mx.sym.Embedding(data=data, weight=weight, input_dim=vocab_size,
+                             output_dim=num_embed, name='embed')
+
+    # stacked rnn layers
+    states = []
+    state_names = []
+    outputs = mx.sym.Dropout(embed, p=dropout)
+    for i in range(num_layers):
+        prefix = 'lstm_l%d_' % i
+        cell = mx.rnn.FusedRNNCell(num_hidden=nhid, prefix=prefix, get_next_state=True,
+                                   forget_bias=0.0, dropout=dropout)
+        state_shape = (1, batch_size, nhid)
+        begin_cell_state_name = prefix + 'cell'
+        begin_hidden_state_name = prefix + 'hidden'
+        begin_cell_state = mx.sym.var(begin_cell_state_name, shape=state_shape)
+        begin_hidden_state = mx.sym.var(begin_hidden_state_name, shape=state_shape)
+        state_names += [begin_cell_state_name, begin_hidden_state_name]
+        outputs, next_states = cell.unroll(bptt, inputs=outputs,
+                                           begin_state=[begin_cell_state, begin_hidden_state],
+                                           merge_outputs=True, layout='TNC')
+        outputs = mx.sym.Dropout(outputs, p=dropout)
+        states += next_states
+
+    # decoder
+    pred = mx.sym.Reshape(outputs, shape=(-1, nhid))
+    if tied:
+        assert(nhid == num_embed), \
+               "the number of hidden units and the embedding size must batch for weight tying"
+        pred = mx.sym.FullyConnected(data=pred, weight=weight,
+                                     num_hidden=vocab_size, name='pred')
+    else:
+        pred = mx.sym.FullyConnected(data=pred, num_hidden=vocab_size, name='pred')
+    pred = mx.sym.Reshape(pred, shape=(-1, vocab_size))
+    return pred, [mx.sym.stop_gradient(s) for s in states], state_names
+
+def softmax_ce_loss(pred):
+    # softmax cross-entropy loss
+    label = mx.sym.Variable('label')
+    label = mx.sym.Reshape(label, shape=(-1,))
+    logits = mx.sym.log_softmax(pred, axis=-1)
+    loss = -mx.sym.pick(logits, label, axis=-1, keepdims=True)
+    loss = mx.sym.mean(loss, axis=0, exclude=True)
+    return mx.sym.make_loss(loss, name='nll')
diff --git a/example/rnn/word_lm/module.py b/example/rnn/word_lm/module.py
new file mode 100644
index 0000000..864700c
--- /dev/null
+++ b/example/rnn/word_lm/module.py
@@ -0,0 +1,134 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import logging
+
+class CustomStatefulModule():
+    """CustomStatefulModule is a module that takes a custom loss symbol and state symbols.
+    The custom loss is typically composed by `mx.sym.make_loss` or `mx.sym.MakeLoss`.
+    The states listed in `state_names` will be carried between iterations.
+
+    Parameters
+    ----------
+    loss : Symbol
+        The custom loss symbol
+    states: list of Symbol
+        The symbols of next states
+    state_names : list of str
+        states are similar to data and label, but not provided by data iterator.
+        Instead they are initialized to `initial_states` and can be carried between iterations.
+    data_names : list of str
+        Defaults to `('data')` for a typical model used in image classification.
+    label_names : list of str
+        Defaults to `('softmax_label')` for a typical model used in image
+        classification.
+    logger : Logger
+        Defaults to `logging`.
+    context : Context or list of Context
+        Defaults to ``mx.cpu()``.
+    initial_states: float or list of NDArray
+        Defaults to 0.0.
+    """
+    def __init__(self, loss, states, state_names, data_names=('data',), label_names=('label',),
+                 context=mx.cpu(), initial_states=0.0, **kwargs):
+        if isinstance(states, mx.symbol.Symbol):
+            states = [states]
+        self._net = mx.sym.Group(states + [loss])
+        self._next_states = initial_states
+        self._module = mx.module.Module(self._net, data_names=data_names, label_names=label_names,
+                                        context=context, state_names=state_names, **kwargs)
+
+    def backward(self, out_grads=None):
+        """Backward computation.
+        """
+        self._module.backward(out_grads=out_grads)
+
+    def init_params(self, initializer=mx.init.Uniform(0.01), **kwargs):
+        """Initializes the parameters and auxiliary states.
+        """
+        self._module.init_params(initializer=initializer, **kwargs)
+
+    def init_optimizer(self, **kwargs):
+        """Installs and initializes optimizers, as well as initialize kvstore for
+           distributed training.
+        """
+        self._module.init_optimizer(**kwargs)
+
+    def bind(self, data_shapes, **kwargs):
+        """Binds the symbols to construct executors. This is necessary before one
+        can perform computation with the module.
+        """
+        self._module.bind(data_shapes, **kwargs)
+
+    def forward(self, data_batch, is_train=None, carry_state=True):
+        """Forward computation. States from previous forward computation are carried
+        to the current iteration if `carry_state` is set to `True`.
+        """
+        # propagate states from the previous iteration
+        if carry_state:
+            if isinstance(self._next_states, (int, float)):
+                self._module.set_states(value=self._next_states)
+            else:
+                self._module.set_states(states=self._next_states)
+        self._module.forward(data_batch, is_train=is_train)
+        outputs = self._module.get_outputs(merge_multi_context=False)
+        self._next_states = outputs[:-1]
+
+    def update(self, max_norm=None):
+        """Updates parameters according to the installed optimizer and the gradients computed
+        in the previous forward-backward batch. Gradients are clipped by their global norm
+        if `max_norm` is set.
+
+        Parameters
+        ----------
+        max_norm: float, optional
+            If set, clip values of all gradients the ratio of the sum of their norms.
+        """
+        if max_norm is not None:
+            self._clip_by_global_norm(max_norm)
+        self._module.update()
+
+    def _clip_by_global_norm(self, max_norm):
+        """Clips gradient norm.
+
+        The norm is computed over all gradients together, as if they were
+        concatenated into a single vector. Gradients are modified in-place.
+        The method is first used in
+         `[ICML2013] On the difficulty of training recurrent neural networks`
+
+        Parameters
+        ----------
+        max_norm : float or int
+            The maximum clipping threshold of the gradient norm.
+
+        Returns
+        -------
+        norm_val : float
+            The computed norm of the gradients.
+        """
+        assert self._module.binded and self._module.params_initialized \
+               and self._module.optimizer_initialized
+        grad_array = []
+        for grad in self._module._exec_group.grad_arrays:
+            grad_array += grad
+        return mx.gluon.utils.clip_global_norm(grad_array, max_norm)
+
+    def get_loss(self):
+        """Gets the output loss of the previous forward computation.
+        """
+        return self._module.get_outputs(merge_multi_context=False)[-1]
diff --git a/example/rnn/word_lm/train.py b/example/rnn/word_lm/train.py
new file mode 100644
index 0000000..53b6bd3
--- /dev/null
+++ b/example/rnn/word_lm/train.py
@@ -0,0 +1,136 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import mxnet as mx, math
+import argparse, math
+import logging
+from data import Corpus, CorpusIter
+from model import *
+from module import *
+from mxnet.model import BatchEndParam
+
+parser = argparse.ArgumentParser(description='PennTreeBank LSTM Language Model')
+parser.add_argument('--data', type=str, default='./data/ptb.',
+                    help='location of the data corpus')
+parser.add_argument('--emsize', type=int, default=650,
+                    help='size of word embeddings')
+parser.add_argument('--nhid', type=int, default=650,
+                    help='number of hidden units per layer')
+parser.add_argument('--nlayers', type=int, default=2,
+                    help='number of layers')
+parser.add_argument('--lr', type=float, default=1.0,
+                    help='initial learning rate')
+parser.add_argument('--clip', type=float, default=0.2,
+                    help='gradient clipping by global norm')
+parser.add_argument('--epochs', type=int, default=40,
+                    help='upper epoch limit')
+parser.add_argument('--batch_size', type=int, default=32,
+                    help='batch size')
+parser.add_argument('--dropout', type=float, default=0.5,
+                    help='dropout applied to layers (0 = no dropout)')
+parser.add_argument('--tied', action='store_true',
+                    help='tie the word embedding and softmax weights')
+parser.add_argument('--bptt', type=int, default=35,
+                    help='sequence length')
+parser.add_argument('--log-interval', type=int, default=200,
+                    help='report interval')
+parser.add_argument('--seed', type=int, default=3,
+                    help='random seed')
+args = parser.parse_args()
+
+best_loss = 9999
+
+def evaluate(valid_module, data_iter, epoch, mode, bptt, batch_size):
+    total_loss = 0.0
+    nbatch = 0
+    for batch in data_iter:
+        valid_module.forward(batch, is_train=False)
+        outputs = valid_module.get_loss()
+        total_loss += mx.nd.sum(outputs[0]).asscalar()
+        nbatch += 1
+    data_iter.reset()
+    loss = total_loss / bptt / batch_size / nbatch
+    logging.info('Iter[%d] %s loss:\t%.7f, Perplexity: %.7f' % \
+                 (epoch, mode, loss, math.exp(loss)))
+    return loss
+
+if __name__ == '__main__':
+    # args
+    head = '%(asctime)-15s %(message)s'
+    logging.basicConfig(level=logging.DEBUG, format=head)
+    args = parser.parse_args()
+    logging.info(args)
+    ctx = mx.gpu()
+    batch_size = args.batch_size
+    bptt = args.bptt
+    mx.random.seed(args.seed)
+
+    # data
+    corpus = Corpus(args.data)
+    ntokens = len(corpus.dictionary)
+    train_data = CorpusIter(corpus.train, batch_size, bptt)
+    valid_data = CorpusIter(corpus.valid, batch_size, bptt)
+    test_data = CorpusIter(corpus.test, batch_size, bptt)
+
+    # model
+    pred, states, state_names = rnn(bptt, ntokens, args.emsize, args.nhid,
+                                    args.nlayers, args.dropout, batch_size, args.tied)
+    loss = softmax_ce_loss(pred)
+
+    # module
+    module = CustomStatefulModule(loss, states, state_names=state_names, context=ctx)
+    module.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label)
+    module.init_params(initializer=mx.init.Xavier())
+    optimizer = mx.optimizer.create('sgd', learning_rate=args.lr, rescale_grad=1.0/batch_size)
+    module.init_optimizer(optimizer=optimizer)
+
+    # metric
+    speedometer = mx.callback.Speedometer(batch_size, args.log_interval)
+
+    # train
+    logging.info("Training started ... ")
+    for epoch in range(args.epochs):
+        # train
+        total_loss = 0.0
+        nbatch = 0
+        for batch in train_data:
+            module.forward(batch)
+            module.backward()
+            module.update(max_norm=args.clip * bptt * batch_size)
+            # update metric
+            outputs = module.get_loss()
+            total_loss += mx.nd.sum(outputs[0]).asscalar()
+            speedometer_param = BatchEndParam(epoch=epoch, nbatch=nbatch,
+                                              eval_metric=None, locals=locals())
+            speedometer(speedometer_param)
+            if nbatch % args.log_interval == 0 and nbatch > 0:
+                cur_loss = total_loss / bptt / batch_size / args.log_interval
+                logging.info('Iter[%d] Batch [%d]\tLoss:  %.7f,\tPerplexity:\t%.7f' % \
+                             (epoch, nbatch, cur_loss, math.exp(cur_loss)))
+                total_loss = 0.0
+            nbatch += 1
+        # validation
+        valid_loss = evaluate(module, valid_data, epoch, 'Valid', bptt, batch_size)
+        if valid_loss < best_loss:
+            best_loss = valid_loss
+            # test
+            test_loss = evaluate(module, test_data, epoch, 'Test', bptt, batch_size)
+        else:
+            optimizer.lr *= 0.25
+        train_data.reset()
+    logging.info("Training completed. ")

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].