You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/12/12 21:42:10 UTC

[GitHub] piiswrong closed pull request #9038: Add truncated bptt RNN example using symbol & module API

piiswrong closed pull request #9038: Add truncated bptt RNN example using symbol & module API
URL: https://github.com/apache/incubator-mxnet/pull/9038
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/rnn/README.md b/example/rnn/README.md
index 8a6f29d2c0..f0d80c3a61 100644
--- a/example/rnn/README.md
+++ b/example/rnn/README.md
@@ -1,15 +1,14 @@
-RNN Example
+Recurrent Neural Network Examples
 ===========
-This folder contains RNN examples using high level mxnet.rnn interface.
 
-Examples using low level symbol interface have been deprecated and moved to old/
+This directory contains functions for creating recurrent neural networks
+models using high level mxnet.rnn interface.
 
-## Data
-Run `get_ptb_data.sh` to download PenTreeBank data.
+Here is a short overview of what is in this directory.
 
-## Python
-
-- [lstm_bucketing.py](lstm_bucketing.py) PennTreeBank language model by using LSTM
-
-Performance Note:
-More ```MXNET_GPU_WORKER_NTHREADS``` may lead to better performance. For setting ```MXNET_GPU_WORKER_NTHREADS```, please refer to [Environment Variables](https://mxnet.readthedocs.org/en/latest/how_to/env_var.html).
+Directory | What's in it?
+--- | ---
+`word_lm/` | Language model trained on the PTB dataset achieving state of the art performance
+`bucketing/` | Language model with bucketing API with python
+`bucket_R/` | Language model with bucketing API with R
+`old/` | Language model trained with low level symbol interface (deprecated)
diff --git a/example/rnn/bucketing/README.md b/example/rnn/bucketing/README.md
new file mode 100644
index 0000000000..6baf1ecae9
--- /dev/null
+++ b/example/rnn/bucketing/README.md
@@ -0,0 +1,13 @@
+RNN Example
+===========
+This folder contains RNN examples using high level mxnet.rnn interface.
+
+## Data
+Run `get_ptb_data.sh` to download PenTreeBank data.
+
+## Python
+
+- [lstm_bucketing.py](lstm_bucketing.py) PennTreeBank language model by using LSTM
+
+Performance Note:
+More ```MXNET_GPU_WORKER_NTHREADS``` may lead to better performance. For setting ```MXNET_GPU_WORKER_NTHREADS```, please refer to [Environment Variables](https://mxnet.readthedocs.org/en/latest/how_to/env_var.html).
diff --git a/example/rnn/cudnn_lstm_bucketing.py b/example/rnn/bucketing/cudnn_lstm_bucketing.py
similarity index 100%
rename from example/rnn/cudnn_lstm_bucketing.py
rename to example/rnn/bucketing/cudnn_lstm_bucketing.py
diff --git a/example/rnn/get_ptb_data.sh b/example/rnn/bucketing/get_ptb_data.sh
similarity index 100%
rename from example/rnn/get_ptb_data.sh
rename to example/rnn/bucketing/get_ptb_data.sh
diff --git a/example/rnn/lstm_bucketing.py b/example/rnn/bucketing/lstm_bucketing.py
similarity index 100%
rename from example/rnn/lstm_bucketing.py
rename to example/rnn/bucketing/lstm_bucketing.py
diff --git a/example/rnn/word_lm/README.md b/example/rnn/word_lm/README.md
new file mode 100644
index 0000000000..c4980326e4
--- /dev/null
+++ b/example/rnn/word_lm/README.md
@@ -0,0 +1,49 @@
+Word Level Language Modeling
+===========
+This example trains a multi-layer LSTM on Penn Treebank (PTB) language modeling benchmark.
+
+The following techniques have been adopted for SOTA results:
+- [LSTM for LM](https://arxiv.org/pdf/1409.2329.pdf)
+- [Weight tying](https://arxiv.org/abs/1608.05859) between word vectors and softmax output embeddings
+
+## Prerequisite
+The example requires MXNet built with CUDA.
+
+## Data
+The PTB data is the processed version from [(Mikolov et al, 2010)](http://www.fit.vutbr.cz/research/groups/speech/publi/2010/mikolov_interspeech2010_IS100722.pdf):
+
+## Usage
+Example runs and the results:
+
+```
+python train.py --tied --nhid 650 --emsize 650 --dropout 0.5        # Test ppl of 75.4
+```
+
+```
+usage: train.py [-h] [--data DATA] [--emsize EMSIZE] [--nhid NHID]
+                [--nlayers NLAYERS] [--lr LR] [--clip CLIP] [--epochs EPOCHS]
+                [--batch_size BATCH_SIZE] [--dropout DROPOUT] [--tied]
+                [--bptt BPTT] [--log-interval LOG_INTERVAL] [--seed SEED]
+
+PennTreeBank LSTM Language Model
+
+optional arguments:
+  -h, --help            show this help message and exit
+  --data DATA           location of the data corpus
+  --emsize EMSIZE       size of word embeddings
+  --nhid NHID           number of hidden units per layer
+  --nlayers NLAYERS     number of layers
+  --lr LR               initial learning rate
+  --clip CLIP           gradient clipping by global norm
+  --epochs EPOCHS       upper epoch limit
+  --batch_size BATCH_SIZE
+                        batch size
+  --dropout DROPOUT     dropout applied to layers (0 = no dropout)
+  --tied                tie the word embedding and softmax weights
+  --bptt BPTT           sequence length
+  --log-interval LOG_INTERVAL
+                        report interval
+  --seed SEED           random seed
+```
+
+
diff --git a/example/rnn/word_lm/data.py b/example/rnn/word_lm/data.py
new file mode 100644
index 0000000000..ff67088c78
--- /dev/null
+++ b/example/rnn/word_lm/data.py
@@ -0,0 +1,114 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os, gzip
+import sys
+import mxnet as mx
+import numpy as np
+
+class Dictionary(object):
+    def __init__(self):
+        self.word2idx = {}
+        self.idx2word = []
+        self.word_count = []
+
+    def add_word(self, word):
+        if word not in self.word2idx:
+            self.idx2word.append(word)
+            self.word2idx[word] = len(self.idx2word) - 1
+            self.word_count.append(0)
+        index = self.word2idx[word]
+        self.word_count[index] += 1
+        return index
+
+    def __len__(self):
+        return len(self.idx2word)
+
+class Corpus(object):
+    def __init__(self, path):
+        self.dictionary = Dictionary()
+        self.train = self.tokenize(path + 'train.txt')
+        self.valid = self.tokenize(path + 'valid.txt')
+        self.test = self.tokenize(path + 'test.txt')
+
+    def tokenize(self, path):
+        """Tokenizes a text file."""
+        assert os.path.exists(path)
+        # Add words to the dictionary
+        with open(path, 'r') as f:
+            tokens = 0
+            for line in f:
+                words = line.split() + ['<eos>']
+                tokens += len(words)
+                for word in words:
+                    self.dictionary.add_word(word)
+
+        # Tokenize file content
+        with open(path, 'r') as f:
+            ids = np.zeros((tokens,), dtype='int32')
+            token = 0
+            for line in f:
+                words = line.split() + ['<eos>']
+                for word in words:
+                    ids[token] = self.dictionary.word2idx[word]
+                    token += 1
+
+        return mx.nd.array(ids, dtype='int32')
+
+def batchify(data, batch_size):
+    """Reshape data into (num_example, batch_size)"""
+    nbatch = data.shape[0] // batch_size
+    data = data[:nbatch * batch_size]
+    data = data.reshape((batch_size, nbatch)).T
+    return data
+
+class CorpusIter(mx.io.DataIter):
+    "An iterator that returns the a batch of sequence each time"
+    def __init__(self, source, batch_size, bptt):
+        super(CorpusIter, self).__init__()
+        self.batch_size = batch_size
+        self.provide_data = [('data', (bptt, batch_size), np.int32)]
+        self.provide_label = [('label', (bptt, batch_size))]
+        self._index = 0
+        self._bptt = bptt
+        self._source = batchify(source, batch_size)
+
+    def iter_next(self):
+        i = self._index
+        if i+self._bptt > self._source.shape[0] - 1:
+            return False
+        self._next_data = self._source[i:i+self._bptt]
+        self._next_label = self._source[i+1:i+1+self._bptt].astype(np.float32)
+        self._index += self._bptt
+        return True
+
+    def next(self):
+        if self.iter_next():
+            return mx.io.DataBatch(data=self.getdata(), label=self.getlabel())
+        else:
+            raise StopIteration
+
+    def reset(self):
+        self._index = 0
+        self._next_data = None
+        self._next_label = None
+
+    def getdata(self):
+        return [self._next_data]
+
+    def getlabel(self):
+        return [self._next_label]
diff --git a/example/rnn/word_lm/get_ptb_data.sh b/example/rnn/word_lm/get_ptb_data.sh
new file mode 100755
index 0000000000..0a0c7051b0
--- /dev/null
+++ b/example/rnn/word_lm/get_ptb_data.sh
@@ -0,0 +1,37 @@
+#!/usr/bin/env bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+echo ""
+echo "NOTE: Please review the licensing of the datasets in this script before proceeding"
+echo "See https://catalog.ldc.upenn.edu/ldc99t42 for the licensing"
+echo "Once that is done, please uncomment the wget commands in this script"
+echo ""
+
+RNN_DIR=$(cd `dirname $0`; pwd)
+DATA_DIR="${RNN_DIR}/data/"
+
+if [[ ! -d "${DATA_DIR}" ]]; then
+  echo "${DATA_DIR} doesn't exist, will create one";
+  mkdir -p ${DATA_DIR}
+fi
+
+#wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.train.txt;
+#wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.valid.txt;
+#wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.test.txt;
+#wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt;
diff --git a/example/rnn/word_lm/model.py b/example/rnn/word_lm/model.py
new file mode 100644
index 0000000000..aa3710a3b0
--- /dev/null
+++ b/example/rnn/word_lm/model.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+
+def rnn(bptt, vocab_size, num_embed, nhid,
+        num_layers, dropout, batch_size, tied):
+    # encoder
+    data = mx.sym.Variable('data')
+    weight = mx.sym.var("encoder_weight", init=mx.init.Uniform(0.1))
+    embed = mx.sym.Embedding(data=data, weight=weight, input_dim=vocab_size,
+                             output_dim=num_embed, name='embed')
+
+    # stacked rnn layers
+    states = []
+    state_names = []
+    outputs = mx.sym.Dropout(embed, p=dropout)
+    for i in range(num_layers):
+        prefix = 'lstm_l%d_' % i
+        cell = mx.rnn.FusedRNNCell(num_hidden=nhid, prefix=prefix, get_next_state=True,
+                                   forget_bias=0.0, dropout=dropout)
+        state_shape = (1, batch_size, nhid)
+        begin_cell_state_name = prefix + 'cell'
+        begin_hidden_state_name = prefix + 'hidden'
+        begin_cell_state = mx.sym.var(begin_cell_state_name, shape=state_shape)
+        begin_hidden_state = mx.sym.var(begin_hidden_state_name, shape=state_shape)
+        state_names += [begin_cell_state_name, begin_hidden_state_name]
+        outputs, next_states = cell.unroll(bptt, inputs=outputs,
+                                           begin_state=[begin_cell_state, begin_hidden_state],
+                                           merge_outputs=True, layout='TNC')
+        outputs = mx.sym.Dropout(outputs, p=dropout)
+        states += next_states
+
+    # decoder
+    pred = mx.sym.Reshape(outputs, shape=(-1, nhid))
+    if tied:
+        assert(nhid == num_embed), \
+               "the number of hidden units and the embedding size must batch for weight tying"
+        pred = mx.sym.FullyConnected(data=pred, weight=weight,
+                                     num_hidden=vocab_size, name='pred')
+    else:
+        pred = mx.sym.FullyConnected(data=pred, num_hidden=vocab_size, name='pred')
+    pred = mx.sym.Reshape(pred, shape=(-1, vocab_size))
+    return pred, [mx.sym.stop_gradient(s) for s in states], state_names
+
+def softmax_ce_loss(pred):
+    # softmax cross-entropy loss
+    label = mx.sym.Variable('label')
+    label = mx.sym.Reshape(label, shape=(-1,))
+    logits = mx.sym.log_softmax(pred, axis=-1)
+    loss = -mx.sym.pick(logits, label, axis=-1, keepdims=True)
+    loss = mx.sym.mean(loss, axis=0, exclude=True)
+    return mx.sym.make_loss(loss, name='nll')
diff --git a/example/rnn/word_lm/module.py b/example/rnn/word_lm/module.py
new file mode 100644
index 0000000000..864700c104
--- /dev/null
+++ b/example/rnn/word_lm/module.py
@@ -0,0 +1,134 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import logging
+
+class CustomStatefulModule():
+    """CustomStatefulModule is a module that takes a custom loss symbol and state symbols.
+    The custom loss is typically composed by `mx.sym.make_loss` or `mx.sym.MakeLoss`.
+    The states listed in `state_names` will be carried between iterations.
+
+    Parameters
+    ----------
+    loss : Symbol
+        The custom loss symbol
+    states: list of Symbol
+        The symbols of next states
+    state_names : list of str
+        states are similar to data and label, but not provided by data iterator.
+        Instead they are initialized to `initial_states` and can be carried between iterations.
+    data_names : list of str
+        Defaults to `('data')` for a typical model used in image classification.
+    label_names : list of str
+        Defaults to `('softmax_label')` for a typical model used in image
+        classification.
+    logger : Logger
+        Defaults to `logging`.
+    context : Context or list of Context
+        Defaults to ``mx.cpu()``.
+    initial_states: float or list of NDArray
+        Defaults to 0.0.
+    """
+    def __init__(self, loss, states, state_names, data_names=('data',), label_names=('label',),
+                 context=mx.cpu(), initial_states=0.0, **kwargs):
+        if isinstance(states, mx.symbol.Symbol):
+            states = [states]
+        self._net = mx.sym.Group(states + [loss])
+        self._next_states = initial_states
+        self._module = mx.module.Module(self._net, data_names=data_names, label_names=label_names,
+                                        context=context, state_names=state_names, **kwargs)
+
+    def backward(self, out_grads=None):
+        """Backward computation.
+        """
+        self._module.backward(out_grads=out_grads)
+
+    def init_params(self, initializer=mx.init.Uniform(0.01), **kwargs):
+        """Initializes the parameters and auxiliary states.
+        """
+        self._module.init_params(initializer=initializer, **kwargs)
+
+    def init_optimizer(self, **kwargs):
+        """Installs and initializes optimizers, as well as initialize kvstore for
+           distributed training.
+        """
+        self._module.init_optimizer(**kwargs)
+
+    def bind(self, data_shapes, **kwargs):
+        """Binds the symbols to construct executors. This is necessary before one
+        can perform computation with the module.
+        """
+        self._module.bind(data_shapes, **kwargs)
+
+    def forward(self, data_batch, is_train=None, carry_state=True):
+        """Forward computation. States from previous forward computation are carried
+        to the current iteration if `carry_state` is set to `True`.
+        """
+        # propagate states from the previous iteration
+        if carry_state:
+            if isinstance(self._next_states, (int, float)):
+                self._module.set_states(value=self._next_states)
+            else:
+                self._module.set_states(states=self._next_states)
+        self._module.forward(data_batch, is_train=is_train)
+        outputs = self._module.get_outputs(merge_multi_context=False)
+        self._next_states = outputs[:-1]
+
+    def update(self, max_norm=None):
+        """Updates parameters according to the installed optimizer and the gradients computed
+        in the previous forward-backward batch. Gradients are clipped by their global norm
+        if `max_norm` is set.
+
+        Parameters
+        ----------
+        max_norm: float, optional
+            If set, clip values of all gradients the ratio of the sum of their norms.
+        """
+        if max_norm is not None:
+            self._clip_by_global_norm(max_norm)
+        self._module.update()
+
+    def _clip_by_global_norm(self, max_norm):
+        """Clips gradient norm.
+
+        The norm is computed over all gradients together, as if they were
+        concatenated into a single vector. Gradients are modified in-place.
+        The method is first used in
+         `[ICML2013] On the difficulty of training recurrent neural networks`
+
+        Parameters
+        ----------
+        max_norm : float or int
+            The maximum clipping threshold of the gradient norm.
+
+        Returns
+        -------
+        norm_val : float
+            The computed norm of the gradients.
+        """
+        assert self._module.binded and self._module.params_initialized \
+               and self._module.optimizer_initialized
+        grad_array = []
+        for grad in self._module._exec_group.grad_arrays:
+            grad_array += grad
+        return mx.gluon.utils.clip_global_norm(grad_array, max_norm)
+
+    def get_loss(self):
+        """Gets the output loss of the previous forward computation.
+        """
+        return self._module.get_outputs(merge_multi_context=False)[-1]
diff --git a/example/rnn/word_lm/train.py b/example/rnn/word_lm/train.py
new file mode 100644
index 0000000000..53b6bd35f2
--- /dev/null
+++ b/example/rnn/word_lm/train.py
@@ -0,0 +1,136 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import mxnet as mx, math
+import argparse, math
+import logging
+from data import Corpus, CorpusIter
+from model import *
+from module import *
+from mxnet.model import BatchEndParam
+
+parser = argparse.ArgumentParser(description='PennTreeBank LSTM Language Model')
+parser.add_argument('--data', type=str, default='./data/ptb.',
+                    help='location of the data corpus')
+parser.add_argument('--emsize', type=int, default=650,
+                    help='size of word embeddings')
+parser.add_argument('--nhid', type=int, default=650,
+                    help='number of hidden units per layer')
+parser.add_argument('--nlayers', type=int, default=2,
+                    help='number of layers')
+parser.add_argument('--lr', type=float, default=1.0,
+                    help='initial learning rate')
+parser.add_argument('--clip', type=float, default=0.2,
+                    help='gradient clipping by global norm')
+parser.add_argument('--epochs', type=int, default=40,
+                    help='upper epoch limit')
+parser.add_argument('--batch_size', type=int, default=32,
+                    help='batch size')
+parser.add_argument('--dropout', type=float, default=0.5,
+                    help='dropout applied to layers (0 = no dropout)')
+parser.add_argument('--tied', action='store_true',
+                    help='tie the word embedding and softmax weights')
+parser.add_argument('--bptt', type=int, default=35,
+                    help='sequence length')
+parser.add_argument('--log-interval', type=int, default=200,
+                    help='report interval')
+parser.add_argument('--seed', type=int, default=3,
+                    help='random seed')
+args = parser.parse_args()
+
+best_loss = 9999
+
+def evaluate(valid_module, data_iter, epoch, mode, bptt, batch_size):
+    total_loss = 0.0
+    nbatch = 0
+    for batch in data_iter:
+        valid_module.forward(batch, is_train=False)
+        outputs = valid_module.get_loss()
+        total_loss += mx.nd.sum(outputs[0]).asscalar()
+        nbatch += 1
+    data_iter.reset()
+    loss = total_loss / bptt / batch_size / nbatch
+    logging.info('Iter[%d] %s loss:\t%.7f, Perplexity: %.7f' % \
+                 (epoch, mode, loss, math.exp(loss)))
+    return loss
+
+if __name__ == '__main__':
+    # args
+    head = '%(asctime)-15s %(message)s'
+    logging.basicConfig(level=logging.DEBUG, format=head)
+    args = parser.parse_args()
+    logging.info(args)
+    ctx = mx.gpu()
+    batch_size = args.batch_size
+    bptt = args.bptt
+    mx.random.seed(args.seed)
+
+    # data
+    corpus = Corpus(args.data)
+    ntokens = len(corpus.dictionary)
+    train_data = CorpusIter(corpus.train, batch_size, bptt)
+    valid_data = CorpusIter(corpus.valid, batch_size, bptt)
+    test_data = CorpusIter(corpus.test, batch_size, bptt)
+
+    # model
+    pred, states, state_names = rnn(bptt, ntokens, args.emsize, args.nhid,
+                                    args.nlayers, args.dropout, batch_size, args.tied)
+    loss = softmax_ce_loss(pred)
+
+    # module
+    module = CustomStatefulModule(loss, states, state_names=state_names, context=ctx)
+    module.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label)
+    module.init_params(initializer=mx.init.Xavier())
+    optimizer = mx.optimizer.create('sgd', learning_rate=args.lr, rescale_grad=1.0/batch_size)
+    module.init_optimizer(optimizer=optimizer)
+
+    # metric
+    speedometer = mx.callback.Speedometer(batch_size, args.log_interval)
+
+    # train
+    logging.info("Training started ... ")
+    for epoch in range(args.epochs):
+        # train
+        total_loss = 0.0
+        nbatch = 0
+        for batch in train_data:
+            module.forward(batch)
+            module.backward()
+            module.update(max_norm=args.clip * bptt * batch_size)
+            # update metric
+            outputs = module.get_loss()
+            total_loss += mx.nd.sum(outputs[0]).asscalar()
+            speedometer_param = BatchEndParam(epoch=epoch, nbatch=nbatch,
+                                              eval_metric=None, locals=locals())
+            speedometer(speedometer_param)
+            if nbatch % args.log_interval == 0 and nbatch > 0:
+                cur_loss = total_loss / bptt / batch_size / args.log_interval
+                logging.info('Iter[%d] Batch [%d]\tLoss:  %.7f,\tPerplexity:\t%.7f' % \
+                             (epoch, nbatch, cur_loss, math.exp(cur_loss)))
+                total_loss = 0.0
+            nbatch += 1
+        # validation
+        valid_loss = evaluate(module, valid_data, epoch, 'Valid', bptt, batch_size)
+        if valid_loss < best_loss:
+            best_loss = valid_loss
+            # test
+            test_loss = evaluate(module, test_data, epoch, 'Test', bptt, batch_size)
+        else:
+            optimizer.lr *= 0.25
+        train_data.reset()
+    logging.info("Training completed. ")


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services