You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/27 17:19:33 UTC

[GitHub] szha closed pull request #10263: Update language model and update with new sentiment analysis and lm example

szha closed pull request #10263: Update language model and update with new sentiment analysis and lm example
URL: https://github.com/apache/incubator-mxnet/pull/10263
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/gluon/word_language_model.py b/example/gluon/word_language_model.py
index 382f0b17ff5..7dd8c697c69 100644
--- a/example/gluon/word_language_model.py
+++ b/example/gluon/word_language_model.py
@@ -22,7 +22,7 @@
 import mxnet as mx
 from mxnet import gluon, autograd
 from mxnet.gluon import data, text
-from mxnet.gluon.model_zoo.text.lm import SimpleRNN, AWDRNN
+from mxnet.gluon.model_zoo.text.lm import StandardRNN, AWDRNN
 
 parser = argparse.ArgumentParser(description='MXNet Autograd RNN/LSTM Language Model on Wikitext-2.')
 parser.add_argument('--model', type=str, default='lstm',
@@ -68,6 +68,8 @@
                     help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu. (the result of multi-gpu training might be slightly different compared to single-gpu training, still need to be finalized)')
 args = parser.parse_args()
 
+print(args)
+
 
 ###############################################################################
 # Load data
@@ -82,7 +84,7 @@
 def get_frequencies(dataset):
     return collections.Counter(x for tup in dataset for x in tup[0] if x)
 
-vocab = text.vocab.Vocabulary(get_frequencies(train_dataset))
+vocab = text.vocab.Vocabulary(get_frequencies(train_dataset), reserved_tokens=['<eos>', '<pad>'])
 def index_tokens(data, label):
     return vocab[data], vocab[label]
 
@@ -124,8 +126,8 @@ def index_tokens(data, label):
     model = AWDRNN(args.model, len(vocab), args.emsize, args.nhid, args.nlayers,
                    args.tied, args.dropout, args.weight_dropout, args.dropout_h, args.dropout_i)
 else:
-    model = SimpleRNN(args.model, len(vocab), args.emsize, args.nhid, args.nlayers,
-                      args.tied, args.dropout)
+    model = StandardRNN(args.model, len(vocab), args.emsize, args.nhid, args.nlayers, args.dropout,
+                      args.tied)
 
 model.initialize(mx.init.Xavier(), ctx=context)
 
@@ -169,7 +171,7 @@ def train():
     for epoch in range(args.epochs):
         total_L = 0.0
         start_epoch_time = time.time()
-        hiddens = [model.begin_state(args.batch_size, func=mx.nd.zeros, ctx=ctx) for ctx in context]
+        hiddens = [model.begin_state(args.batch_size//len(context), func=mx.nd.zeros, ctx=ctx) for ctx in context]
         for i, (data, target) in enumerate(train_data):
             start_batch_time = time.time()
             data = data.T
diff --git a/python/mxnet/gluon/model_zoo/text/base.py b/python/mxnet/gluon/model_zoo/text/base.py
index 6f5f5557b88..7d84bdeda73 100644
--- a/python/mxnet/gluon/model_zoo/text/base.py
+++ b/python/mxnet/gluon/model_zoo/text/base.py
@@ -118,7 +118,7 @@ def get_rnn_cell(mode, num_layers, input_size, hidden_size,
 def get_rnn_layer(mode, num_layers, input_size, hidden_size, dropout, weight_dropout):
     """create rnn layer given specs"""
     if mode == 'rnn_relu':
-        block = rnn.RNN(hidden_size, 'relu', num_layers, dropout=dropout,
+        block = rnn.RNN(hidden_size, num_layers, 'relu', dropout=dropout,
                         input_size=input_size)
     elif mode == 'rnn_tanh':
         block = rnn.RNN(hidden_size, num_layers, dropout=dropout,
diff --git a/python/mxnet/gluon/model_zoo/text/lm.py b/python/mxnet/gluon/model_zoo/text/lm.py
index 7607cf5bf13..34f060935ca 100644
--- a/python/mxnet/gluon/model_zoo/text/lm.py
+++ b/python/mxnet/gluon/model_zoo/text/lm.py
@@ -209,7 +209,7 @@ def awd_lstm_lm_1150(dataset_name=None, vocab=None, pretrained=False, ctx=cpu(),
                        'tie_weights': True,
                        'dropout': 0.4,
                        'weight_drop': 0.5,
-                       'drop_h': 0.3,
+                       'drop_h': 0.2,
                        'drop_i': 0.65}
     assert all(k not in kwargs for k in predefined_args), \
            "Cannot override predefined model settings."


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services