You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/06/30 02:44:06 UTC
[incubator-mxnet] branch master updated: Update large word language
model example (#11405)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 806b41b Update large word language model example (#11405)
806b41b is described below
commit 806b41bfed33d496a35f0af00997774b662990f5
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Fri Jun 29 19:43:57 2018 -0700
Update large word language model example (#11405)
* add cython sampler
* remove unused files
* use eval batch size = 1
* update read me
* update read me
* update license
---
example/rnn/large_word_lm/LogUniformGenerator.cc | 52 ++++++++++++++++++++++
example/rnn/large_word_lm/LogUniformGenerator.h | 45 +++++++++++++++++++
example/rnn/large_word_lm/Makefile | 25 +++++++++++
example/rnn/large_word_lm/custom_module.py | 3 +-
example/rnn/large_word_lm/log_uniform.pyx | 38 ++++++++++++++++
example/rnn/large_word_lm/model.py | 21 ++++-----
example/rnn/large_word_lm/readme.md | 16 +++----
example/rnn/large_word_lm/run_utils.py | 11 +++--
example/rnn/large_word_lm/sampler.py | 55 ++++++++++++++++++++++++
example/rnn/large_word_lm/setup.py | 28 ++++++++++++
example/rnn/large_word_lm/train.py | 32 +++++++++-----
11 files changed, 292 insertions(+), 34 deletions(-)
diff --git a/example/rnn/large_word_lm/LogUniformGenerator.cc b/example/rnn/large_word_lm/LogUniformGenerator.cc
new file mode 100644
index 0000000..ae40659
--- /dev/null
+++ b/example/rnn/large_word_lm/LogUniformGenerator.cc
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file LogUniformGenerator.cc
+ * \brief log uniform distribution generator
+*/
+
+#include <unordered_set>
+#include <unordered_map>
+#include <cmath>
+#include <stddef.h>
+#include <iostream>
+
+#include "LogUniformGenerator.h"
+
+LogUniformGenerator::LogUniformGenerator(const int range_max)
+ : range_max_(range_max), log_range_max_(log(range_max)),
+ generator_(), distribution_(0.0, 1.0) {}
+
+std::unordered_set<long> LogUniformGenerator::draw(const size_t size, int* num_tries) {
+ std::unordered_set<long> result;
+ int tries = 0;
+ while (result.size() != size) {
+ tries += 1;
+ double x = distribution_(generator_);
+ long value = lround(exp(x * log_range_max_)) - 1;
+ // sampling without replacement
+ if (result.find(value) == result.end()) {
+ result.emplace(value);
+ }
+ }
+ *num_tries = tries;
+ return result;
+}
diff --git a/example/rnn/large_word_lm/LogUniformGenerator.h b/example/rnn/large_word_lm/LogUniformGenerator.h
new file mode 100644
index 0000000..b6c4f93
--- /dev/null
+++ b/example/rnn/large_word_lm/LogUniformGenerator.h
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file LogUniformGenerator.h
+ * \brief log uniform distribution generator
+*/
+
+#ifndef _LOG_UNIFORM_GENERATOR_H
+#define _LOG_UNIFORM_GENERATOR_H
+
+#include <unordered_set>
+#include <utility>
+#include <random>
+
+class LogUniformGenerator {
+private:
+ const int range_max_;
+ const double log_range_max_;
+ std::default_random_engine generator_;
+ std::uniform_real_distribution<double> distribution_;
+public:
+ LogUniformGenerator(const int);
+ std::unordered_set<long> draw(const size_t, int*);
+};
+
+#endif // _LOG_UNIFORM_GENERATOR_H
+
diff --git a/example/rnn/large_word_lm/Makefile b/example/rnn/large_word_lm/Makefile
new file mode 100644
index 0000000..116f7bb
--- /dev/null
+++ b/example/rnn/large_word_lm/Makefile
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+all: clean
+ python setup.py build_ext --inplace
+
+clean:
+ rm -rf build
+ rm -rf __pycache__
+ rm -rf log_uniform.cpp
+ rm -rf log_uniform.*.so
diff --git a/example/rnn/large_word_lm/custom_module.py b/example/rnn/large_word_lm/custom_module.py
index 05d0fb7..a117427 100644
--- a/example/rnn/large_word_lm/custom_module.py
+++ b/example/rnn/large_word_lm/custom_module.py
@@ -60,7 +60,7 @@ class CustomModule(Module):
priority=-param_idx)
@staticmethod
- def load(prefix, epoch, load_optimizer_states=False, **kwargs):
+ def load(prefix, epoch, load_optimizer_states=False, symbol=None, **kwargs):
"""Creates a model from previously saved checkpoint.
Parameters
@@ -90,6 +90,7 @@ class CustomModule(Module):
Default ``None``, indicating no network parameters are fixed.
"""
sym, args, auxs = load_checkpoint(prefix, epoch)
+ sym = sym if symbol is None else symbol
mod = CustomModule(symbol=sym, **kwargs)
mod._arg_params = args
mod._aux_params = auxs
diff --git a/example/rnn/large_word_lm/log_uniform.pyx b/example/rnn/large_word_lm/log_uniform.pyx
new file mode 100644
index 0000000..641835a
--- /dev/null
+++ b/example/rnn/large_word_lm/log_uniform.pyx
@@ -0,0 +1,38 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from libcpp.unordered_set cimport unordered_set
+import cython
+
+cdef extern from "LogUniformGenerator.h":
+ cdef cppclass LogUniformGenerator:
+ LogUniformGenerator(int) except +
+ unordered_set[long] draw(int, int*) except +
+
+cdef class LogUniformSampler:
+ cdef LogUniformGenerator* c_sampler
+
+ def __cinit__(self, N):
+ self.c_sampler = new LogUniformGenerator(N)
+
+ def __dealloc__(self):
+ del self.c_sampler
+
+ def sample_unique(self, size):
+ cdef int num_tries
+ samples = list(self.c_sampler.draw(size, &num_tries))
+ return samples, num_tries
diff --git a/example/rnn/large_word_lm/model.py b/example/rnn/large_word_lm/model.py
index 3d3c83b..0e9abda 100644
--- a/example/rnn/large_word_lm/model.py
+++ b/example/rnn/large_word_lm/model.py
@@ -58,7 +58,7 @@ def rnn(bptt, vocab_size, num_embed, nhid, num_layers, dropout, num_proj, batch_
init_h = S.var(prefix + 'init_h', shape=(batch_size, num_proj), init=mx.init.Zero())
init_c = S.var(prefix + 'init_c', shape=(batch_size, nhid), init=mx.init.Zero())
state_names += [prefix + 'init_h', prefix + 'init_c']
- lstmp = mx.gluon.contrib.rnn.LSTMPCell(nhid, num_proj)
+ lstmp = mx.gluon.contrib.rnn.LSTMPCell(nhid, num_proj, prefix=prefix)
outputs, next_states = lstmp.unroll(bptt, outputs, begin_state=[init_h, init_c], \
layout='NTC', merge_outputs=True)
outputs = S.Dropout(outputs, p=dropout)
@@ -127,7 +127,7 @@ def sampled_softmax(num_classes, num_samples, in_dim, inputs, weight, bias,
new_targets = S.zeros_like(label)
return logits, new_targets
-def generate_samples(label, num_splits, num_samples, num_classes):
+def generate_samples(label, num_splits, sampler):
""" Split labels into `num_splits` and
generate candidates based on log-uniform distribution.
"""
@@ -139,29 +139,30 @@ def generate_samples(label, num_splits, num_samples, num_classes):
samples = []
for label_split in label_splits:
label_split_2d = label_split.reshape((-1,1))
- sampled_value = mx.nd.contrib.rand_zipfian(label_split_2d, num_samples, num_classes)
+ sampled_value = sampler.draw(label_split_2d)
sampled_classes, exp_cnt_true, exp_cnt_sampled = sampled_value
samples.append(sampled_classes.astype(np.float32))
- prob_targets.append(exp_cnt_true.astype(np.float32))
+ prob_targets.append(exp_cnt_true.astype(np.float32).reshape((-1,1)))
prob_samples.append(exp_cnt_sampled.astype(np.float32))
return samples, prob_samples, prob_targets
class Model():
""" LSTMP with Importance Sampling """
- def __init__(self, args, ntokens, rescale_loss):
- out = rnn(args.bptt, ntokens, args.emsize, args.nhid, args.nlayers,
- args.dropout, args.num_proj, args.batch_size)
+ def __init__(self, ntokens, rescale_loss, bptt, emsize,
+ nhid, nlayers, dropout, num_proj, batch_size, k):
+ out = rnn(bptt, ntokens, emsize, nhid, nlayers,
+ dropout, num_proj, batch_size)
rnn_out, self.last_states, self.lstm_args, self.state_names = out
# decoder weight and bias
decoder_w = S.var("decoder_weight", stype='row_sparse')
decoder_b = S.var("decoder_bias", shape=(ntokens, 1), stype='row_sparse')
# sampled softmax for training
- sample = S.var('sample', shape=(args.k,))
- prob_sample = S.var("prob_sample", shape=(args.k,))
+ sample = S.var('sample', shape=(k,))
+ prob_sample = S.var("prob_sample", shape=(k,))
prob_target = S.var("prob_target")
self.sample_names = ['sample', 'prob_sample', 'prob_target']
- logits, new_targets = sampled_softmax(ntokens, args.k, args.num_proj,
+ logits, new_targets = sampled_softmax(ntokens, k, num_proj,
rnn_out, decoder_w, decoder_b,
[sample, prob_sample, prob_target])
self.train_loss = cross_entropy_loss(logits, new_targets, rescale_loss=rescale_loss)
diff --git a/example/rnn/large_word_lm/readme.md b/example/rnn/large_word_lm/readme.md
index d74ffbd..465aaa1 100644
--- a/example/rnn/large_word_lm/readme.md
+++ b/example/rnn/large_word_lm/readme.md
@@ -3,17 +3,18 @@ This example implements the baseline model in
[Exploring the Limits of Language Modeling](https://arxiv.org/abs/1602.02410) on the
[Google 1-Billion Word](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) (GBW) dataset.
-This example reaches **41.97 perplexity** after 5 training epochs on a 1-layer, 2048-unit, 512-projection LSTM Language Model.
-The result is slightly better than the one reported in the paper(43.7 perplexity).
+This example reaches 48.0 test perplexity after 6 training epochs on a 1-layer, 2048-unit, 512-projection LSTM Language Model.
+It reaches 44.2 test perplexity after 35 epochs of training.
+
The main differences with the original implementation include:
* Synchronized gradient updates instead of asynchronized updates
-* Noise candidates are sampled with replacement
-Each epoch for training takes around 80 minutes on a p3.8xlarge instance, which comes with 4 Volta V100 GPUs.
+Each epoch for training (excluding time for evaluation on test set) takes around 80 minutes on a p3.8xlarge instance, which comes with 4 Volta V100 GPUs.
-# Setup - Original Data Format
-1. Download 1-Billion Word Dataset - [Link](http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz)
+# Setup dataset and build sampler
+1. Download 1-Billion Word Dataset: [Link](http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz)
2. Download pre-processed vocabulary file which maps tokens into ids.
+3. Build sampler with cython by running `make` in the current directory. If you do not have cython installed, run `pip install cython`
# Run the Script
```
@@ -59,8 +60,7 @@ optional arguments:
To reproduce the result, run
```
-train.py --gpus=0,1,2,3 --clip=1 --lr=0.05 --dropout=0.01 --eps=0.0001 --rescale-embed=128
+train.py --gpus=0,1,2,3 --clip=10 --lr=0.2 --dropout=0.1 --eps=1 --rescale-embed=256
--test=/path/to/heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050
--data=/path/to/training-monolingual.tokenized.shuffled/*
-# ~42 perplexity for 5 epochs of training
```
diff --git a/example/rnn/large_word_lm/run_utils.py b/example/rnn/large_word_lm/run_utils.py
index 7650530e..bd1412d 100644
--- a/example/rnn/large_word_lm/run_utils.py
+++ b/example/rnn/large_word_lm/run_utils.py
@@ -53,7 +53,7 @@ def get_parser():
help='report interval')
parser.add_argument('--seed', type=int, default=1,
help='random seed')
- parser.add_argument('--checkpoint-dir', type=str, default='./checkpoint/cp',
+ parser.add_argument('--checkpoint-dir', type=str, default='./checkpoint',
help='dir for checkpoint')
parser.add_argument('--lr', type=float, default=0.1,
help='initial learning rate')
@@ -68,18 +68,21 @@ def evaluate(mod, data_iter, epoch, log_interval):
start = time.time()
total_L = 0.0
nbatch = 0
+ density = 0
mod.set_states(value=0)
for batch in data_iter:
mod.forward(batch, is_train=False)
outputs = mod.get_outputs(merge_multi_context=False)
states = outputs[:-1]
- total_L += outputs[-1][0].asscalar()
+ total_L += outputs[-1][0]
mod.set_states(states=states)
nbatch += 1
+ # don't include padding data in the test perplexity
+ density += batch.data[1].mean()
if (nbatch + 1) % log_interval == 0:
- logging.info("Eval batch %d loss : %.7f" % (nbatch, total_L / nbatch))
+ logging.info("Eval batch %d loss : %.7f" % (nbatch, (total_L / density).asscalar()))
data_iter.reset()
- loss = total_L / nbatch
+ loss = (total_L / density).asscalar()
ppl = math.exp(loss) if loss < 100 else 1e37
end = time.time()
logging.info('Iter[%d]\t\t CE loss %.7f, ppl %.7f. Eval duration = %.2f seconds ' % \
diff --git a/example/rnn/large_word_lm/sampler.py b/example/rnn/large_word_lm/sampler.py
new file mode 100644
index 0000000..047e516
--- /dev/null
+++ b/example/rnn/large_word_lm/sampler.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import math
+import os
+import numpy as np
+import mxnet as mx
+import log_uniform
+from mxnet import ndarray
+
+class LogUniformSampler():
+ def __init__(self, range_max, num_sampled):
+ self.range_max = range_max
+ self.num_sampled = num_sampled
+ self.sampler = log_uniform.LogUniformSampler(range_max)
+
+ def _prob_helper(self, num_tries, num_sampled, prob):
+ if num_tries == num_sampled:
+ return prob * num_sampled
+ return (num_tries * (-prob).log1p()).expm1() * -1
+
+ def draw(self, true_classes):
+ """Draw samples from log uniform distribution and returns sampled candidates,
+ expected count for true classes and sampled classes."""
+ range_max = self.range_max
+ num_sampled = self.num_sampled
+ ctx = true_classes.context
+ log_range = math.log(range_max + 1)
+ num_tries = 0
+ true_classes = true_classes.reshape((-1,))
+ sampled_classes, num_tries = self.sampler.sample_unique(num_sampled)
+
+ true_cls = true_classes.as_in_context(ctx).astype('float64')
+ prob_true = ((true_cls + 2.0) / (true_cls + 1.0)).log() / log_range
+ count_true = self._prob_helper(num_tries, num_sampled, prob_true)
+
+ sampled_classes = ndarray.array(sampled_classes, ctx=ctx, dtype='int64')
+ sampled_cls_fp64 = sampled_classes.astype('float64')
+ prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 + 1.0)).log() / log_range
+ count_sampled = self._prob_helper(num_tries, num_sampled, prob_sampled)
+ return [sampled_classes, count_true, count_sampled]
diff --git a/example/rnn/large_word_lm/setup.py b/example/rnn/large_word_lm/setup.py
new file mode 100644
index 0000000..09c4fb0
--- /dev/null
+++ b/example/rnn/large_word_lm/setup.py
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.core import setup, Extension
+from Cython.Build import cythonize
+import numpy
+
+extension_name = "log_uniform"
+sources = ["log_uniform.pyx", "LogUniformGenerator.cc"]
+setup(ext_modules = cythonize(Extension(extension_name,
+ sources=sources,
+ language="c++",
+ extra_compile_args=["-std=c++11"],
+ include_dirs=[numpy.get_include()])))
diff --git a/example/rnn/large_word_lm/train.py b/example/rnn/large_word_lm/train.py
index a1b4e31..a815914 100644
--- a/example/rnn/large_word_lm/train.py
+++ b/example/rnn/large_word_lm/train.py
@@ -23,6 +23,7 @@ from data import MultiSentenceIter, Vocabulary
from model import *
from custom_module import CustomModule
import os, math, logging, sys
+from sampler import LogUniformSampler
if __name__ == '__main__':
# parser
@@ -48,9 +49,11 @@ if __name__ == '__main__':
train_data = mx.io.PrefetchingIter(MultiSentenceIter(args.data, vocab,
args.batch_size * ngpus, args.bptt))
# model
- model = Model(args, ntokens, rescale_loss)
+ model = Model(ntokens, rescale_loss, args.bptt, args.emsize, args.nhid,
+ args.nlayers, args.dropout, args.num_proj, args.batch_size, args.k)
train_loss_and_states = model.train()
eval_loss_and_states = model.eval()
+ sampler = LogUniformSampler(ntokens, args.k)
# training module
data_names, label_names = ['data', 'mask'], ['label']
@@ -83,7 +86,7 @@ if __name__ == '__main__':
module.set_states(value=0)
state_cache = module.get_states(merge_multi_context=False)[:-num_sample_names]
next_batch = train_data.next()
- next_sampled_values = generate_samples(next_batch.label[0], ngpus, args.k, ntokens)
+ next_sampled_values = generate_samples(next_batch.label[0], ngpus, sampler)
stop_iter = False
while not stop_iter:
batch = next_batch
@@ -102,8 +105,7 @@ if __name__ == '__main__':
try:
# prefetch the next batch of data and samples
next_batch = train_data.next()
- next_sampled_values = generate_samples(next_batch.label[0], ngpus,
- args.k, ntokens)
+ next_sampled_values = generate_samples(next_batch.label[0], ngpus, sampler)
except StopIteration:
stop_iter = True
# cache LSTMP states of the current batch
@@ -132,21 +134,29 @@ if __name__ == '__main__':
nbatch += 1
# run evaluation with full softmax on cpu
- module.save_checkpoint(args.checkpoint_dir, epoch, save_optimizer_states=False)
- cpu_train_mod = CustomModule.load(args.checkpoint_dir, epoch, context=mx.cpu(),
- state_names=train_state_names,
- data_names=data_names, label_names=label_names)
+ if not os.path.exists(args.checkpoint_dir):
+ os.mkdir(args.checkpoint_dir)
+ ckp = os.path.join(args.checkpoint_dir, 'ckp')
+ module.save_checkpoint(ckp, epoch, save_optimizer_states=False)
+
+ # use batch_size = 1 for testing
+ eval_batch_size = 1
+ load_model = Model(ntokens, rescale_loss, args.bptt, args.emsize, args.nhid,
+ args.nlayers, args.dropout, args.num_proj, eval_batch_size, args.k)
+ cpu_train_mod = CustomModule.load(ckp, epoch, context=mx.cpu(),
+ state_names=train_state_names, data_names=data_names,
+ label_names=label_names, symbol=load_model.train())
# eval data iter
eval_data = mx.io.PrefetchingIter(MultiSentenceIter(args.test, vocab,
- args.batch_size, args.bptt))
+ eval_batch_size, args.bptt))
cpu_train_mod.bind(data_shapes=eval_data.provide_data, label_shapes=eval_data.provide_label)
# eval module
- eval_module = CustomModule(symbol=eval_loss_and_states, context=mx.cpu(), data_names=data_names,
+ eval_module = CustomModule(symbol=load_model.eval(), context=mx.cpu(), data_names=data_names,
label_names=label_names, state_names=eval_state_names)
# use `shared_module` to share parameter with the training module
eval_module.bind(data_shapes=eval_data.provide_data, label_shapes=eval_data.provide_label,
shared_module=cpu_train_mod, for_training=False)
- val_L = run_utils.evaluate(eval_module, eval_data, epoch, 20)
+ val_L = run_utils.evaluate(eval_module, eval_data, epoch, 1000)
train_data.reset()
logging.info("Training completed. ")