You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/03/20 18:37:35 UTC

[incubator-mxnet] branch nlp_toolkit updated: word language model end-to-end example (AWD/RNNModel, with fix) (#11)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch nlp_toolkit
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/nlp_toolkit by this push:
     new 6cf1560  word language model end-to-end example (AWD/RNNModel, with fix) (#11)
6cf1560 is described below

commit 6cf15606aabfa186ee5d5b6118730e69d7e522da
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Tue Mar 20 11:36:52 2018 -0700

    word language model end-to-end example (AWD/RNNModel, with fix) (#11)
    
    * word language model zoo
    
    * update names
    
    * fix
    
    * fix weight drop
    
    * fix one-off error in dataset
---
 .../train.py => word_language_model.py}            | 134 +++++++-----
 example/gluon/word_language_model/README.md        |  67 ------
 example/gluon/word_language_model/get_ptb_data.sh  |  43 ----
 example/gluon/word_language_model/model.py         |  64 ------
 python/mxnet/gluon/contrib/nn/basic_layers.py      |   4 +-
 python/mxnet/gluon/data/text/base.py               |   8 +-
 python/mxnet/gluon/data/text/utils.py              |  25 ++-
 python/mxnet/gluon/model_zoo/__init__.py           |   2 +
 python/mxnet/gluon/model_zoo/text/__init__.py      |  76 +++++++
 python/mxnet/gluon/model_zoo/text/base.py          | 238 +++++++++++++++++++++
 python/mxnet/gluon/model_zoo/text/lm.py            | 124 +++++++++++
 tests/python/unittest/test_gluon_data_text.py      |  38 +++-
 tests/python/unittest/test_gluon_model_zoo.py      |  24 +--
 13 files changed, 592 insertions(+), 255 deletions(-)

diff --git a/example/gluon/word_language_model/train.py b/example/gluon/word_language_model.py
similarity index 59%
rename from example/gluon/word_language_model/train.py
rename to example/gluon/word_language_model.py
index c732393..eb33f2a 100644
--- a/example/gluon/word_language_model/train.py
+++ b/example/gluon/word_language_model.py
@@ -20,35 +20,41 @@ import collections
 import time
 import math
 import mxnet as mx
-from mxnet import gluon, autograd, contrib
-from mxnet.gluon import data
-import model
+from mxnet import gluon, autograd
+from mxnet.gluon import data, text
+from mxnet.gluon.model_zoo.text.lm import RNNModel, AWDLSTM
 
 parser = argparse.ArgumentParser(description='MXNet Autograd RNN/LSTM Language Model on Wikitext-2.')
 parser.add_argument('--model', type=str, default='lstm',
                     help='type of recurrent net (rnn_tanh, rnn_relu, lstm, gru)')
-parser.add_argument('--emsize', type=int, default=200,
+parser.add_argument('--emsize', type=int, default=400,
                     help='size of word embeddings')
-parser.add_argument('--nhid', type=int, default=200,
+parser.add_argument('--nhid', type=int, default=1150,
                     help='number of hidden units per layer')
-parser.add_argument('--nlayers', type=int, default=2,
+parser.add_argument('--nlayers', type=int, default=3,
                     help='number of layers')
-parser.add_argument('--lr', type=float, default=1.0,
+parser.add_argument('--lr', type=float, default=30,
                     help='initial learning rate')
-parser.add_argument('--clip', type=float, default=0.2,
+parser.add_argument('--clip', type=float, default=0.25,
                     help='gradient clipping')
-parser.add_argument('--epochs', type=int, default=40,
+parser.add_argument('--epochs', type=int, default=750,
                     help='upper epoch limit')
-parser.add_argument('--batch_size', type=int, default=32, metavar='N',
+parser.add_argument('--batch_size', type=int, default=80, metavar='N',
                     help='batch size')
 parser.add_argument('--bptt', type=int, default=35,
                     help='sequence length')
-parser.add_argument('--dropout', type=float, default=0.2,
+parser.add_argument('--dropout', type=float, default=0.4,
                     help='dropout applied to layers (0 = no dropout)')
+parser.add_argument('--dropout_h', type=float, default=0.3,
+                    help='dropout applied to hidden layer (0 = no dropout)')
+parser.add_argument('--dropout_i', type=float, default=0.4,
+                    help='dropout applied to input layer (0 = no dropout)')
+parser.add_argument('--dropout_e', type=float, default=0.1,
+                    help='dropout applied to embedding layer (0 = no dropout)')
+parser.add_argument('--weight_dropout', type=float, default=0.65,
+                    help='weight dropout applied to h2h weight matrix (0 = no weight dropout)')
 parser.add_argument('--tied', action='store_true',
                     help='tie the word embedding and softmax weights')
-parser.add_argument('--cuda', action='store_true',
-                    help='Whether to use gpu')
 parser.add_argument('--log-interval', type=int, default=200, metavar='N',
                     help='report interval')
 parser.add_argument('--save', type=str, default='model.params',
@@ -58,6 +64,10 @@ parser.add_argument('--gctype', type=str, default='none',
                           takes `2bit` or `none` for now.')
 parser.add_argument('--gcthreshold', type=float, default=0.5,
                     help='threshold for 2bit gradient compression')
+parser.add_argument('--eval_only', action='store_true',
+                    help='Whether to only evaluate the trained model')
+parser.add_argument('--gpus', type=str,
+                    help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu. (the result of multi-gpu training might be slightly different compared to single-gpu training, still need to be finalized)')
 args = parser.parse_args()
 
 
@@ -65,23 +75,20 @@ args = parser.parse_args()
 # Load data
 ###############################################################################
 
+context = [mx.cpu()] if args.gpus is None or args.gpus == "" else \
+          [mx.gpu(int(i)) for i in args.gpus.split(',')]
 
-if args.cuda:
-    context = mx.gpu(0)
-else:
-    context = mx.cpu(0)
-
-train_dataset = data.text.lm.WikiText2('./data', 'train', seq_len=args.bptt,
+train_dataset = data.text.lm.WikiText2(segment='train', seq_len=args.bptt,
                                        eos='<eos>')
 
 def get_frequencies(dataset):
     return collections.Counter(x for tup in dataset for x in tup[0] if x)
 
-vocab = contrib.text.vocab.Vocabulary(get_frequencies(train_dataset))
+vocab = text.vocab.Vocabulary(get_frequencies(train_dataset))
 def index_tokens(data, label):
-    return vocab.to_indices(data), vocab.to_indices(label)
+    return vocab[data], vocab[label]
 
-val_dataset, test_dataset = [data.text.lm.WikiText2('./data', segment,
+val_dataset, test_dataset = [data.text.lm.WikiText2(segment=segment,
                                                     seq_len=args.bptt,
                                                     eos='<eos>')
                              for segment in ['val', 'test']]
@@ -114,9 +121,17 @@ test_data = gluon.data.DataLoader(test_dataset.transform(index_tokens),
 
 
 ntokens = len(vocab)
-model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid,
-                       args.nlayers, args.dropout, args.tied)
-model.collect_params().initialize(mx.init.Xavier(), ctx=context)
+
+if args.weight_dropout:
+    model = AWDLSTM(args.model, vocab, args.emsize, args.nhid, args.nlayers,
+                    args.dropout, args.dropout_h, args.dropout_i, args.dropout_e, args.weight_dropout,
+                    args.tied)
+else:
+    model = RNNModel(args.model, vocab, args.emsize, args.nhid,
+                     args.nlayers, args.dropout, args.tied)
+
+model.initialize(mx.init.Xavier(), ctx=context)
+
 
 compression_params = None if args.gctype == 'none' else {'type': args.gctype, 'threshold': args.gcthreshold}
 trainer = gluon.Trainer(model.collect_params(), 'sgd',
@@ -140,38 +155,46 @@ def detach(hidden):
 def eval(data_source):
     total_L = 0.0
     ntotal = 0
-    hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context)
+    hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context[0])
     for i, (data, target) in enumerate(data_source):
-        data = data.as_in_context(context).T
-        target = target.as_in_context(context).T.reshape((-1, 1))
+        data = data.as_in_context(context[0]).T
+        target= target.as_in_context(context[0]).T
         output, hidden = model(data, hidden)
-        L = loss(output, target)
+        L = loss(mx.nd.reshape(output, (-3, -1)),
+                 mx.nd.reshape(target, (-1,)))
         total_L += mx.nd.sum(L).asscalar()
         ntotal += L.size
     return total_L / ntotal
 
 def train():
     best_val = float("Inf")
+    start_train_time = time.time()
     for epoch in range(args.epochs):
         total_L = 0.0
-        start_time = time.time()
-        hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context)
+        start_epoch_time = time.time()
+        hiddens = [model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=ctx) for ctx in context]
         for i, (data, target) in enumerate(train_data):
-            data = data.as_in_context(context).T
-            target = target.as_in_context(context).T.reshape((-1, 1))
-            hidden = detach(hidden)
+            start_batch_time = time.time()
+            data = data.T
+            target= target.T
+            data_list = gluon.utils.split_and_load(data, context, even_split=False)
+            target_list = gluon.utils.split_and_load(target, context, even_split=False)
+            hiddens = [detach(hidden) for hidden in hiddens]
+            Ls = []
             with autograd.record():
-                output, hidden = model(data, hidden)
-                L = loss(output, target)
+                for j, (X, y, h) in enumerate(zip(data_list, target_list, hiddens)):
+                    output, h = model(X, h)
+                    Ls.append(loss(mx.nd.reshape(output, (-3, -1)), mx.nd.reshape(y, (-1,))))
+                    hiddens[j] = h
+            for L in Ls:
                 L.backward()
-
-            grads = [p.grad(context) for p in model.collect_params().values()]
-            # Here gradient is for the whole batch.
-            # So we multiply max_norm by batch_size and bptt size to balance it.
-            gluon.utils.clip_global_norm(grads, args.clip * args.bptt * args.batch_size)
+            for ctx in context:
+                grads = [p.grad(ctx) for p in model.collect_params().values()]
+                gluon.utils.clip_global_norm(grads, args.clip * args.bptt * args.batch_size)
 
             trainer.step(args.batch_size)
-            total_L += mx.nd.sum(L).asscalar()
+
+            total_L += sum([mx.nd.sum(L).asscalar() for L in Ls])
 
             if i % args.log_interval == 0 and i > 0:
                 cur_L = total_L / args.bptt / args.batch_size / args.log_interval
@@ -179,26 +202,33 @@ def train():
                     epoch, i, cur_L, math.exp(cur_L)))
                 total_L = 0.0
 
-        val_L = eval(val_data)
+            print('[Epoch %d Batch %d] throughput %.2f samples/s'%(
+                    epoch, i, args.batch_size / (time.time() - start_batch_time)))
 
+        mx.nd.waitall()
+
+        print('[Epoch %d] throughput %.2f samples/s'%(
+                    epoch, (args.batch_size * nbatch_train) / (time.time() - start_epoch_time)))
+        val_L = eval(val_data)
         print('[Epoch %d] time cost %.2fs, valid loss %.2f, valid ppl %.2f'%(
-            epoch, time.time()-start_time, val_L, math.exp(val_L)))
+            epoch, time.time()-start_epoch_time, val_L, math.exp(val_L)))
 
         if val_L < best_val:
             best_val = val_L
             test_L = eval(test_data)
             model.collect_params().save(args.save)
             print('test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L)))
-        else:
-            args.lr = args.lr*0.25
-            trainer._init_optimizer('sgd',
-                                    {'learning_rate': args.lr,
-                                     'momentum': 0,
-                                     'wd': 0})
-            model.collect_params().load(args.save, context)
+
+    print('Total training throughput %.2f samples/s'%(
+                            (args.batch_size * nbatch_train * args.epochs) / (time.time() - start_train_time)))
 
 if __name__ == '__main__':
-    train()
+    start_pipeline_time = time.time()
+    if not args.eval_only:
+        train()
     model.collect_params().load(args.save, context)
+    val_L = eval(val_data)
     test_L = eval(test_data)
+    print('Best validation loss %.2f, test ppl %.2f'%(val_L, math.exp(val_L)))
     print('Best test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L)))
+    print('Total time cost %.2fs'%(time.time()-start_pipeline_time))
diff --git a/example/gluon/word_language_model/README.md b/example/gluon/word_language_model/README.md
deleted file mode 100644
index ff8ea56..0000000
--- a/example/gluon/word_language_model/README.md
+++ /dev/null
@@ -1,67 +0,0 @@
-# Word-level language modeling RNN
-
-This example trains a multi-layer RNN (Elman, GRU, or LSTM) on Penn Treebank (PTB) language modeling benchmark.
-
-The model obtains the state-of-the-art result on PTB using LSTM, getting a test perplexity of ~72.
-And ~97 ppl in WikiText-2, outperform than basic LSTM(99.3) and reach Variational LSTM(96.3).
-
-The following techniques have been adopted for SOTA results: 
-- [LSTM for LM](https://arxiv.org/pdf/1409.2329.pdf)
-- [Weight tying](https://arxiv.org/abs/1608.05859) between word vectors and softmax output embeddings
-
-## Data
-
-### PTB
-
-The PTB data is the processed version from [(Mikolov et al, 2010)](http://www.fit.vutbr.cz/research/groups/speech/publi/2010/mikolov_interspeech2010_IS100722.pdf):
-
-```bash
-bash get_ptb_data.sh
-python data.py
-```
-
-### Wiki Text
-
-The wikitext-2 data is downloaded from [(The wikitext long term dependency language modeling dataset)](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/):
-
-```bash
-bash get_wikitext2_data.sh
-```
-
-
-## Usage
-
-Example runs and the results:
-
-```
-python train.py -data ./data/ptb. --cuda --tied --nhid 650 --emsize 650 --dropout 0.5        # Test ppl of 75.3 in ptb
-python train.py -data ./data/ptb. --cuda --tied --nhid 1500 --emsize 1500 --dropout 0.65      # Test ppl of 72.0 in ptb
-```
-
-```
-python train.py -data ./data/wikitext-2/wiki. --cuda --tied --nhid 256 --emsize 256          # Test ppl of 97.07 in wikitext-2 
-```
-
-
-<br>
-
-`python train.py --help` gives the following arguments:
-```
-Optional arguments:
-  -h, --help         show this help message and exit
-  --data DATA        location of the data corpus
-  --model MODEL      type of recurrent net (rnn_tanh, rnn_relu, lstm, gru)
-  --emsize EMSIZE    size of word embeddings
-  --nhid NHID        number of hidden units per layer
-  --nlayers NLAYERS  number of layers
-  --lr LR            initial learning rate
-  --clip CLIP        gradient clipping
-  --epochs EPOCHS    upper epoch limit
-  --batch_size N     batch size
-  --bptt BPTT        sequence length
-  --dropout DROPOUT  dropout applied to layers (0 = no dropout)
-  --tied             tie the word embedding and softmax weights
-  --cuda             Whether to use gpu
-  --log-interval N   report interval
-  --save SAVE        path to save the final model
-```
diff --git a/example/gluon/word_language_model/get_ptb_data.sh b/example/gluon/word_language_model/get_ptb_data.sh
deleted file mode 100755
index 2dc4034..0000000
--- a/example/gluon/word_language_model/get_ptb_data.sh
+++ /dev/null
@@ -1,43 +0,0 @@
-#!/usr/bin/env bash
-
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-echo
-echo "NOTE: To continue, you need to review the licensing of the data sets used by this script"
-echo "See https://catalog.ldc.upenn.edu/ldc99t42 for the licensing"
-read -p "Please confirm you have reviewed the licensing [Y/n]:" -n 1 -r
-echo
-
-if [ $REPLY != "Y" ]
-then
-    echo "License was not reviewed, aborting script."
-    exit 1
-fi
-
-RNN_DIR=$(cd `dirname $0`; pwd)
-DATA_DIR="${RNN_DIR}/data/"
-
-if [[ ! -d "${DATA_DIR}" ]]; then
-  echo "${DATA_DIR} doesn't exist, will create one";
-  mkdir -p ${DATA_DIR}
-fi
-
-wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.train.txt;
-wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.valid.txt;
-wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.test.txt;
-wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt;
diff --git a/example/gluon/word_language_model/model.py b/example/gluon/word_language_model/model.py
deleted file mode 100644
index 40e7926..0000000
--- a/example/gluon/word_language_model/model.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import mxnet as mx
-from mxnet import gluon
-from mxnet.gluon import nn, rnn
-
-class RNNModel(gluon.Block):
-    """A model with an encoder, recurrent layer, and a decoder."""
-
-    def __init__(self, mode, vocab_size, num_embed, num_hidden,
-                 num_layers, dropout=0.5, tie_weights=False, **kwargs):
-        super(RNNModel, self).__init__(**kwargs)
-        with self.name_scope():
-            self.drop = nn.Dropout(dropout)
-            self.encoder = nn.Embedding(vocab_size, num_embed,
-                                        weight_initializer=mx.init.Uniform(0.1))
-            if mode == 'rnn_relu':
-                self.rnn = rnn.RNN(num_hidden, 'relu', num_layers, dropout=dropout,
-                                   input_size=num_embed)
-            elif mode == 'rnn_tanh':
-                self.rnn = rnn.RNN(num_hidden, num_layers, dropout=dropout,
-                                   input_size=num_embed)
-            elif mode == 'lstm':
-                self.rnn = rnn.LSTM(num_hidden, num_layers, dropout=dropout,
-                                    input_size=num_embed)
-            elif mode == 'gru':
-                self.rnn = rnn.GRU(num_hidden, num_layers, dropout=dropout,
-                                   input_size=num_embed)
-            else:
-                raise ValueError("Invalid mode %s. Options are rnn_relu, "
-                                 "rnn_tanh, lstm, and gru"%mode)
-
-            if tie_weights:
-                self.decoder = nn.Dense(vocab_size, in_units=num_hidden,
-                                        params=self.encoder.params)
-            else:
-                self.decoder = nn.Dense(vocab_size, in_units=num_hidden)
-
-            self.num_hidden = num_hidden
-
-    def forward(self, inputs, hidden):
-        emb = self.drop(self.encoder(inputs))
-        output, hidden = self.rnn(emb, hidden)
-        output = self.drop(output)
-        decoded = self.decoder(output.reshape((-1, self.num_hidden)))
-        return decoded, hidden
-
-    def begin_state(self, *args, **kwargs):
-        return self.rnn.begin_state(*args, **kwargs)
diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py
index 8870888..082a148 100644
--- a/python/mxnet/gluon/contrib/nn/basic_layers.py
+++ b/python/mxnet/gluon/contrib/nn/basic_layers.py
@@ -108,5 +108,7 @@ class Identity(HybridBlock):
     def __init__(self, prefix=None, params=None):
         super(Identity, self).__init__(prefix=prefix, params=params)
 
-    def hybrid_forward(self, F, x):
+    def hybrid_forward(self, F, *x):
+        if x and len(x) == 1:
+            return x[0]
         return x
diff --git a/python/mxnet/gluon/data/text/base.py b/python/mxnet/gluon/data/text/base.py
index 3c02424..f67a18c 100644
--- a/python/mxnet/gluon/data/text/base.py
+++ b/python/mxnet/gluon/data/text/base.py
@@ -27,7 +27,7 @@ import os
 
 from ..dataset import SimpleDataset
 from ..datareader import DataReader
-from .utils import flatten_samples, collate
+from .utils import flatten_samples, collate, collate_pad_length
 
 class CorpusReader(DataReader):
     """Text reader that reads a whole corpus and produces a dataset based on provided
@@ -124,9 +124,9 @@ class WordLanguageReader(CorpusReader):
         samples = [self._process(s) for s in samples]
         if self._seq_len:
             samples = flatten_samples(samples)
-            if self._pad and len(samples) % self._seq_len:
-                pad_len = self._seq_len - len(samples) % self._seq_len
+            pad_len = collate_pad_length(len(samples), self._seq_len, 1)
+            if self._pad:
                 samples.extend([self._pad] * pad_len)
-            samples = collate(samples, self._seq_len, 1)
+            samples = collate(samples, self._seq_len+1, 1)
 
         return SimpleDataset(samples).transform(lambda x: (x[:-1], x[1:]))
diff --git a/python/mxnet/gluon/data/text/utils.py b/python/mxnet/gluon/data/text/utils.py
index 057e843..b923f74 100644
--- a/python/mxnet/gluon/data/text/utils.py
+++ b/python/mxnet/gluon/data/text/utils.py
@@ -52,5 +52,26 @@ def collate(flat_sample, seq_len, overlap=0):
     -------
     List of samples, each of which has length equal to `seq_len`.
     """
-    num_samples = len(flat_sample) // seq_len
-    return [flat_sample[i*seq_len:((i+1)*seq_len+overlap)] for i in range(num_samples)]
+    num_samples = (len(flat_sample)-seq_len) // (seq_len-overlap) + 1
+    return [flat_sample[i*(seq_len-overlap):((i+1)*seq_len-i*overlap)] for i in range(num_samples)]
+
+def collate_pad_length(num_items, seq_len, overlap=0):
+    """Calculate the padding length needed for collated samples in order not to discard data.
+
+    Parameters
+    ----------
+    num_items : int
+        Number of items in dataset before collating.
+    seq_len : int
+        The length of each of the samples.
+    overlap : int, default 0
+        The extra number of items in current sample that should overlap with the
+        next sample.
+
+    Returns
+    -------
+    Length of paddings.
+    """
+    step = seq_len-overlap
+    span = num_items-seq_len
+    return (span // step + 1) * step - span
diff --git a/python/mxnet/gluon/model_zoo/__init__.py b/python/mxnet/gluon/model_zoo/__init__.py
index b8c32af..bde69ad 100644
--- a/python/mxnet/gluon/model_zoo/__init__.py
+++ b/python/mxnet/gluon/model_zoo/__init__.py
@@ -21,3 +21,5 @@
 from . import model_store
 
 from . import vision
+
+from . import text
diff --git a/python/mxnet/gluon/model_zoo/text/__init__.py b/python/mxnet/gluon/model_zoo/text/__init__.py
new file mode 100644
index 0000000..9aabd8a
--- /dev/null
+++ b/python/mxnet/gluon/model_zoo/text/__init__.py
@@ -0,0 +1,76 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import, arguments-differ
+r"""Module for pre-defined NLP models.
+
+This module contains definitions for the following model architectures:
+-  `AWD`_
+
+You can construct a model with random weights by calling its constructor:
+
+.. code::
+
+    from mxnet.gluon.model_zoo import text
+    # TODO
+    awd = text.awd_variant()
+
+We provide pre-trained models for all the listed models.
+These models can constructed by passing ``pretrained=True``:
+
+.. code::
+
+    from mxnet.gluon.model_zoo import text
+    # TODO
+    awd = text.awd_variant(pretrained=True)
+
+.. _AWD: https://arxiv.org/abs/1404.5997
+"""
+
+from .base import *
+
+from . import lm
+
+def get_model(name, **kwargs):
+    """Returns a pre-defined model by name
+
+    Parameters
+    ----------
+    name : str
+        Name of the model.
+    pretrained : bool
+        Whether to load the pretrained weights for model.
+    classes : int
+        Number of classes for the output layer.
+    ctx : Context, default CPU
+        The context in which to load the pretrained weights.
+    root : str, default '~/.mxnet/models'
+        Location for keeping the model parameters.
+
+    Returns
+    -------
+    HybridBlock
+        The model.
+    """
+    #models = {'awd_variant': awd_variant}
+    name = name.lower()
+    if name not in models:
+        raise ValueError(
+            'Model %s is not supported. Available options are\n\t%s'%(
+                name, '\n\t'.join(sorted(models.keys()))))
+    return models[name](**kwargs)
diff --git a/python/mxnet/gluon/model_zoo/text/base.py b/python/mxnet/gluon/model_zoo/text/base.py
new file mode 100644
index 0000000..7967eb9
--- /dev/null
+++ b/python/mxnet/gluon/model_zoo/text/base.py
@@ -0,0 +1,238 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Building blocks and utility for models."""
+
+from ... import Block, HybridBlock, Parameter, contrib, nn, rnn
+from .... import nd
+
+
+class _TextSeq2SeqModel(Block):
+    def __init__(self, src_vocab, tgt_vocab, **kwargs):
+        super(_TextSeq2SeqModel, self).__init__(**kwargs)
+        self._src_vocab = src_vocab
+        self._tgt_vocab = tgt_vocab
+
+    def begin_state(self, *args, **kwargs):
+        return self.encoder.begin_state(*args, **kwargs)
+
+    def forward(self, inputs, begin_state=None): # pylint: disable=arguments-differ
+        embedded_inputs = self.embedding(inputs)
+        if not begin_state:
+            begin_state = self.begin_state()
+        encoded, state = self.encoder(embedded_inputs, begin_state)
+        out = self.decoder(encoded)
+        return out, state
+
+
+def apply_weight_drop(block, local_param_name, rate, axes=(),
+                      weight_dropout_mode='training'):
+    if not rate:
+        return
+
+    params = block.collect_params('.*{}'.format(local_param_name))
+    for full_param_name, param in params.items():
+        dropped_param = WeightDropParameter(param, rate, weight_dropout_mode, axes)
+        param_dicts, reg_param_dicts = _find_param(block, full_param_name, local_param_name)
+        for param_dict in param_dicts:
+            param_dict[full_param_name] = dropped_param
+        for reg_param_dict in reg_param_dicts:
+            reg_param_dict[local_param_name] = dropped_param
+        local_attr = getattr(block, local_param_name)
+        if local_attr == param:
+            super(Block, block).__setattr__(local_param_name, dropped_param)
+        else:
+            if isinstance(local_attr, (list, tuple)):
+                if isinstance(local_attr, tuple):
+                    local_attr = list(local_attr)
+                for i, v in enumerate(local_attr):
+                    if v == param:
+                        local_attr[i] = dropped_param
+            elif isinstance(local_attr, dict):
+                for k, v in local_attr:
+                    if v == param:
+                        local_attr[k] = dropped_param
+            else:
+                continue
+            super(Block, block).__setattr__(local_param_name, local_attr)
+
+
+def _find_param(block, full_param_name, local_param_name):
+    param_dict_results = []
+    reg_dict_results = []
+    params = block.params
+
+    if full_param_name in block.params._params:
+        if isinstance(block, HybridBlock) and local_param_name in block._reg_params:
+            reg_dict_results.append(block._reg_params)
+        while params:
+            if full_param_name in params._params:
+                param_dict_results.append(params._params)
+            if params._shared:
+                params = params._shared
+            else:
+                break
+
+    if block._children:
+        for c in block._children:
+            pd, rd = _find_param(c, full_param_name, local_param_name)
+            param_dict_results.extend(pd)
+            reg_dict_results.extend(rd)
+
+    return param_dict_results, reg_dict_results
+
+def get_rnn_cell(mode, num_layers, num_embed, num_hidden,
+                 dropout, weight_dropout,
+                 var_drop_in, var_drop_state, var_drop_out):
+    """create rnn cell given specs"""
+    rnn_cell = rnn.SequentialRNNCell()
+    with rnn_cell.name_scope():
+        for i in range(num_layers):
+            if mode == 'rnn_relu':
+                cell = rnn.RNNCell(num_hidden, 'relu', input_size=num_embed)
+            elif mode == 'rnn_tanh':
+                cell = rnn.RNNCell(num_hidden, 'tanh', input_size=num_embed)
+            elif mode == 'lstm':
+                cell = rnn.LSTMCell(num_hidden, input_size=num_embed)
+            elif mode == 'gru':
+                cell = rnn.GRUCell(num_hidden, input_size=num_embed)
+            if var_drop_in + var_drop_state + var_drop_out != 0:
+                cell = contrib.rnn.VariationalDropoutCell(cell,
+                                                          var_drop_in,
+                                                          var_drop_state,
+                                                          var_drop_out)
+
+            rnn_cell.add(cell)
+            if i != num_layers - 1 and dropout != 0:
+                rnn_cell.add(rnn.DropoutCell(dropout))
+
+            if weight_dropout:
+                apply_weight_drop(rnn_cell, 'h2h_weight', rate=weight_dropout)
+
+    return rnn_cell
+
+
+def get_rnn_layer(mode, num_layers, num_embed, num_hidden, dropout, weight_dropout):
+    """create rnn layer given specs"""
+    if mode == 'rnn_relu':
+        block = rnn.RNN(num_hidden, 'relu', num_layers, dropout=dropout,
+                        input_size=num_embed)
+    elif mode == 'rnn_tanh':
+        block = rnn.RNN(num_hidden, num_layers, dropout=dropout,
+                        input_size=num_embed)
+    elif mode == 'lstm':
+        block = rnn.LSTM(num_hidden, num_layers, dropout=dropout,
+                         input_size=num_embed)
+    elif mode == 'gru':
+        block = rnn.GRU(num_hidden, num_layers, dropout=dropout,
+                        input_size=num_embed)
+    if weight_dropout:
+        apply_weight_drop(block, 'h2h_weight', rate=weight_dropout)
+
+    return block
+
+
+class RNNCellLayer(Block):
+    """A block that takes an rnn cell and makes it act like rnn layer."""
+    def __init__(self, rnn_cell, layout='TNC', **kwargs):
+        super(RNNCellBlock, self).__init__(**kwargs)
+        self.cell = rnn_cell
+        assert layout == 'TNC' or layout == 'NTC', \
+            "Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout
+        self._layout = layout
+        self._axis = layout.find('T')
+        self._batch_axis = layout.find('N')
+
+    def forward(self, inputs, states=None): # pylint: disable=arguments-differ
+        batch_size = inputs.shape[self._batch_axis]
+        skip_states = states is None
+        if skip_states:
+            states = self.cell.begin_state(batch_size, ctx=inputs.context)
+        if isinstance(states, ndarray.NDArray):
+            states = [states]
+        for state, info in zip(states, self.cell.state_info(batch_size)):
+            if state.shape != info['shape']:
+                raise ValueError(
+                    "Invalid recurrent state shape. Expecting %s, got %s."%(
+                        str(info['shape']), str(state.shape)))
+        states = sum(zip(*((j for j in i) for i in states)), ())
+        outputs, states = self.cell.unroll(
+            inputs.shape[self._axis], inputs, states,
+            layout=self._layout, merge_outputs=True)
+
+        if skip_states:
+            return outputs
+        return outputs, states
+
+class ExtendedSequential(nn.Sequential):
+    def forward(self, *x): # pylint: disable=arguments-differ
+        for block in self._children:
+            x = block(*x)
+        return x
+
+class TransformerBlock(Block):
+    def __init__(self, *blocks, **kwargs):
+        super(TransformerBlock, self).__init__(**kwargs)
+        self._blocks = blocks
+
+    def forward(self, *inputs):
+        return [block(data) if block else data for block, data in zip(self._blocks, inputs)]
+
+
+class WeightDropParameter(Parameter):
+    """A Container holding parameters (weights) of Blocks and performs dropout.
+    parameter : Parameter
+        The parameter which drops out.
+    rate : float, default 0.0
+        Fraction of the input units to drop. Must be a number between 0 and 1.
+        Dropout is not applied if dropout_rate is 0.
+    mode : str, default 'training'
+        Whether to only turn on dropout during training or to also turn on for inference.
+        Options are 'training' and 'always'.
+    axes : tuple of int, default ()
+        Axes on which dropout mask is shared.
+    """
+    def __init__(self, parameter, rate=0.0, mode='training', axes=()):
+        p = parameter
+        super(WeightDropParameter, self).__init__(
+            name=p.name, grad_req=p.grad_req, shape=p._shape, dtype=p.dtype,
+            lr_mult=p.lr_mult, wd_mult=p.wd_mult, init=p.init,
+            allow_deferred_init=p._allow_deferred_init,
+            differentiable=p._differentiable)
+        self._rate = rate
+        self._mode = mode
+        self._axes = axes
+
+    def data(self, ctx=None):
+        """Returns a copy of this parameter on one context. Must have been
+        initialized on this context before.
+        Parameters
+        ----------
+        ctx : Context
+            Desired context.
+        Returns
+        -------
+        NDArray on ctx
+        """
+        d = self._check_and_get(self._data, ctx)
+        if self._rate:
+            d = nd.Dropout(d, self._rate, self._mode, self._axes)
+        return d
+
+    def __repr__(self):
+        s = 'WeightDropParameter {name} (shape={shape}, dtype={dtype}, rate={rate}, mode={mode})'
+        return s.format(name=self.name, shape=self.shape, dtype=self.dtype,
+                        rate=self._rate, mode=self._mode)
diff --git a/python/mxnet/gluon/model_zoo/text/lm.py b/python/mxnet/gluon/model_zoo/text/lm.py
new file mode 100644
index 0000000..3ada7d0
--- /dev/null
+++ b/python/mxnet/gluon/model_zoo/text/lm.py
@@ -0,0 +1,124 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Language models."""
+
+from .base import _TextSeq2SeqModel, ExtendedSequential, TransformerBlock
+from .base import get_rnn_layer, apply_weight_drop
+from ... import nn
+from .... import init
+
+
+class AWDLSTM(_TextSeq2SeqModel):
+    """AWD language model."""
+    def __init__(self, mode, vocab, embed_dim, hidden_dim, num_layers,
+                 dropout=0.5, drop_h=0.5, drop_i=0.5, drop_e=0.1, weight_drop=0,
+                 tie_weights=False, **kwargs):
+        super(AWDLSTM, self).__init__(vocab, vocab, **kwargs)
+        self._mode = mode
+        self._embed_dim = embed_dim
+        self._hidden_dim = hidden_dim
+        self._num_layers = num_layers
+        self._dropout = dropout
+        self._drop_h = drop_h
+        self._drop_i = drop_i
+        self._drop_e = drop_e
+        self._weight_drop = weight_drop
+        self._tie_weights = tie_weights
+        self.embedding = self._get_embedding()
+        self.encoder = self._get_encoder()
+        self.decoder = self._get_decoder()
+
+    def _get_embedding(self):
+        embedding = nn.HybridSequential()
+        with embedding.name_scope():
+            embedding_block = nn.Embedding(len(self._src_vocab), self._embed_dim,
+                                           weight_initializer=init.Uniform(0.1))
+            if self._drop_e:
+                apply_weight_drop(embedding_block, 'weight', self._drop_e, axes=(1,))
+            embedding.add(embedding_block)
+            if self._drop_i:
+                embedding.add(nn.Dropout(self._drop_i, axes=(0,)))
+        return embedding
+
+    def _get_encoder(self):
+        encoder = ExtendedSequential()
+        with encoder.name_scope():
+            for l in range(self._num_layers):
+                encoder.add(get_rnn_layer(self._mode, 1, self._embed_dim if l == 0 else
+                                          self._hidden_dim, self._hidden_dim if
+                                          l != self._num_layers - 1 or not self._tie_weights
+                                          else self._embed_dim, 0, self._weight_drop))
+                if self._drop_h:
+                    encoder.add(TransformerBlock(nn.Dropout(self._drop_h, axes=(0,)), None))
+        return encoder
+
+    def _get_decoder(self):
+        vocab_size = len(self._tgt_vocab)
+        if self._tie_weights:
+            output = nn.Dense(vocab_size, flatten=False, params=self.embedding.params)
+        else:
+            output = nn.Dense(vocab_size, flatten=False)
+        return output
+
+    def begin_state(self, *args, **kwargs):
+        return self.encoder[0].begin_state(*args, **kwargs)
+
+class RNNModel(_TextSeq2SeqModel):
+    """Simple RNN language model."""
+    def __init__(self, mode, vocab, embed_dim, hidden_dim,
+                 num_layers, dropout=0.5, tie_weights=False, **kwargs):
+        super(RNNModel, self).__init__(vocab, vocab, **kwargs)
+        self._mode = mode
+        self._embed_dim = embed_dim
+        self._hidden_dim = hidden_dim
+        self._num_layers = num_layers
+        self._dropout = dropout
+        self._tie_weights = tie_weights
+        self.embedding = self._get_embedding()
+        self.encoder = self._get_encoder()
+        self.decoder = self._get_decoder()
+
+    def _get_embedding(self):
+        embedding = nn.HybridSequential()
+        with embedding.name_scope():
+            embedding.add(nn.Embedding(len(self._src_vocab), self._embed_dim,
+                                       weight_initializer=init.Uniform(0.1)))
+            if self._dropout:
+                embedding.add(nn.Dropout(self._dropout))
+        return embedding
+
+    def _get_encoder(self):
+        encoder = ExtendedSequential()
+        with encoder.name_scope():
+            for l in range(self._num_layers):
+                encoder.add(get_rnn_layer(self._mode, 1, self._embed_dim if l == 0 else
+                                          self._hidden_dim, self._hidden_dim if
+                                          l != self._num_layers - 1 or not self._tie_weights
+                                          else self._embed_dim, 0, 0))
+
+        return encoder
+
+    def _get_decoder(self):
+        vocab_size = len(self._tgt_vocab)
+        if self._tie_weights:
+            output = nn.Dense(vocab_size, flatten=False, params=self.embedding[0].params)
+        else:
+            output = nn.Dense(vocab_size, flatten=False)
+        return output
+
+    def begin_state(self, *args, **kwargs):
+        return self.encoder[0].begin_state(*args, **kwargs)
diff --git a/tests/python/unittest/test_gluon_data_text.py b/tests/python/unittest/test_gluon_data_text.py
index 49a5988..8888b75 100644
--- a/tests/python/unittest/test_gluon_data_text.py
+++ b/tests/python/unittest/test_gluon_data_text.py
@@ -18,33 +18,51 @@
 from __future__ import print_function
 import collections
 import mxnet as mx
-from mxnet.gluon import nn, data
+from mxnet.gluon import text, contrib, nn
+from mxnet.gluon import data as d
 from common import setup_module, with_seed
 
 def get_frequencies(dataset):
     return collections.Counter(x for tup in dataset for x in tup[0]+tup[1][-1:])
 
+
 def test_wikitext2():
-    train = data.text.lm.WikiText2(root='data/wikitext-2', segment='train')
-    val = data.text.lm.WikiText2(root='data/wikitext-2', segment='val')
-    test = data.text.lm.WikiText2(root='data/wikitext-2', segment='test')
+    train = d.text.lm.WikiText2(root='data/wikitext-2', segment='train')
+    val = d.text.lm.WikiText2(root='data/wikitext-2', segment='val')
+    test = d.text.lm.WikiText2(root='data/wikitext-2', segment='test')
     train_freq, val_freq, test_freq = [get_frequencies(x) for x in [train, val, test]]
-    assert len(train) == 59306, len(train)
-    assert len(train_freq) == 33279, len(train_freq)
+    assert len(train) == 59305, len(train)
+    assert len(train_freq) == 33278, len(train_freq)
     assert len(val) == 6182, len(val)
     assert len(val_freq) == 13778, len(val_freq)
-    assert len(test) == 6975, len(test)
-    assert len(test_freq) == 14144, len(test_freq)
+    assert len(test) == 6974, len(test)
+    assert len(test_freq) == 14143, len(test_freq)
     assert test_freq['English'] == 33, test_freq['English']
     assert len(train[0][0]) == 35, len(train[0][0])
-    test_no_pad = data.text.lm.WikiText2(root='data/wikitext-2', segment='test', pad=None)
+    test_no_pad = d.text.lm.WikiText2(root='data/wikitext-2', segment='test', pad=None)
     assert len(test_no_pad) == 6974, len(test_no_pad)
 
-    train_paragraphs = data.text.lm.WikiText2(root='data/wikitext-2', segment='train', seq_len=None)
+    train_paragraphs = d.text.lm.WikiText2(root='data/wikitext-2', segment='train', seq_len=None)
     assert len(train_paragraphs) == 23767, len(train_paragraphs)
     assert len(train_paragraphs[0][0]) != 35, len(train_paragraphs[0][0])
 
 
+    vocab = text.vocab.Vocabulary(get_frequencies(train))
+    def index_tokens(data, label):
+        return vocab[data], vocab[label]
+    nbatch_train = len(train) // 80
+    train_data = d.DataLoader(train.transform(index_tokens),
+                              batch_size=80,
+                              sampler=contrib.data.IntervalSampler(len(train),
+                                                                   nbatch_train),
+                              last_batch='discard')
+    sampler = contrib.data.IntervalSampler(len(train), nbatch_train)
+
+    for i, (data, target) in enumerate(train_data):
+        pass
+
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_gluon_model_zoo.py b/tests/python/unittest/test_gluon_model_zoo.py
index f89a8f7..e97b3b5 100644
--- a/tests/python/unittest/test_gluon_model_zoo.py
+++ b/tests/python/unittest/test_gluon_model_zoo.py
@@ -17,7 +17,7 @@
 
 from __future__ import print_function
 import mxnet as mx
-from mxnet.gluon.model_zoo.vision import get_model
+from mxnet.gluon.model_zoo.vision import get_model as get_vision_model
 import sys
 from common import setup_module, with_seed
 
@@ -28,20 +28,20 @@ def eprint(*args, **kwargs):
 
 @with_seed()
 def test_models():
-    all_models = ['resnet18_v1', 'resnet34_v1', 'resnet50_v1', 'resnet101_v1', 'resnet152_v1',
-                  'resnet18_v2', 'resnet34_v2', 'resnet50_v2', 'resnet101_v2', 'resnet152_v2',
-                  'vgg11', 'vgg13', 'vgg16', 'vgg19',
-                  'vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn',
-                  'alexnet', 'inceptionv3',
-                  'densenet121', 'densenet161', 'densenet169', 'densenet201',
-                  'squeezenet1.0', 'squeezenet1.1',
-                  'mobilenet1.0', 'mobilenet0.75', 'mobilenet0.5', 'mobilenet0.25',
-                  'mobilenetv2_1.0', 'mobilenetv2_0.75', 'mobilenetv2_0.5', 'mobilenetv2_0.25']
+    vision_models = ['resnet18_v1', 'resnet34_v1', 'resnet50_v1', 'resnet101_v1', 'resnet152_v1',
+                     'resnet18_v2', 'resnet34_v2', 'resnet50_v2', 'resnet101_v2', 'resnet152_v2',
+                     'vgg11', 'vgg13', 'vgg16', 'vgg19',
+                     'vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn',
+                     'alexnet', 'inceptionv3',
+                     'densenet121', 'densenet161', 'densenet169', 'densenet201',
+                     'squeezenet1.0', 'squeezenet1.1',
+                     'mobilenet1.0', 'mobilenet0.75', 'mobilenet0.5', 'mobilenet0.25',
+                     'mobilenetv2_1.0', 'mobilenetv2_0.75', 'mobilenetv2_0.5', 'mobilenetv2_0.25']
     pretrained_to_test = set(['squeezenet1.1'])
 
-    for model_name in all_models:
+    for model_name in vision_models:
         test_pretrain = model_name in pretrained_to_test
-        model = get_model(model_name, pretrained=test_pretrain, root='model/')
+        model = get_vision_model(model_name, pretrained=test_pretrain, root='model/')
         data_shape = (2, 3, 224, 224) if 'inception' not in model_name else (2, 3, 299, 299)
         eprint('testing forward for %s' % model_name)
         print(model)

-- 
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.