You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/03/22 01:02:00 UTC
[incubator-mxnet] branch master updated: [MXNET-96] Language model
with Google's billion words dataset (#10025)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 57534bf [MXNET-96] Language model with Google's billion words dataset (#10025)
57534bf is described below
commit 57534bfadb06ade5fbbab12cf5d05dd8da4538af
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Wed Mar 21 18:01:56 2018 -0700
[MXNET-96] Language model with Google's billion words dataset (#10025)
* Language model with Google's billion words dataset (#197)
Language model with Google's billion words dataset (#197)
* fix lint
* ffix license
* patch
* fix lint
* cr comment
* update fc fallback
* fix build
* fix temp memory in fc
* fix compilation
---
example/rnn/large_word_lm/custom_module.py | 182 ++++++++++++++++++++++++
example/rnn/large_word_lm/data.py | 202 +++++++++++++++++++++++++++
example/rnn/large_word_lm/get_vocab_file.sh | 34 +++++
example/rnn/large_word_lm/model.py | 181 ++++++++++++++++++++++++
example/rnn/large_word_lm/readme.md | 66 +++++++++
example/rnn/large_word_lm/run_utils.py | 87 ++++++++++++
example/rnn/large_word_lm/train.py | 152 ++++++++++++++++++++
python/mxnet/gluon/contrib/rnn/rnn_cell.py | 128 ++++++++++++++++-
src/executor/graph_executor.cc | 6 +-
src/operator/nn/fully_connected-inl.h | 11 +-
src/operator/nn/fully_connected.cc | 101 +++++++++++---
tests/python/unittest/test_gluon_contrib.py | 16 +++
tests/python/unittest/test_sparse_ndarray.py | 18 +++
13 files changed, 1156 insertions(+), 28 deletions(-)
diff --git a/example/rnn/large_word_lm/custom_module.py b/example/rnn/large_word_lm/custom_module.py
new file mode 100644
index 0000000..05d0fb7
--- /dev/null
+++ b/example/rnn/large_word_lm/custom_module.py
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import logging
+import warnings
+
+import mxnet as mx
+import numpy as np
+from mxnet.module import Module
+from mxnet.model import load_checkpoint
+
+class CustomModule(Module):
+
+ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
+ logger=logging, context=mx.cpu(), work_load_list=None,
+ fixed_param_names=None, state_names=None, group2ctxs=None,
+ compression_params=None):
+
+ super(CustomModule, self).__init__(symbol, data_names=data_names, label_names=label_names,
+ logger=logger, context=context, work_load_list=work_load_list,
+ fixed_param_names=fixed_param_names, state_names=state_names,
+ group2ctxs=group2ctxs, compression_params=compression_params)
+
+ def prepare_sparse_params(self, param_rowids):
+ '''Prepares the module for processing a data batch by pulling row_sparse
+ parameters from kvstore to all devices based on rowids.
+
+ Parameters
+ ----------
+ param_rowids : dict of str to NDArray of list of NDArrays
+ '''
+ if not self._kvstore:
+ return
+ assert(isinstance(param_rowids, dict))
+ for param_name, rowids in param_rowids.items():
+ if isinstance(rowids, (tuple, list)):
+ rowids_1d = []
+ for r in rowids:
+ rowids_1d.append(r.reshape((-1,)).astype(np.int64))
+ rowid = mx.nd.concat(*rowids_1d, dim=0)
+ else:
+ rowid = rowids
+ param_idx = self._exec_group.param_names.index(param_name)
+ param_val = self._exec_group.param_arrays[param_idx]
+ self._kvstore.row_sparse_pull(param_name, param_val, row_ids=rowid,
+ priority=-param_idx)
+
+ @staticmethod
+ def load(prefix, epoch, load_optimizer_states=False, **kwargs):
+ """Creates a model from previously saved checkpoint.
+
+ Parameters
+ ----------
+ prefix : str
+ path prefix of saved model files. You should have
+ "prefix-symbol.json", "prefix-xxxx.params", and
+ optionally "prefix-xxxx.states", where xxxx is the
+ epoch number.
+ epoch : int
+ epoch to load.
+ load_optimizer_states : bool
+ whether to load optimizer states. Checkpoint needs
+ to have been made with save_optimizer_states=True.
+ data_names : list of str
+ Default is `('data')` for a typical model used in image classification.
+ label_names : list of str
+ Default is `('softmax_label')` for a typical model used in image
+ classification.
+ logger : Logger
+ Default is `logging`.
+ context : Context or list of Context
+ Default is ``cpu()``.
+ work_load_list : list of number
+ Default ``None``, indicating uniform workload.
+ fixed_param_names: list of str
+ Default ``None``, indicating no network parameters are fixed.
+ """
+ sym, args, auxs = load_checkpoint(prefix, epoch)
+ mod = CustomModule(symbol=sym, **kwargs)
+ mod._arg_params = args
+ mod._aux_params = auxs
+ mod.params_initialized = True
+ if load_optimizer_states:
+ mod._preload_opt_states = '%s-%04d.states'%(prefix, epoch)
+ return mod
+
+ def save_params(self, fname):
+ """Saves model parameters to file.
+ Parameters
+ ----------
+ fname : str
+ Path to output param file.
+ Examples
+ --------
+ >>> # An example of saving module parameters.
+ >>> mod.save_params('myfile')
+ """
+ arg_params, aux_params = self.get_params_from_kv(self._arg_params, self._aux_params)
+ save_dict = {('arg:%s' % k) : v.as_in_context(mx.cpu()) for k, v in arg_params.items()}
+ save_dict.update({('aux:%s' % k) : v.as_in_context(mx.cpu()) for k, v in aux_params.items()})
+ mx.nd.save(fname, save_dict)
+
+ def get_params_from_kv(self, arg_params, aux_params):
+ """ Copy data from kvstore to `arg_params` and `aux_params`.
+ Parameters
+ ----------
+ arg_params : list of NDArray
+ Target parameter arrays.
+ aux_params : list of NDArray
+ Target aux arrays.
+ Notes
+ -----
+ - This function will inplace update the NDArrays in arg_params and aux_params.
+ """
+ assert(self._kvstore is not None)
+ for name, block in zip(self._exec_group.param_names, self._exec_group.param_arrays):
+ assert(isinstance(block, list))
+ if block[0].stype == 'row_sparse':
+ row_ids = mx.nd.arange(start=0, stop=block[0].shape[0], dtype='int64')
+ self._kvstore.row_sparse_pull(name, arg_params[name], row_ids=row_ids)
+ else:
+ assert(block[0].stype == 'default')
+ self._kvstore.pull(name, out=arg_params[name])
+ if len(aux_params) > 0:
+ raise NotImplementedError()
+ return arg_params, aux_params
+
+ def clip_by_global_norm_per_ctx(self, max_norm=1.0, param_names=None):
+ """Clips gradient norm.
+
+ The norm is computed over all gradients together, as if they were
+ concatenated into a single vector. Gradients are modified in-place.
+
+ The method is first used in
+ `[ICML2013] On the difficulty of training recurrent neural networks`
+
+ Note that the gradients are concatenated per context in this implementation.
+
+ Examples
+ --------
+ An example of using clip_grad_norm to clip the gradient before updating the parameters::
+ >>> #Get the gradient via back-propagation
+ >>> net.forward_backward(data_batch=data_batch)
+ >>> norm_val = net.clip_by_global_norm(max_norm=2.0, param_names='w0')
+ >>> net.update()
+ """
+ assert self.binded and self.params_initialized and self.optimizer_initialized
+ num_ctx = len(self._exec_group.grad_arrays[0])
+ grad_array_per_ctx = [[] for i in range(num_ctx)]
+ assert(param_names is not None)
+ for param_name in param_names:
+ param_idx = self._exec_group.param_names.index(param_name)
+ grad_val = self._exec_group.grad_arrays[param_idx]
+ assert(len(grad_val) == num_ctx)
+ for i in range(num_ctx):
+ grad_array_per_ctx[i].append(grad_val[i])
+ norm_vals = []
+ for i in range(num_ctx):
+ mx.gluon.utils.clip_global_norm(grad_array_per_ctx[i], max_norm)
+
+ def rescale_grad(self, scale=None, param_name=None):
+ """ Rescale the gradient of provided parameters by a certain scale """
+ if scale is None or param_name is None:
+ return
+ param_idx = self._exec_group.param_names.index(param_name)
+ grad_vals = self._exec_group.grad_arrays[param_idx]
+ for grad in grad_vals:
+ grad[:] *= scale
diff --git a/example/rnn/large_word_lm/data.py b/example/rnn/large_word_lm/data.py
new file mode 100644
index 0000000..b9cc3e8
--- /dev/null
+++ b/example/rnn/large_word_lm/data.py
@@ -0,0 +1,202 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import numpy as np
+import codecs, glob, random, logging, collections
+
+class Vocabulary(object):
+ """ A dictionary for words.
+ Adapeted from @rafaljozefowicz's implementation.
+ """
+ def __init__(self):
+ self._token_to_id = {}
+ self._token_to_count = collections.Counter()
+ self._id_to_token = []
+ self._num_tokens = 0
+ self._total_count = 0
+ self._s_id = None
+ self._unk_id = None
+
+ @property
+ def num_tokens(self):
+ return self._num_tokens
+
+ @property
+ def unk(self):
+ return "<UNK>"
+
+ @property
+ def unk_id(self):
+ return self._unk_id
+
+ @property
+ def s(self):
+ return "<S>"
+
+ @property
+ def s_id(self):
+ return self._s_id
+
+ def add(self, token, count):
+ self._token_to_id[token] = self._num_tokens
+ self._token_to_count[token] = count
+ self._id_to_token.append(token)
+ self._num_tokens += 1
+ self._total_count += count
+
+ def finalize(self):
+ self._s_id = self.get_id(self.s)
+ self._unk_id = self.get_id(self.unk)
+
+ def get_id(self, token):
+ # Unseen token are mapped to UNK
+ return self._token_to_id.get(token, self.unk_id)
+
+ def get_token(self, id_):
+ return self._id_to_token[id_]
+
+ @staticmethod
+ def from_file(filename):
+ vocab = Vocabulary()
+ with codecs.open(filename, "r", "utf-8") as f:
+ for line in f:
+ word, count = line.strip().split()
+ vocab.add(word, int(count))
+ vocab.finalize()
+ return vocab
+
+class Dataset(object):
+ """ A dataset for truncated bptt with multiple sentences.
+ Adapeted from @rafaljozefowicz's implementation.
+ """
+ def __init__(self, vocab, file_pattern, shuffle=False):
+ self._vocab = vocab
+ self._file_pattern = file_pattern
+ self._shuffle = shuffle
+
+ def _parse_sentence(self, line):
+ s_id = self._vocab.s_id
+ return [s_id] + [self._vocab.get_id(word) for word in line.strip().split()] + [s_id]
+
+ def _parse_file(self, file_name):
+ logging.debug("Processing file: %s" % file_name)
+ with codecs.open(file_name, "r", "utf-8") as f:
+ lines = [line.strip() for line in f]
+ if not self._shuffle:
+ random.shuffle(lines)
+ logging.debug("Finished processing!")
+ for line in lines:
+ yield self._parse_sentence(line)
+
+ def _sentence_stream(self, file_stream):
+ for file_name in file_stream:
+ for sentence in self._parse_file(file_name):
+ yield sentence
+
+ def _iterate(self, sentences, batch_size, num_steps):
+ streams = [None] * batch_size
+ x = np.zeros([batch_size, num_steps], np.int32)
+ y = np.zeros([batch_size, num_steps], np.int32)
+ w = np.zeros([batch_size, num_steps], np.uint8)
+ while True:
+ x[:] = 0
+ y[:] = 0
+ w[:] = 0
+ for i in range(batch_size):
+ tokens_filled = 0
+ try:
+ while tokens_filled < num_steps:
+ if streams[i] is None or len(streams[i]) <= 1:
+ streams[i] = next(sentences)
+ num_tokens = min(len(streams[i]) - 1, num_steps - tokens_filled)
+ x[i, tokens_filled:tokens_filled+num_tokens] = streams[i][:num_tokens]
+ y[i, tokens_filled:tokens_filled + num_tokens] = streams[i][1:num_tokens+1]
+ w[i, tokens_filled:tokens_filled + num_tokens] = 1
+ streams[i] = streams[i][num_tokens:]
+ tokens_filled += num_tokens
+ except StopIteration:
+ pass
+ if not np.any(w):
+ return
+
+ yield x, y, w
+
+ def iterate_once(self, batch_size, num_steps):
+ def file_stream():
+ file_patterns = glob.glob(self._file_pattern)
+ if not self._shuffle:
+ random.shuffle(file_patterns)
+ for file_name in file_patterns:
+ yield file_name
+ for value in self._iterate(self._sentence_stream(file_stream()), batch_size, num_steps):
+ yield value
+
+ def iterate_forever(self, batch_size, num_steps):
+ def file_stream():
+ while True:
+ file_patterns = glob.glob(self._file_pattern)
+ if not self._shuffle:
+ random.shuffle(file_patterns)
+ for file_name in file_patterns:
+ yield file_name
+ for value in self._iterate(self._sentence_stream(file_stream()), batch_size, num_steps):
+ yield value
+
+class MultiSentenceIter(mx.io.DataIter):
+ """ An MXNet iterator that returns the a batch of sequence data and label each time.
+ It also returns a mask which indicates padded/missing data at the end of the dataset.
+ The iterator re-shuffles the data when reset is called.
+ """
+ def __init__(self, data_file, vocab, batch_size, bptt):
+ super(MultiSentenceIter, self).__init__()
+ self.batch_size = batch_size
+ self.bptt = bptt
+ self.provide_data = [('data', (batch_size, bptt), np.int32), ('mask', (batch_size, bptt))]
+ self.provide_label = [('label', (batch_size, bptt))]
+ self.vocab = vocab
+ self.data_file = data_file
+ self._dataset = Dataset(self.vocab, data_file, shuffle=True)
+ self._iter = self._dataset.iterate_once(batch_size, bptt)
+
+ def iter_next(self):
+ data = self._iter.next()
+ if data is None:
+ return False
+ self._next_data = mx.nd.array(data[0], dtype=np.int32)
+ self._next_label = mx.nd.array(data[1])
+ self._next_mask = mx.nd.array(data[2])
+ return True
+
+ def next(self):
+ if self.iter_next():
+ return mx.io.DataBatch(data=self.getdata(), label=self.getlabel())
+ else:
+ raise StopIteration
+
+ def reset(self):
+ self._dataset = Dataset(self.vocab, self.data_file, shuffle=False)
+ self._iter = self._dataset.iterate_once(self.batch_size, self.bptt)
+ self._next_data = None
+ self._next_label = None
+ self._next_mask = None
+
+ def getdata(self):
+ return [self._next_data, self._next_mask]
+
+ def getlabel(self):
+ return [self._next_label]
diff --git a/example/rnn/large_word_lm/get_vocab_file.sh b/example/rnn/large_word_lm/get_vocab_file.sh
new file mode 100755
index 0000000..97fa29b
--- /dev/null
+++ b/example/rnn/large_word_lm/get_vocab_file.sh
@@ -0,0 +1,34 @@
+#!/usr/bin/env bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+echo ""
+echo "NOTE: This script only downloads the pre-processed vocabulary file. "
+echo "For the full training and testing dataset, please download from "
+echo "http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz"
+echo ""
+
+RNN_DIR=$(cd `dirname $0`; pwd)
+DATA_DIR="${RNN_DIR}/data/"
+
+if [[ ! -d "${DATA_DIR}" ]]; then
+ echo "${DATA_DIR} doesn't exist, will create one";
+ mkdir -p ${DATA_DIR}
+fi
+
+wget -P ${DATA_DIR} wget https://s3-us-west-2.amazonaws.com/sparse-dataset/gbw/1b_word_vocab.txt;
diff --git a/example/rnn/large_word_lm/model.py b/example/rnn/large_word_lm/model.py
new file mode 100644
index 0000000..7ee010e
--- /dev/null
+++ b/example/rnn/large_word_lm/model.py
@@ -0,0 +1,181 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Licensed to the Apache Software Soundation (ASS) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASS licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OS ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import mxnet.symbol as S
+import numpy as np
+
+def cross_entropy_loss(inputs, labels, rescale_loss=1):
+ """ cross entropy loss with a mask """
+ criterion = mx.gluon.loss.SoftmaxCrossEntropyLoss(weight=rescale_loss)
+ loss = criterion(inputs, labels)
+ mask = S.var('mask')
+ loss = loss * S.reshape(mask, shape=(-1,))
+ return S.make_loss(loss.mean())
+
+def rnn(bptt, vocab_size, num_embed, nhid, num_layers, dropout, num_proj, batch_size):
+ """ word embedding + LSTM Projected """
+ embed = mx.sym.contrib.SparseEmbedding
+ state_names = []
+ data = S.var('data')
+ weight = S.var("encoder_weight", stype='row_sparse')
+ embed = embed(data=data, weight=weight, input_dim=vocab_size,
+ output_dim=num_embed, name='embed', deterministic=True)
+ states = []
+ outputs = S.Dropout(embed, p=dropout)
+ for i in range(num_layers):
+ prefix = 'lstmp%d_' % i
+ init_h = S.var(prefix + 'init_h', shape=(batch_size, num_proj), init=mx.init.Zero())
+ init_c = S.var(prefix + 'init_c', shape=(batch_size, nhid), init=mx.init.Zero())
+ state_names += [prefix + 'init_h', prefix + 'init_c']
+ lstmp = mx.gluon.contrib.rnn.LSTMPCell(nhid, num_proj)
+ outputs, next_states = lstmp.unroll(bptt, outputs, begin_state=[init_h, init_c], \
+ layout='NTC', merge_outputs=True)
+ outputs = S.Dropout(outputs, p=dropout)
+ states += [S.stop_gradient(s) for s in next_states]
+ outputs = S.reshape(outputs, shape=(-1, num_proj))
+
+ trainable_lstm_args = []
+ for arg in outputs.list_arguments():
+ if 'lstmp' in arg and 'init' not in arg:
+ trainable_lstm_args.append(arg)
+ return outputs, states, trainable_lstm_args, state_names
+
+def sampled_softmax(num_classes, num_samples, in_dim, inputs, weight, bias,
+ sampled_values, remove_accidental_hits=True):
+ """ Sampled softmax via importance sampling.
+ This under-estimates the full softmax and is only used for training.
+ """
+ # inputs = (n, in_dim)
+ embed = mx.sym.contrib.SparseEmbedding
+ sample, prob_sample, prob_target = sampled_values
+
+ # (num_samples, )
+ sample = S.var('sample', shape=(num_samples,), dtype='float32')
+ # (n, )
+ label = S.var('label')
+ label = S.reshape(label, shape=(-1,), name="label_reshape")
+ # (num_samples+n, )
+ sample_label = S.concat(sample, label, dim=0)
+ # lookup weights and biases
+ # (num_samples+n, dim)
+ sample_target_w = embed(data=sample_label, weight=weight,
+ input_dim=num_classes, output_dim=in_dim,
+ deterministic=True)
+ # (num_samples+n, 1)
+ sample_target_b = embed(data=sample_label, weight=bias,
+ input_dim=num_classes, output_dim=1, deterministic=True)
+ # (num_samples, dim)
+ sample_w = S.slice(sample_target_w, begin=(0, 0), end=(num_samples, None))
+ target_w = S.slice(sample_target_w, begin=(num_samples, 0), end=(None, None))
+ sample_b = S.slice(sample_target_b, begin=(0, 0), end=(num_samples, None))
+ target_b = S.slice(sample_target_b, begin=(num_samples, 0), end=(None, None))
+
+ # target
+ # (n, 1)
+ true_pred = S.sum(target_w * inputs, axis=1, keepdims=True) + target_b
+ # samples
+ # (n, num_samples)
+ sample_b = S.reshape(sample_b, (-1,))
+ sample_pred = S.FullyConnected(inputs, weight=sample_w, bias=sample_b,
+ num_hidden=num_samples)
+
+ # remove accidental hits
+ if remove_accidental_hits:
+ label_v = S.reshape(label, (-1, 1))
+ sample_v = S.reshape(sample, (1, -1))
+ neg = S.broadcast_equal(label_v, sample_v) * -1e37
+ sample_pred = sample_pred + neg
+
+ prob_sample = S.reshape(prob_sample, shape=(1, num_samples))
+ p_target = true_pred - S.log(prob_target)
+ p_sample = S.broadcast_sub(sample_pred, S.log(prob_sample))
+
+ # return logits and new_labels
+ # (n, 1+num_samples)
+ logits = S.concat(p_target, p_sample, dim=1)
+ new_targets = S.zeros_like(label)
+ return logits, new_targets
+
+def generate_samples(label, num_splits, num_samples, num_classes):
+ """ Split labels into `num_splits` and
+ generate candidates based on log-uniform distribution.
+ """
+ def listify(x):
+ return x if isinstance(x, list) else [x]
+ label_splits = listify(label.split(num_splits, axis=0))
+ prob_samples = []
+ prob_targets = []
+ samples = []
+ for label_split in label_splits:
+ label_split_2d = label_split.reshape((-1,1))
+ sampled_value = mx.nd.contrib.rand_zipfian(label_split_2d, num_samples, num_classes)
+ sampled_classes, exp_cnt_true, exp_cnt_sampled = sampled_value
+ samples.append(sampled_classes.astype(np.float32))
+ prob_targets.append(exp_cnt_true.astype(np.float32))
+ prob_samples.append(exp_cnt_sampled.astype(np.float32))
+ return samples, prob_samples, prob_targets
+
+class Model():
+ """ LSTMP with Importance Sampling """
+ def __init__(self, args, ntokens, rescale_loss):
+ out = rnn(args.bptt, ntokens, args.emsize, args.nhid, args.nlayers,
+ args.dropout, args.num_proj, args.batch_size)
+ rnn_out, self.last_states, self.lstm_args, self.state_names = out
+ # decoder weight and bias
+ decoder_w = S.var("decoder_weight", stype='row_sparse')
+ decoder_b = S.var("decoder_bias", shape=(ntokens, 1), stype='row_sparse')
+
+ # sampled softmax for training
+ sample = S.var('sample', shape=(args.k,))
+ prob_sample = S.var("prob_sample", shape=(args.k,))
+ prob_target = S.var("prob_target")
+ self.sample_names = ['sample', 'prob_sample', 'prob_target']
+ logits, new_targets = sampled_softmax(ntokens, args.k, args.num_proj,
+ rnn_out, decoder_w, decoder_b,
+ [sample, prob_sample, prob_target])
+ self.train_loss = cross_entropy_loss(logits, new_targets, rescale_loss=rescale_loss)
+
+ # full softmax for testing
+ eval_logits = S.FullyConnected(data=rnn_out, weight=decoder_w,
+ num_hidden=ntokens, name='decode_fc', bias=decoder_b)
+ label = S.Variable('label')
+ label = S.reshape(label, shape=(-1,))
+ self.eval_loss = cross_entropy_loss(eval_logits, label)
+
+ def eval(self):
+ return S.Group(self.last_states + [self.eval_loss])
+
+ def train(self):
+ return S.Group(self.last_states + [self.train_loss])
diff --git a/example/rnn/large_word_lm/readme.md b/example/rnn/large_word_lm/readme.md
new file mode 100644
index 0000000..d74ffbd
--- /dev/null
+++ b/example/rnn/large_word_lm/readme.md
@@ -0,0 +1,66 @@
+# Large-Scale Language Model
+This example implements the baseline model in
+[Exploring the Limits of Language Modeling](https://arxiv.org/abs/1602.02410) on the
+[Google 1-Billion Word](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) (GBW) dataset.
+
+This example reaches **41.97 perplexity** after 5 training epochs on a 1-layer, 2048-unit, 512-projection LSTM Language Model.
+The result is slightly better than the one reported in the paper(43.7 perplexity).
+The main differences with the original implementation include:
+* Synchronized gradient updates instead of asynchronized updates
+* Noise candidates are sampled with replacement
+
+Each epoch for training takes around 80 minutes on a p3.8xlarge instance, which comes with 4 Volta V100 GPUs.
+
+# Setup - Original Data Format
+1. Download 1-Billion Word Dataset - [Link](http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz)
+2. Download pre-processed vocabulary file which maps tokens into ids.
+
+# Run the Script
+```
+usage: train.py [-h] [--data DATA] [--test TEST] [--vocab VOCAB]
+ [--emsize EMSIZE] [--nhid NHID] [--num-proj NUM_PROJ]
+ [--nlayers NLAYERS] [--epochs EPOCHS]
+ [--batch-size BATCH_SIZE] [--dropout DROPOUT] [--eps EPS]
+ [--bptt BPTT] [--k K] [--gpus GPUS]
+ [--log-interval LOG_INTERVAL] [--seed SEED]
+ [--checkpoint-dir CHECKPOINT_DIR] [--lr LR] [--clip CLIP]
+ [--rescale-embed RESCALE_EMBED]
+
+Language Model on GBW
+
+optional arguments:
+ -h, --help show this help message and exit
+ --data DATA location of the training data
+ --test TEST location of the test data
+ --vocab VOCAB location of the corpus vocabulary file
+ --emsize EMSIZE size of word embeddings
+ --nhid NHID number of hidden units per layer
+ --num-proj NUM_PROJ number of projection units per layer
+ --nlayers NLAYERS number of LSTM layers
+ --epochs EPOCHS number of epoch for training
+ --batch-size BATCH_SIZE
+ batch size per gpu
+ --dropout DROPOUT dropout applied to layers (0 = no dropout)
+ --eps EPS epsilon for adagrad
+ --bptt BPTT sequence length
+ --k K number of noise samples for estimation
+ --gpus GPUS list of gpus to run, e.g. 0 or 0,2,5. empty means
+ using gpu(0).
+ --log-interval LOG_INTERVAL
+ report interval
+ --seed SEED random seed
+ --checkpoint-dir CHECKPOINT_DIR
+ dir for checkpoint
+ --lr LR initial learning rate
+ --clip CLIP gradient clipping by global norm.
+ --rescale-embed RESCALE_EMBED
+ scale factor for the gradients of the embedding layer
+```
+
+To reproduce the result, run
+```
+train.py --gpus=0,1,2,3 --clip=1 --lr=0.05 --dropout=0.01 --eps=0.0001 --rescale-embed=128
+--test=/path/to/heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050
+--data=/path/to/training-monolingual.tokenized.shuffled/*
+# ~42 perplexity for 5 epochs of training
+```
diff --git a/example/rnn/large_word_lm/run_utils.py b/example/rnn/large_word_lm/run_utils.py
new file mode 100644
index 0000000..7650530e
--- /dev/null
+++ b/example/rnn/large_word_lm/run_utils.py
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import argparse, time, logging, math
+
+def get_parser():
+ parser = argparse.ArgumentParser(description='Language Model on GBW')
+ parser.add_argument('--data', type=str,
+ default='/path/to/training-monolingual.tokenized.shuffled/*',
+ help='location of the training data')
+ parser.add_argument('--test', type=str,
+ default='/path/to/heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050',
+ help='location of the test data')
+ parser.add_argument('--vocab', type=str, default='./data/1b_word_vocab.txt',
+ help='location of the corpus vocabulary file')
+ parser.add_argument('--emsize', type=int, default=512,
+ help='size of word embeddings')
+ parser.add_argument('--nhid', type=int, default=2048,
+ help='number of hidden units per layer')
+ parser.add_argument('--num-proj', type=int, default=512,
+ help='number of projection units per layer')
+ parser.add_argument('--nlayers', type=int, default=1,
+ help='number of LSTM layers')
+ parser.add_argument('--epochs', type=int, default=8,
+ help='number of epoch for training')
+ parser.add_argument('--batch-size', type=int, default=128,
+ help='batch size per gpu')
+ parser.add_argument('--dropout', type=float, default=0.1,
+ help='dropout applied to layers (0 = no dropout)')
+ parser.add_argument('--eps', type=float, default=0.0001,
+ help='epsilon for adagrad')
+ parser.add_argument('--bptt', type=int, default=20,
+ help='sequence length')
+ parser.add_argument('--k', type=int, default=8192,
+ help='number of noise samples for estimation')
+ parser.add_argument('--gpus', type=str,
+ help='list of gpus to run, e.g. 0 or 0,2,5. empty means using gpu(0).')
+ parser.add_argument('--log-interval', type=int, default=200,
+ help='report interval')
+ parser.add_argument('--seed', type=int, default=1,
+ help='random seed')
+ parser.add_argument('--checkpoint-dir', type=str, default='./checkpoint/cp',
+ help='dir for checkpoint')
+ parser.add_argument('--lr', type=float, default=0.1,
+ help='initial learning rate')
+ parser.add_argument('--clip', type=float, default=1,
+ help='gradient clipping by global norm.')
+ parser.add_argument('--rescale-embed', type=float, default=None,
+ help='scale factor for the gradients of the embedding layer')
+ return parser
+
+def evaluate(mod, data_iter, epoch, log_interval):
+ """ Run evaluation on cpu. """
+ start = time.time()
+ total_L = 0.0
+ nbatch = 0
+ mod.set_states(value=0)
+ for batch in data_iter:
+ mod.forward(batch, is_train=False)
+ outputs = mod.get_outputs(merge_multi_context=False)
+ states = outputs[:-1]
+ total_L += outputs[-1][0].asscalar()
+ mod.set_states(states=states)
+ nbatch += 1
+ if (nbatch + 1) % log_interval == 0:
+ logging.info("Eval batch %d loss : %.7f" % (nbatch, total_L / nbatch))
+ data_iter.reset()
+ loss = total_L / nbatch
+ ppl = math.exp(loss) if loss < 100 else 1e37
+ end = time.time()
+ logging.info('Iter[%d]\t\t CE loss %.7f, ppl %.7f. Eval duration = %.2f seconds ' % \
+ (epoch, loss, ppl, end - start))
+ return loss
diff --git a/example/rnn/large_word_lm/train.py b/example/rnn/large_word_lm/train.py
new file mode 100644
index 0000000..a1b4e31
--- /dev/null
+++ b/example/rnn/large_word_lm/train.py
@@ -0,0 +1,152 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import mxnet as mx
+import mxnet.symbol as S
+import run_utils
+from data import MultiSentenceIter, Vocabulary
+from model import *
+from custom_module import CustomModule
+import os, math, logging, sys
+
+if __name__ == '__main__':
+ # parser
+ parser = run_utils.get_parser()
+ args = parser.parse_args()
+ head = '%(asctime)-15s %(message)s'
+ ctx = [mx.gpu(int(i)) for i in args.gpus.split(',')] if args.gpus else [mx.gpu()]
+ ngpus = len(ctx)
+ rescale_loss = args.bptt
+
+ # logging
+ logging.basicConfig(level=logging.INFO, format=head)
+ logging.info(args)
+ logging.debug(sys.argv)
+
+ # seeding
+ mx.random.seed(args.seed)
+ np.random.seed(args.seed)
+
+ # data
+ vocab = Vocabulary.from_file(args.vocab)
+ ntokens = vocab.num_tokens
+ train_data = mx.io.PrefetchingIter(MultiSentenceIter(args.data, vocab,
+ args.batch_size * ngpus, args.bptt))
+ # model
+ model = Model(args, ntokens, rescale_loss)
+ train_loss_and_states = model.train()
+ eval_loss_and_states = model.eval()
+
+ # training module
+ data_names, label_names = ['data', 'mask'], ['label']
+ eval_state_names = model.state_names
+ num_sample_names = len(model.sample_names)
+ train_state_names = model.state_names + model.sample_names
+
+ module = CustomModule(symbol=train_loss_and_states, context=ctx,
+ state_names=train_state_names,
+ data_names=data_names, label_names=label_names)
+ module.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label)
+ module.init_params(initializer=mx.init.Xavier(factor_type='out'))
+
+ # create kvstore and sparse optimizer
+ kvstore = mx.kv.create('device')
+ optimizer = mx.optimizer.create('adagrad', learning_rate=args.lr, \
+ rescale_grad=1.0/ngpus, eps=args.eps)
+ module.init_optimizer(optimizer=optimizer, kvstore=kvstore)
+
+ # speedometer
+ num_words_per_batch = args.batch_size * ngpus * args.bptt
+ speedometer = mx.callback.Speedometer(num_words_per_batch, args.log_interval)
+
+ # training loop
+ logging.info("Training started ... ")
+ for epoch in range(args.epochs):
+ total_L = mx.nd.array([0.0])
+ nbatch = 0
+ # reset initial LSTMP states
+ module.set_states(value=0)
+ state_cache = module.get_states(merge_multi_context=False)[:-num_sample_names]
+ next_batch = train_data.next()
+ next_sampled_values = generate_samples(next_batch.label[0], ngpus, args.k, ntokens)
+ stop_iter = False
+ while not stop_iter:
+ batch = next_batch
+ state_cache += next_sampled_values
+ # propagate LSTMP states from the previous batch
+ module.set_states(states=state_cache)
+ # selectively pull row_sparse weight to multiple devices based on the data batch
+ target_ids = [batch.label[0]]
+ sampled_ids = next_sampled_values[0]
+ param_rowids = {'encoder_weight': batch.data[0],
+ 'decoder_weight': sampled_ids + target_ids,
+ 'decoder_bias': sampled_ids + target_ids}
+ module.prepare_sparse_params(param_rowids)
+ # forward
+ module.forward(batch)
+ try:
+ # prefetch the next batch of data and samples
+ next_batch = train_data.next()
+ next_sampled_values = generate_samples(next_batch.label[0], ngpus,
+ args.k, ntokens)
+ except StopIteration:
+ stop_iter = True
+ # cache LSTMP states of the current batch
+ outputs = module.get_outputs(merge_multi_context=False)
+ state_cache = outputs[:-1]
+ # backward
+ module.backward()
+ for g in range(ngpus):
+ total_L += outputs[-1][g].copyto(mx.cpu()) / ngpus
+
+ # rescaling the gradient for embedding layer emperically leads to faster convergence
+ module.rescale_grad(args.rescale_embed, 'encoder_weight')
+ # clip lstm params on each device based on norm
+ norm = module.clip_by_global_norm_per_ctx(max_norm=args.clip, param_names=model.lstm_args)
+ # update parameters
+ module.update()
+ speedometer_param = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
+ eval_metric=None, locals=locals())
+ speedometer(speedometer_param)
+ # update training metric
+ if nbatch % args.log_interval == 0 and nbatch > 0:
+ cur_L = total_L.asscalar() / args.log_interval / rescale_loss
+ ppl = math.exp(cur_L) if cur_L < 100 else 1e36
+ logging.info('Iter[%d] Batch [%d] \tloss %.7f, ppl %.7f'%(epoch, nbatch, cur_L, ppl))
+ total_L[:] = 0.0
+ nbatch += 1
+
+ # run evaluation with full softmax on cpu
+ module.save_checkpoint(args.checkpoint_dir, epoch, save_optimizer_states=False)
+ cpu_train_mod = CustomModule.load(args.checkpoint_dir, epoch, context=mx.cpu(),
+ state_names=train_state_names,
+ data_names=data_names, label_names=label_names)
+ # eval data iter
+ eval_data = mx.io.PrefetchingIter(MultiSentenceIter(args.test, vocab,
+ args.batch_size, args.bptt))
+ cpu_train_mod.bind(data_shapes=eval_data.provide_data, label_shapes=eval_data.provide_label)
+
+ # eval module
+ eval_module = CustomModule(symbol=eval_loss_and_states, context=mx.cpu(), data_names=data_names,
+ label_names=label_names, state_names=eval_state_names)
+ # use `shared_module` to share parameter with the training module
+ eval_module.bind(data_shapes=eval_data.provide_data, label_shapes=eval_data.provide_label,
+ shared_module=cpu_train_mod, for_training=False)
+ val_L = run_utils.evaluate(eval_module, eval_data, epoch, 20)
+ train_data.reset()
+ logging.info("Training completed. ")
diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
index b964c71..1b9afee 100644
--- a/python/mxnet/gluon/contrib/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
@@ -17,13 +17,12 @@
# coding: utf-8
"""Definition of various recurrent neural network cells."""
-__all__ = ['VariationalDropoutCell']
+__all__ = ['VariationalDropoutCell', 'LSTMPCell']
-from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell
+from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell, HybridRecurrentCell
from ...rnn.rnn_cell import _format_sequence, _get_begin_state, _mask_sequence_variable_length
from ... import tensor_types
-
class VariationalDropoutCell(ModifierCell):
"""
Applies Variational Dropout on base cell.
@@ -193,3 +192,126 @@ class VariationalDropoutCell(ModifierCell):
outputs = _mask_sequence_variable_length(F, outputs, length, valid_length, axis,
merge_outputs)
return outputs, states
+
+
+class LSTMPCell(HybridRecurrentCell):
+ r"""Long-Short Term Memory Projected (LSTMP) network cell.
+ (https://arxiv.org/abs/1402.1128)
+ Each call computes the following function:
+ .. math::
+ \begin{array}{ll}
+ i_t = sigmoid(W_{ii} x_t + b_{ii} + W_{ri} r_{(t-1)} + b_{ri}) \\
+ f_t = sigmoid(W_{if} x_t + b_{if} + W_{rf} r_{(t-1)} + b_{rf}) \\
+ g_t = \tanh(W_{ig} x_t + b_{ig} + W_{rc} r_{(t-1)} + b_{rg}}) \\
+ o_t = sigmoid(W_{io} x_t + b_{io} + W_{ro} r_{(t-1)} + b_{ro}) \\
+ c_t = f_t * c_{(t-1)} + i_t * g_t \\
+ h_t = o_t * \tanh(c_t) \\
+ r_t = W_{hr} h_t
+ \end{array}
+ where :math:`r_t` is the projected recurrent activation at time `t`,
+ math:`h_t` is the hidden state at time `t`, :math:`c_t` is the
+ cell state at time `t`, :math:`x_t` is the input at time `t`, and :math:`i_t`,
+ :math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell, and
+ out gates, respectively.
+ Parameters
+ ----------
+ hidden_size : int
+ Number of units in cell state symbol.
+ projection_size : int
+ Number of units in output symbol.
+ i2h_weight_initializer : str or Initializer
+ Initializer for the input weights matrix, used for the linear
+ transformation of the inputs.
+ h2h_weight_initializer : str or Initializer
+ Initializer for the recurrent weights matrix, used for the linear
+ transformation of the hidden state.
+ h2r_weight_initializer : str or Initializer
+ Initializer for the projection weights matrix, used for the linear
+ transformation of the recurrent state.
+ i2h_bias_initializer : str or Initializer, default 'lstmbias'
+ Initializer for the bias vector. By default, bias for the forget
+ gate is initialized to 1 while all other biases are initialized
+ to zero.
+ h2h_bias_initializer : str or Initializer
+ Initializer for the bias vector.
+ prefix : str, default 'lstmp_'
+ Prefix for name of `Block`s
+ (and name of weight if params is `None`).
+ params : Parameter or None
+ Container for weight sharing between cells.
+ Created if `None`.
+ Inputs:
+ - **data**: input tensor with shape `(batch_size, input_size)`.
+ - **states**: a list of two initial recurrent state tensors, with shape
+ `(batch_size, projection_size)` and `(batch_size, hidden_size)` respectively.
+ Outputs:
+ - **out**: output tensor with shape `(batch_size, num_hidden)`.
+ - **next_states**: a list of two output recurrent state tensors. Each has
+ the same shape as `states`.
+ """
+ def __init__(self, hidden_size, projection_size,
+ i2h_weight_initializer=None, h2h_weight_initializer=None,
+ h2r_weight_initializer=None,
+ i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
+ input_size=0, prefix=None, params=None):
+ super(LSTMPCell, self).__init__(prefix=prefix, params=params)
+
+ self._hidden_size = hidden_size
+ self._input_size = input_size
+ self._projection_size = projection_size
+ self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size),
+ init=i2h_weight_initializer,
+ allow_deferred_init=True)
+ self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, projection_size),
+ init=h2h_weight_initializer,
+ allow_deferred_init=True)
+ self.h2r_weight = self.params.get('h2r_weight', shape=(projection_size, hidden_size),
+ init=h2r_weight_initializer,
+ allow_deferred_init=True)
+ self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,),
+ init=i2h_bias_initializer,
+ allow_deferred_init=True)
+ self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,),
+ init=h2h_bias_initializer,
+ allow_deferred_init=True)
+
+ def state_info(self, batch_size=0):
+ return [{'shape': (batch_size, self._projection_size), '__layout__': 'NC'},
+ {'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}]
+
+ def _alias(self):
+ return 'lstmp'
+
+ def __repr__(self):
+ s = '{name}({mapping})'
+ shape = self.i2h_weight.shape
+ proj_shape = self.h2r_weight.shape
+ mapping = '{0} -> {1} -> {2}'.format(shape[1] if shape[1] else None,
+ shape[0], proj_shape[0])
+ return s.format(name=self.__class__.__name__,
+ mapping=mapping,
+ **self.__dict__)
+
+ # pylint: disable= arguments-differ
+ def hybrid_forward(self, F, inputs, states, i2h_weight,
+ h2h_weight, h2r_weight, i2h_bias, h2h_bias):
+ prefix = 't%d_'%self._counter
+ i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,
+ num_hidden=self._hidden_size*4, name=prefix+'i2h')
+ h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
+ num_hidden=self._hidden_size*4, name=prefix+'h2h')
+ gates = i2h + h2h
+ slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice')
+ in_gate = F.Activation(slice_gates[0], act_type="sigmoid", name=prefix+'i')
+ forget_gate = F.Activation(slice_gates[1], act_type="sigmoid", name=prefix+'f')
+ in_transform = F.Activation(slice_gates[2], act_type="tanh", name=prefix+'c')
+ out_gate = F.Activation(slice_gates[3], act_type="sigmoid", name=prefix+'o')
+ next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform,
+ name=prefix+'state')
+ hidden = F._internal._mul(out_gate, F.Activation(next_c, act_type="tanh"),
+ name=prefix+'hidden')
+ next_r = F.FullyConnected(data=hidden, num_hidden=self._projection_size,
+ weight=h2r_weight, no_bias=True, name=prefix+'out')
+
+ return next_r, [next_r, next_c]
+ # pylint: enable= arguments-differ
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 7d31a31..ee97649 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -802,17 +802,17 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
CHECK_EQ(inferred_stype, arg_nd_stype)
<< "Inferred stype does not match shared_exec.arg_array's stype"
" Therefore, the allocated memory for shared_exec.arg_array cannot"
- " be resued for creating NDArray of the argument"
+ " be resued for creating NDArray of the argument "
<< arg_name << " for the current executor";
CHECK_EQ(inferred_shape, in_arg_nd.shape())
<< "Inferred shape does not match shared_exec.arg_array's shape"
" Therefore, the allocated memory for shared_exec.arg_array cannot"
- " be resued for creating NDArray of the argument"
+ " be resued for creating NDArray of the argument "
<< arg_name << " for the current executor";
CHECK_EQ(inferred_dtype, in_arg_nd.dtype())
<< "Inferred dtype does not match shared_exec.arg_array's dtype"
" Therefore, the allocated memory for shared_exec.arg_array cannot"
- " be resued for creating NDArray of the argument"
+ " be resued for creating NDArray of the argument "
<< arg_name << " for the current executor";
in_arg_vec->emplace_back(in_arg_nd);
} else {
diff --git a/src/operator/nn/fully_connected-inl.h b/src/operator/nn/fully_connected-inl.h
index e8e9564..7eba2e2 100644
--- a/src/operator/nn/fully_connected-inl.h
+++ b/src/operator/nn/fully_connected-inl.h
@@ -95,11 +95,20 @@ void FCForward(const OpContext &ctx, const FullyConnectedParam ¶m,
Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s);
}
+ CHECK_EQ(data.shape_[1], wmat.shape_[1])
+ << "Incomplete weight tensor detected: weight.data().shape[1] != prod(data.data().shape[1:])."
+ " This is not supported by FCForward. If weight is in row_sparse format,"
+ " please make sure all row ids are present.";
// Legacy approach shown here for comparison:
// out = dot(data, wmat.T());
linalg_gemm(data, wmat, out, false, true, s);
if (!param.no_bias) {
- Tensor<xpu, 1, DType> bias = in_data[fullc::kBias].get<xpu, 1, DType>(s);
+ Tensor<xpu, 1, DType> bias = in_data[fullc::kBias].get_with_shape<xpu, 1, DType>(
+ Shape1(wmat.shape_[0]), s);
+ CHECK_EQ(bias.shape_[0], wmat.shape_[0])
+ << "Incomplete bias tensor detected: bias.data().shape[1] != weight.data().shape[0]."
+ " This is not supported by FCForward. If bias is in row_sparse format, please"
+ " make sure all row ids are present.";
out += repmat(bias, data.size(0));
}
}
diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc
index 4362408..75d594f 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -56,7 +56,10 @@ static bool FullyConnectedShape(const nnvm::NodeAttrs& attrs,
}
SHAPE_ASSIGN_CHECK(*in_shape, fullc::kWeight, Shape2(param.num_hidden, num_input));
if (!param.no_bias) {
- SHAPE_ASSIGN_CHECK(*in_shape, fullc::kBias, Shape1(param.num_hidden));
+ if (!shape_assign(&(*in_shape)[fullc::kBias], Shape1(param.num_hidden)) &&
+ !shape_assign(&(*in_shape)[fullc::kBias], Shape2(param.num_hidden, 1))) {
+ LOG(FATAL) << "Unexpected shape for bias " << (*in_shape)[fullc::kBias];
+ }
}
if (!param.flatten) {
@@ -73,22 +76,67 @@ static bool FullyConnectedShape(const nnvm::NodeAttrs& attrs,
return true;
}
-#if MXNET_USE_MKLDNN == 1
void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
- if (SupportMKLDNN(inputs[0])) {
- MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
- MKLDNNFCForward(attrs, ctx, inputs, req, outputs);
- MKLDNN_OPCHECK_RUN(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req,
- outputs);
+ const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
+ const bool valid_data = inputs[0].storage_type() == kDefaultStorage;
+ const bool valid_weight = inputs[1].storage_type() == kDefaultStorage ||
+ inputs[1].storage_type() == kRowSparseStorage;
+ const bool valid_out = outputs[0].storage_type() == kDefaultStorage;
+ bool valid_bias = true;
+ if (!param.no_bias) {
+ valid_bias = inputs[2].storage_type() == kDefaultStorage ||
+ inputs[2].storage_type() == kRowSparseStorage;
+ }
+#if MXNET_USE_MKLDNN == 1
+ if (common::ContainsOnlyStorage(inputs, kDefaultStorage) &&
+ common::ContainsOnlyStorage(outputs, kDefaultStorage)) {
+ if (SupportMKLDNN(inputs[0])) {
+ MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+ MKLDNNFCForward(attrs, ctx, inputs, req, outputs);
+ MKLDNN_OPCHECK_RUN(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req,
+ outputs);
+ } else {
+ FallBackCompute(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req, outputs);
+ }
return;
+ } else if (valid_data && valid_weight && valid_bias && valid_out) {
+ // inputs
+ std::vector<NDArray> temp_ndarrays;
+ std::vector<TBlob> in_blobs;
+ for (const NDArray& in : inputs) {
+ // if ndarray is in default storage and MKLDNN is available,
+ // need to make sure cpu layout data is used, instead of MKL layout
+ if (in.storage_type() == kDefaultStorage) {
+ temp_ndarrays.push_back(in.Reorder2Default());
+ in_blobs.emplace_back(temp_ndarrays.back().data());
+ } else {
+ in_blobs.emplace_back(in.data());
+ }
+ }
+ // output
+ if (req[0] == kWriteTo) const_cast<NDArray &>(outputs[0]).InvalidateMKLDNNData();
+ FullyConnectedCompute<cpu>(attrs, ctx, in_blobs, req, {outputs[0].data()});
+ } else {
+ LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+ }
+#else
+ if (valid_data && valid_weight && valid_bias && valid_out) {
+ std::vector<TBlob> in_blobs(inputs.size());
+ for (size_t i = 0; i < in_blobs.size(); i++) in_blobs[i] = inputs[i].data();
+ std::vector<TBlob> out_blobs(outputs.size());
+ for (size_t i = 0; i < out_blobs.size(); i++) out_blobs[i] = outputs[i].data();
+ FullyConnectedCompute<cpu>(attrs, ctx, in_blobs, req, out_blobs);
+ } else {
+ LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
- FallBackCompute(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req, outputs);
+#endif
}
+#if MXNET_USE_MKLDNN == 1
void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
@@ -129,19 +177,27 @@ inline static bool FCStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
- uint32_t in_expected = param.no_bias ? 2 : 3;
+ const bool valid_data = in_attrs->at(0) == kDefaultStorage;
+ const bool valid_weight = in_attrs->at(1) == kDefaultStorage ||
+ in_attrs->at(1) == kRowSparseStorage;
+ bool valid_bias = true;
+ uint32_t in_expected = 2;
+ if (!param.no_bias) {
+ in_expected = 3;
+ valid_bias = in_attrs->at(2) == kDefaultStorage || in_attrs->at(2) == kRowSparseStorage;
+ }
CHECK_EQ(in_attrs->size(), in_expected);
CHECK_EQ(out_attrs->size(), 1);
-
- DispatchMode wanted_mode;
-#if MXNET_USE_MKLDNN == 1
- if (dev_mask == mshadow::cpu::kDevMask)
- wanted_mode = DispatchMode::kFComputeEx;
- else
-#endif
- wanted_mode = DispatchMode::kFCompute;
- return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
- dispatch_mode, wanted_mode);
+ // dispatch to kFComputeEx is fine even if all inputs are dense and no MKL is present
+ bool dispatched = false;
+ if (!dispatched && valid_data && valid_weight && valid_bias) {
+ dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
+ dispatch_mode, DispatchMode::kFComputeEx);
+ }
+ if (!dispatched) {
+ dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+ }
+ return dispatched;
}
inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
@@ -170,6 +226,7 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
DMLC_REGISTER_PARAMETER(FullyConnectedParam);
NNVM_REGISTER_OP(FullyConnected)
+MXNET_ADD_SPARSE_OP_ALIAS(FullyConnected)
.describe(R"code(Applies a linear transformation: :math:`Y = XW^T + b`.
If ``flatten`` is set to be true, then the shapes are:
@@ -190,6 +247,10 @@ The learnable parameters include both ``weight`` and ``bias``.
If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
+Note that the operator also supports forward computation with `row_sparse` weight and bias,
+where the length of `weight.indices` and `bias.indices` must be equal to `num_hidden`.
+This could be used for model inference with `row_sparse` weights trained with `SparseEmbedding`.
+
)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
const FullyConnectedParam& params = nnvm::get<FullyConnectedParam>(attrs.parsed);
@@ -214,9 +275,7 @@ If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
.set_attr<nnvm::FInferShape>("FInferShape", FullyConnectedShape)
.set_attr<nnvm::FInferType>("FInferType", FullyConnectedType)
.set_attr<FCompute>("FCompute<cpu>", FullyConnectedCompute<cpu>)
-#if MXNET_USE_MKLDNN == 1
.set_attr<FComputeEx>("FComputeEx<cpu>", FullyConnectedComputeExCPU)
-#endif
.set_attr<nnvm::FGradient>("FGradient", FullyConnectedGrad{"_backward_FullyConnected"})
.add_argument("data", "NDArray-or-Symbol", "Input data.")
.add_argument("weight", "NDArray-or-Symbol", "Weight matrix.")
diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py
index 29850dc..729ec84 100644
--- a/tests/python/unittest/test_gluon_contrib.py
+++ b/tests/python/unittest/test_gluon_contrib.py
@@ -108,6 +108,22 @@ def test_conv_fill_shape():
check_rnn_forward(cell, mx.nd.ones((8, 3, 5, 7)))
assert cell.i2h_weight.shape[1] == 5, cell.i2h_weight.shape[1]
+@with_seed()
+def test_lstmp():
+ nhid = 100
+ nproj = 64
+ cell = contrib.rnn.LSTMPCell(nhid, nproj, prefix='rnn_')
+ inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
+ outputs, _ = cell.unroll(3, inputs)
+ outputs = mx.sym.Group(outputs)
+ expected_params = ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_h2r_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
+ expected_outputs = ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']
+ assert sorted(cell.collect_params().keys()) == expected_params
+ assert outputs.list_outputs() == expected_outputs, outputs.list_outputs()
+
+ args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
+ assert outs == [(10, nproj), (10, nproj), (10, nproj)]
+
@with_seed()
def test_vardrop():
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index 3d6f9d0..182e70c 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -872,6 +872,7 @@ def test_sparse_nd_check_format():
a = mx.nd.sparse.row_sparse_array((data_list, indices_list), shape=shape)
assertRaises(mx.base.MXNetError, a.check_format)
+@with_seed()
def test_sparse_nd_norm():
def check_sparse_nd_norm(stype, shape, density):
data, _ = rand_sparse_ndarray(shape, stype, density)
@@ -886,6 +887,23 @@ def test_sparse_nd_norm():
for density in densities:
check_sparse_nd_norm(stype, shape, density)
+@with_seed()
+def test_sparse_fc():
+ def check_sparse_fc(batch_size, dim_in, dim_out, stype):
+ data = rand_ndarray((batch_size, dim_in), stype, density=0.5)
+ weight = rand_ndarray((dim_out, dim_in), 'row_sparse', density=1)
+ bias = rand_ndarray((dim_out, 1), 'row_sparse', density=1)
+ out = mx.nd.sparse.FullyConnected(data, weight, num_hidden=dim_out, bias=bias)
+ data_dns = data.tostype('default')
+ weight_dns = weight.tostype('default')
+ out_dns = mx.nd.FullyConnected(data_dns, weight_dns, num_hidden=dim_out, bias=bias)
+ assert_almost_equal(out.asnumpy(), out_dns.asnumpy())
+
+ # test FC with row_sparse weight w/ density=1, dense data
+ check_sparse_fc(5, 10, 8, 'default')
+ # test FC with row_sparse weight w/ density=1, csr data (fallback)
+ check_sparse_fc(5, 10, 8, 'csr')
+
if __name__ == '__main__':
import nose
nose.runmodule()
--
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.