You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/06/30 02:44:06 UTC

[incubator-mxnet] branch master updated: Update large word language model example (#11405)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 806b41b  Update large word language model example (#11405)
806b41b is described below

commit 806b41bfed33d496a35f0af00997774b662990f5
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Fri Jun 29 19:43:57 2018 -0700

    Update large word language model example (#11405)
    
    * add cython sampler
    
    * remove unused files
    
    * use eval batch size = 1
    
    * update read me
    
    * update read me
    
    * update license
---
 example/rnn/large_word_lm/LogUniformGenerator.cc | 52 ++++++++++++++++++++++
 example/rnn/large_word_lm/LogUniformGenerator.h  | 45 +++++++++++++++++++
 example/rnn/large_word_lm/Makefile               | 25 +++++++++++
 example/rnn/large_word_lm/custom_module.py       |  3 +-
 example/rnn/large_word_lm/log_uniform.pyx        | 38 ++++++++++++++++
 example/rnn/large_word_lm/model.py               | 21 ++++-----
 example/rnn/large_word_lm/readme.md              | 16 +++----
 example/rnn/large_word_lm/run_utils.py           | 11 +++--
 example/rnn/large_word_lm/sampler.py             | 55 ++++++++++++++++++++++++
 example/rnn/large_word_lm/setup.py               | 28 ++++++++++++
 example/rnn/large_word_lm/train.py               | 32 +++++++++-----
 11 files changed, 292 insertions(+), 34 deletions(-)

diff --git a/example/rnn/large_word_lm/LogUniformGenerator.cc b/example/rnn/large_word_lm/LogUniformGenerator.cc
new file mode 100644
index 0000000..ae40659
--- /dev/null
+++ b/example/rnn/large_word_lm/LogUniformGenerator.cc
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file LogUniformGenerator.cc
+ * \brief log uniform distribution generator
+*/
+
+#include <unordered_set>
+#include <unordered_map>
+#include <cmath>
+#include <stddef.h>
+#include <iostream>
+
+#include "LogUniformGenerator.h"
+
+LogUniformGenerator::LogUniformGenerator(const int range_max)
+  : range_max_(range_max), log_range_max_(log(range_max)),
+    generator_(), distribution_(0.0, 1.0) {}
+
+std::unordered_set<long> LogUniformGenerator::draw(const size_t size, int* num_tries) {
+  std::unordered_set<long> result;
+  int tries = 0;
+  while (result.size() != size) {
+    tries += 1;
+    double x = distribution_(generator_);
+    long value = lround(exp(x * log_range_max_)) - 1;
+    // sampling without replacement
+    if (result.find(value) == result.end()) {
+      result.emplace(value);
+    }
+  }
+  *num_tries = tries;
+  return result;
+}
diff --git a/example/rnn/large_word_lm/LogUniformGenerator.h b/example/rnn/large_word_lm/LogUniformGenerator.h
new file mode 100644
index 0000000..b6c4f93
--- /dev/null
+++ b/example/rnn/large_word_lm/LogUniformGenerator.h
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file LogUniformGenerator.h
+ * \brief log uniform distribution generator
+*/
+
+#ifndef _LOG_UNIFORM_GENERATOR_H
+#define _LOG_UNIFORM_GENERATOR_H
+
+#include <unordered_set>
+#include <utility>
+#include <random>
+
+class LogUniformGenerator {
+private:
+  const int range_max_;
+  const double log_range_max_;
+  std::default_random_engine generator_;
+  std::uniform_real_distribution<double> distribution_;
+public:
+  LogUniformGenerator(const int);
+  std::unordered_set<long> draw(const size_t, int*);
+};
+
+#endif // _LOG_UNIFORM_GENERATOR_H
+
diff --git a/example/rnn/large_word_lm/Makefile b/example/rnn/large_word_lm/Makefile
new file mode 100644
index 0000000..116f7bb
--- /dev/null
+++ b/example/rnn/large_word_lm/Makefile
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+all: clean
+	python setup.py build_ext --inplace
+
+clean:
+	rm -rf build
+	rm -rf __pycache__
+	rm -rf log_uniform.cpp
+	rm -rf log_uniform.*.so
diff --git a/example/rnn/large_word_lm/custom_module.py b/example/rnn/large_word_lm/custom_module.py
index 05d0fb7..a117427 100644
--- a/example/rnn/large_word_lm/custom_module.py
+++ b/example/rnn/large_word_lm/custom_module.py
@@ -60,7 +60,7 @@ class CustomModule(Module):
                                           priority=-param_idx)
 
     @staticmethod
-    def load(prefix, epoch, load_optimizer_states=False, **kwargs):
+    def load(prefix, epoch, load_optimizer_states=False, symbol=None, **kwargs):
         """Creates a model from previously saved checkpoint.
 
         Parameters
@@ -90,6 +90,7 @@ class CustomModule(Module):
             Default ``None``, indicating no network parameters are fixed.
         """
         sym, args, auxs = load_checkpoint(prefix, epoch)
+        sym = sym if symbol is None else symbol
         mod = CustomModule(symbol=sym, **kwargs)
         mod._arg_params = args
         mod._aux_params = auxs
diff --git a/example/rnn/large_word_lm/log_uniform.pyx b/example/rnn/large_word_lm/log_uniform.pyx
new file mode 100644
index 0000000..641835a
--- /dev/null
+++ b/example/rnn/large_word_lm/log_uniform.pyx
@@ -0,0 +1,38 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from libcpp.unordered_set cimport unordered_set
+import cython
+
+cdef extern from "LogUniformGenerator.h":
+    cdef cppclass LogUniformGenerator:
+        LogUniformGenerator(int) except +
+        unordered_set[long] draw(int, int*) except +
+
+cdef class LogUniformSampler:
+    cdef LogUniformGenerator* c_sampler
+
+    def __cinit__(self, N):
+        self.c_sampler = new LogUniformGenerator(N)
+
+    def __dealloc__(self):
+        del self.c_sampler
+
+    def sample_unique(self, size):
+        cdef int num_tries
+        samples = list(self.c_sampler.draw(size, &num_tries))
+        return samples, num_tries
diff --git a/example/rnn/large_word_lm/model.py b/example/rnn/large_word_lm/model.py
index 3d3c83b..0e9abda 100644
--- a/example/rnn/large_word_lm/model.py
+++ b/example/rnn/large_word_lm/model.py
@@ -58,7 +58,7 @@ def rnn(bptt, vocab_size, num_embed, nhid, num_layers, dropout, num_proj, batch_
         init_h = S.var(prefix + 'init_h', shape=(batch_size, num_proj), init=mx.init.Zero())
         init_c = S.var(prefix + 'init_c', shape=(batch_size, nhid), init=mx.init.Zero())
         state_names += [prefix + 'init_h', prefix + 'init_c']
-        lstmp = mx.gluon.contrib.rnn.LSTMPCell(nhid, num_proj)
+        lstmp = mx.gluon.contrib.rnn.LSTMPCell(nhid, num_proj, prefix=prefix)
         outputs, next_states = lstmp.unroll(bptt, outputs, begin_state=[init_h, init_c], \
                                             layout='NTC', merge_outputs=True)
         outputs = S.Dropout(outputs, p=dropout)
@@ -127,7 +127,7 @@ def sampled_softmax(num_classes, num_samples, in_dim, inputs, weight, bias,
         new_targets = S.zeros_like(label)
         return logits, new_targets
 
-def generate_samples(label, num_splits, num_samples, num_classes):
+def generate_samples(label, num_splits, sampler):
     """ Split labels into `num_splits` and
         generate candidates based on log-uniform distribution.
     """
@@ -139,29 +139,30 @@ def generate_samples(label, num_splits, num_samples, num_classes):
     samples = []
     for label_split in label_splits:
         label_split_2d = label_split.reshape((-1,1))
-        sampled_value = mx.nd.contrib.rand_zipfian(label_split_2d, num_samples, num_classes)
+        sampled_value = sampler.draw(label_split_2d)
         sampled_classes, exp_cnt_true, exp_cnt_sampled = sampled_value
         samples.append(sampled_classes.astype(np.float32))
-        prob_targets.append(exp_cnt_true.astype(np.float32))
+        prob_targets.append(exp_cnt_true.astype(np.float32).reshape((-1,1)))
         prob_samples.append(exp_cnt_sampled.astype(np.float32))
     return samples, prob_samples, prob_targets
 
 class Model():
     """ LSTMP with Importance Sampling """
-    def __init__(self, args, ntokens, rescale_loss):
-        out = rnn(args.bptt, ntokens, args.emsize, args.nhid, args.nlayers,
-                  args.dropout, args.num_proj, args.batch_size)
+    def __init__(self, ntokens, rescale_loss, bptt, emsize,
+                 nhid, nlayers, dropout, num_proj, batch_size, k):
+        out = rnn(bptt, ntokens, emsize, nhid, nlayers,
+                  dropout, num_proj, batch_size)
         rnn_out, self.last_states, self.lstm_args, self.state_names = out
         # decoder weight and bias
         decoder_w = S.var("decoder_weight", stype='row_sparse')
         decoder_b = S.var("decoder_bias", shape=(ntokens, 1), stype='row_sparse')
 
         # sampled softmax for training
-        sample = S.var('sample', shape=(args.k,))
-        prob_sample = S.var("prob_sample", shape=(args.k,))
+        sample = S.var('sample', shape=(k,))
+        prob_sample = S.var("prob_sample", shape=(k,))
         prob_target = S.var("prob_target")
         self.sample_names = ['sample', 'prob_sample', 'prob_target']
-        logits, new_targets = sampled_softmax(ntokens, args.k, args.num_proj,
+        logits, new_targets = sampled_softmax(ntokens, k, num_proj,
                                               rnn_out, decoder_w, decoder_b,
                                               [sample, prob_sample, prob_target])
         self.train_loss = cross_entropy_loss(logits, new_targets, rescale_loss=rescale_loss)
diff --git a/example/rnn/large_word_lm/readme.md b/example/rnn/large_word_lm/readme.md
index d74ffbd..465aaa1 100644
--- a/example/rnn/large_word_lm/readme.md
+++ b/example/rnn/large_word_lm/readme.md
@@ -3,17 +3,18 @@ This example implements the baseline model in
 [Exploring the Limits of Language Modeling](https://arxiv.org/abs/1602.02410) on the
 [Google 1-Billion Word](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) (GBW) dataset.
 
-This example reaches **41.97 perplexity** after 5 training epochs on a 1-layer, 2048-unit, 512-projection LSTM Language Model.
-The result is slightly better than the one reported in the paper(43.7 perplexity).
+This example reaches 48.0 test perplexity after 6 training epochs on a 1-layer, 2048-unit, 512-projection LSTM Language Model.
+It reaches 44.2 test perplexity after 35 epochs of training.
+
 The main differences with the original implementation include:
 * Synchronized gradient updates instead of asynchronized updates
-* Noise candidates are sampled with replacement
 
-Each epoch for training takes around 80 minutes on a p3.8xlarge instance, which comes with 4 Volta V100 GPUs.
+Each epoch for training (excluding time for evaluation on test set) takes around 80 minutes on a p3.8xlarge instance, which comes with 4 Volta V100 GPUs.
 
-# Setup - Original Data Format
-1. Download 1-Billion Word Dataset - [Link](http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz)
+# Setup dataset and build sampler
+1. Download 1-Billion Word Dataset: [Link](http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz)
 2. Download pre-processed vocabulary file which maps tokens into ids.
+3. Build sampler with cython by running `make` in the current directory. If you do not have cython installed, run `pip install cython`
 
 # Run the Script
 ```
@@ -59,8 +60,7 @@ optional arguments:
 
 To reproduce the result, run
 ```
-train.py --gpus=0,1,2,3 --clip=1 --lr=0.05 --dropout=0.01 --eps=0.0001 --rescale-embed=128
+train.py --gpus=0,1,2,3 --clip=10 --lr=0.2 --dropout=0.1 --eps=1 --rescale-embed=256
 --test=/path/to/heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050
 --data=/path/to/training-monolingual.tokenized.shuffled/*
-# ~42 perplexity for 5 epochs of training
 ```
diff --git a/example/rnn/large_word_lm/run_utils.py b/example/rnn/large_word_lm/run_utils.py
index 7650530e..bd1412d 100644
--- a/example/rnn/large_word_lm/run_utils.py
+++ b/example/rnn/large_word_lm/run_utils.py
@@ -53,7 +53,7 @@ def get_parser():
                         help='report interval')
     parser.add_argument('--seed', type=int, default=1,
                         help='random seed')
-    parser.add_argument('--checkpoint-dir', type=str, default='./checkpoint/cp',
+    parser.add_argument('--checkpoint-dir', type=str, default='./checkpoint',
                         help='dir for checkpoint')
     parser.add_argument('--lr', type=float, default=0.1,
                         help='initial learning rate')
@@ -68,18 +68,21 @@ def evaluate(mod, data_iter, epoch, log_interval):
     start = time.time()
     total_L = 0.0
     nbatch = 0
+    density = 0
     mod.set_states(value=0)
     for batch in data_iter:
         mod.forward(batch, is_train=False)
         outputs = mod.get_outputs(merge_multi_context=False)
         states = outputs[:-1]
-        total_L += outputs[-1][0].asscalar()
+        total_L += outputs[-1][0]
         mod.set_states(states=states)
         nbatch += 1
+        # don't include padding data in the test perplexity
+        density += batch.data[1].mean()
         if (nbatch + 1) % log_interval == 0:
-            logging.info("Eval batch %d loss : %.7f" % (nbatch, total_L / nbatch))
+            logging.info("Eval batch %d loss : %.7f" % (nbatch, (total_L / density).asscalar()))
     data_iter.reset()
-    loss = total_L / nbatch
+    loss = (total_L / density).asscalar()
     ppl = math.exp(loss) if loss < 100 else 1e37
     end = time.time()
     logging.info('Iter[%d]\t\t CE loss %.7f, ppl %.7f. Eval duration = %.2f seconds ' % \
diff --git a/example/rnn/large_word_lm/sampler.py b/example/rnn/large_word_lm/sampler.py
new file mode 100644
index 0000000..047e516
--- /dev/null
+++ b/example/rnn/large_word_lm/sampler.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import math
+import os
+import numpy as np
+import mxnet as mx
+import log_uniform
+from mxnet import ndarray
+
+class LogUniformSampler():
+    def __init__(self, range_max, num_sampled):
+        self.range_max = range_max
+        self.num_sampled = num_sampled
+        self.sampler = log_uniform.LogUniformSampler(range_max)
+
+    def _prob_helper(self, num_tries, num_sampled, prob):
+        if num_tries == num_sampled:
+            return prob * num_sampled
+        return (num_tries * (-prob).log1p()).expm1() * -1
+
+    def draw(self, true_classes):
+        """Draw samples from log uniform distribution and returns sampled candidates,
+        expected count for true classes and sampled classes."""
+        range_max = self.range_max
+        num_sampled = self.num_sampled
+        ctx = true_classes.context
+        log_range = math.log(range_max + 1)
+        num_tries = 0
+        true_classes = true_classes.reshape((-1,))
+        sampled_classes, num_tries = self.sampler.sample_unique(num_sampled)
+
+        true_cls = true_classes.as_in_context(ctx).astype('float64')
+        prob_true = ((true_cls + 2.0) / (true_cls + 1.0)).log() / log_range
+        count_true = self._prob_helper(num_tries, num_sampled, prob_true)
+
+        sampled_classes = ndarray.array(sampled_classes, ctx=ctx, dtype='int64')
+        sampled_cls_fp64 = sampled_classes.astype('float64')
+        prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 + 1.0)).log() / log_range
+        count_sampled = self._prob_helper(num_tries, num_sampled, prob_sampled)
+        return [sampled_classes, count_true, count_sampled]
diff --git a/example/rnn/large_word_lm/setup.py b/example/rnn/large_word_lm/setup.py
new file mode 100644
index 0000000..09c4fb0
--- /dev/null
+++ b/example/rnn/large_word_lm/setup.py
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.core import setup, Extension
+from Cython.Build import cythonize
+import numpy
+
+extension_name = "log_uniform"
+sources = ["log_uniform.pyx", "LogUniformGenerator.cc"]
+setup(ext_modules = cythonize(Extension(extension_name,
+                                        sources=sources,
+                                        language="c++",
+                                        extra_compile_args=["-std=c++11"],
+                                        include_dirs=[numpy.get_include()])))
diff --git a/example/rnn/large_word_lm/train.py b/example/rnn/large_word_lm/train.py
index a1b4e31..a815914 100644
--- a/example/rnn/large_word_lm/train.py
+++ b/example/rnn/large_word_lm/train.py
@@ -23,6 +23,7 @@ from data import MultiSentenceIter, Vocabulary
 from model import *
 from custom_module import CustomModule
 import os, math, logging, sys
+from sampler import LogUniformSampler
 
 if __name__ == '__main__':
     # parser
@@ -48,9 +49,11 @@ if __name__ == '__main__':
     train_data = mx.io.PrefetchingIter(MultiSentenceIter(args.data, vocab,
                                        args.batch_size * ngpus, args.bptt))
     # model
-    model = Model(args, ntokens, rescale_loss)
+    model = Model(ntokens, rescale_loss, args.bptt, args.emsize, args.nhid,
+                  args.nlayers, args.dropout, args.num_proj, args.batch_size, args.k)
     train_loss_and_states = model.train()
     eval_loss_and_states = model.eval()
+    sampler = LogUniformSampler(ntokens, args.k)
 
     # training module
     data_names, label_names = ['data', 'mask'], ['label']
@@ -83,7 +86,7 @@ if __name__ == '__main__':
         module.set_states(value=0)
         state_cache = module.get_states(merge_multi_context=False)[:-num_sample_names]
         next_batch = train_data.next()
-        next_sampled_values = generate_samples(next_batch.label[0], ngpus, args.k, ntokens)
+        next_sampled_values = generate_samples(next_batch.label[0], ngpus, sampler)
         stop_iter = False
         while not stop_iter:
             batch = next_batch
@@ -102,8 +105,7 @@ if __name__ == '__main__':
             try:
                 # prefetch the next batch of data and samples
                 next_batch = train_data.next()
-                next_sampled_values = generate_samples(next_batch.label[0], ngpus,
-                                                       args.k, ntokens)
+                next_sampled_values = generate_samples(next_batch.label[0], ngpus, sampler)
             except StopIteration:
                 stop_iter = True
             # cache LSTMP states of the current batch
@@ -132,21 +134,29 @@ if __name__ == '__main__':
             nbatch += 1
 
         # run evaluation with full softmax on cpu
-        module.save_checkpoint(args.checkpoint_dir, epoch, save_optimizer_states=False)
-        cpu_train_mod = CustomModule.load(args.checkpoint_dir, epoch, context=mx.cpu(),
-                                          state_names=train_state_names,
-                                          data_names=data_names, label_names=label_names)
+        if not os.path.exists(args.checkpoint_dir):
+            os.mkdir(args.checkpoint_dir)
+        ckp = os.path.join(args.checkpoint_dir, 'ckp')
+        module.save_checkpoint(ckp, epoch, save_optimizer_states=False)
+
+        # use batch_size = 1 for testing
+        eval_batch_size = 1
+        load_model = Model(ntokens, rescale_loss, args.bptt, args.emsize, args.nhid,
+                           args.nlayers, args.dropout, args.num_proj, eval_batch_size, args.k)
+        cpu_train_mod = CustomModule.load(ckp, epoch, context=mx.cpu(),
+                                          state_names=train_state_names, data_names=data_names,
+                                          label_names=label_names, symbol=load_model.train())
         # eval data iter
         eval_data = mx.io.PrefetchingIter(MultiSentenceIter(args.test, vocab,
-                                          args.batch_size, args.bptt))
+                                          eval_batch_size, args.bptt))
         cpu_train_mod.bind(data_shapes=eval_data.provide_data, label_shapes=eval_data.provide_label)
 
         # eval module
-        eval_module = CustomModule(symbol=eval_loss_and_states, context=mx.cpu(), data_names=data_names,
+        eval_module = CustomModule(symbol=load_model.eval(), context=mx.cpu(), data_names=data_names,
                                    label_names=label_names, state_names=eval_state_names)
         # use `shared_module` to share parameter with the training module
         eval_module.bind(data_shapes=eval_data.provide_data, label_shapes=eval_data.provide_label,
                          shared_module=cpu_train_mod, for_training=False)
-        val_L = run_utils.evaluate(eval_module, eval_data, epoch, 20)
+        val_L = run_utils.evaluate(eval_module, eval_data, epoch, 1000)
         train_data.reset()
     logging.info("Training completed. ")