You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/22 01:01:58 UTC

[GitHub] szha closed pull request #10025: [MXNET-96] Language model with Google's billion words dataset

szha closed pull request #10025: [MXNET-96] Language model with Google's billion words dataset
URL: https://github.com/apache/incubator-mxnet/pull/10025
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/rnn/large_word_lm/custom_module.py b/example/rnn/large_word_lm/custom_module.py
new file mode 100644
index 00000000000..05d0fb75af7
--- /dev/null
+++ b/example/rnn/large_word_lm/custom_module.py
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import logging
+import warnings
+
+import mxnet as mx
+import numpy as np
+from mxnet.module import Module
+from mxnet.model import load_checkpoint
+
+class CustomModule(Module):
+
+    def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
+                 logger=logging, context=mx.cpu(), work_load_list=None,
+                 fixed_param_names=None, state_names=None, group2ctxs=None,
+                 compression_params=None):
+
+        super(CustomModule, self).__init__(symbol, data_names=data_names, label_names=label_names,
+                                           logger=logger, context=context, work_load_list=work_load_list,
+                                           fixed_param_names=fixed_param_names, state_names=state_names,
+                                           group2ctxs=group2ctxs, compression_params=compression_params)
+
+    def prepare_sparse_params(self, param_rowids):
+        '''Prepares the module for processing a data batch by pulling row_sparse
+        parameters from kvstore to all devices based on rowids.
+
+        Parameters
+        ----------
+        param_rowids : dict of str to NDArray of list of NDArrays
+        '''
+        if not self._kvstore:
+            return
+        assert(isinstance(param_rowids, dict))
+        for param_name, rowids in param_rowids.items():
+            if isinstance(rowids, (tuple, list)):
+                rowids_1d = []
+                for r in rowids:
+                    rowids_1d.append(r.reshape((-1,)).astype(np.int64))
+                rowid = mx.nd.concat(*rowids_1d, dim=0)
+            else:
+                rowid = rowids
+            param_idx = self._exec_group.param_names.index(param_name)
+            param_val = self._exec_group.param_arrays[param_idx]
+            self._kvstore.row_sparse_pull(param_name, param_val, row_ids=rowid,
+                                          priority=-param_idx)
+
+    @staticmethod
+    def load(prefix, epoch, load_optimizer_states=False, **kwargs):
+        """Creates a model from previously saved checkpoint.
+
+        Parameters
+        ----------
+        prefix : str
+            path prefix of saved model files. You should have
+            "prefix-symbol.json", "prefix-xxxx.params", and
+            optionally "prefix-xxxx.states", where xxxx is the
+            epoch number.
+        epoch : int
+            epoch to load.
+        load_optimizer_states : bool
+            whether to load optimizer states. Checkpoint needs
+            to have been made with save_optimizer_states=True.
+        data_names : list of str
+            Default is `('data')` for a typical model used in image classification.
+        label_names : list of str
+            Default is `('softmax_label')` for a typical model used in image
+            classification.
+        logger : Logger
+            Default is `logging`.
+        context : Context or list of Context
+            Default is ``cpu()``.
+        work_load_list : list of number
+            Default ``None``, indicating uniform workload.
+        fixed_param_names: list of str
+            Default ``None``, indicating no network parameters are fixed.
+        """
+        sym, args, auxs = load_checkpoint(prefix, epoch)
+        mod = CustomModule(symbol=sym, **kwargs)
+        mod._arg_params = args
+        mod._aux_params = auxs
+        mod.params_initialized = True
+        if load_optimizer_states:
+            mod._preload_opt_states = '%s-%04d.states'%(prefix, epoch)
+        return mod
+
+    def save_params(self, fname):
+        """Saves model parameters to file.
+        Parameters
+        ----------
+        fname : str
+            Path to output param file.
+        Examples
+        --------
+        >>> # An example of saving module parameters.
+        >>> mod.save_params('myfile')
+        """
+        arg_params, aux_params = self.get_params_from_kv(self._arg_params, self._aux_params)
+        save_dict = {('arg:%s' % k) : v.as_in_context(mx.cpu()) for k, v in arg_params.items()}
+        save_dict.update({('aux:%s' % k) : v.as_in_context(mx.cpu()) for k, v in aux_params.items()})
+        mx.nd.save(fname, save_dict)
+
+    def get_params_from_kv(self, arg_params, aux_params):
+        """ Copy data from kvstore to `arg_params` and `aux_params`.
+        Parameters
+        ----------
+        arg_params : list of NDArray
+            Target parameter arrays.
+        aux_params : list of NDArray
+            Target aux arrays.
+        Notes
+        -----
+        - This function will inplace update the NDArrays in arg_params and aux_params.
+        """
+        assert(self._kvstore is not None)
+        for name, block in zip(self._exec_group.param_names, self._exec_group.param_arrays):
+            assert(isinstance(block, list))
+            if block[0].stype == 'row_sparse':
+                row_ids = mx.nd.arange(start=0, stop=block[0].shape[0], dtype='int64')
+                self._kvstore.row_sparse_pull(name, arg_params[name], row_ids=row_ids)
+            else:
+                assert(block[0].stype == 'default')
+                self._kvstore.pull(name, out=arg_params[name])
+        if len(aux_params) > 0:
+            raise NotImplementedError()
+        return arg_params, aux_params
+
+    def clip_by_global_norm_per_ctx(self, max_norm=1.0, param_names=None):
+        """Clips gradient norm.
+
+        The norm is computed over all gradients together, as if they were
+         concatenated into a single vector. Gradients are modified in-place.
+
+        The method is first used in
+         `[ICML2013] On the difficulty of training recurrent neural networks`
+
+        Note that the gradients are concatenated per context in this implementation.
+
+        Examples
+        --------
+        An example of using clip_grad_norm to clip the gradient before updating the parameters::
+            >>> #Get the gradient via back-propagation
+            >>> net.forward_backward(data_batch=data_batch)
+            >>> norm_val = net.clip_by_global_norm(max_norm=2.0, param_names='w0')
+            >>> net.update()
+        """
+        assert self.binded and self.params_initialized and self.optimizer_initialized
+        num_ctx = len(self._exec_group.grad_arrays[0])
+        grad_array_per_ctx = [[] for i in range(num_ctx)]
+        assert(param_names is not None)
+        for param_name in param_names:
+            param_idx = self._exec_group.param_names.index(param_name)
+            grad_val = self._exec_group.grad_arrays[param_idx]
+            assert(len(grad_val) == num_ctx)
+            for i in range(num_ctx):
+                grad_array_per_ctx[i].append(grad_val[i])
+        norm_vals = []
+        for i in range(num_ctx):
+            mx.gluon.utils.clip_global_norm(grad_array_per_ctx[i], max_norm)
+
+    def rescale_grad(self, scale=None, param_name=None):
+        """ Rescale the gradient of provided parameters by a certain scale """
+        if scale is None or param_name is None:
+            return
+        param_idx = self._exec_group.param_names.index(param_name)
+        grad_vals = self._exec_group.grad_arrays[param_idx]
+        for grad in grad_vals:
+            grad[:] *= scale
diff --git a/example/rnn/large_word_lm/data.py b/example/rnn/large_word_lm/data.py
new file mode 100644
index 00000000000..b9cc3e8a89e
--- /dev/null
+++ b/example/rnn/large_word_lm/data.py
@@ -0,0 +1,202 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import numpy as np
+import codecs, glob, random, logging, collections
+
+class Vocabulary(object):
+    """ A dictionary for words.
+        Adapeted from @rafaljozefowicz's implementation.
+    """
+    def __init__(self):
+        self._token_to_id = {}
+        self._token_to_count = collections.Counter()
+        self._id_to_token = []
+        self._num_tokens = 0
+        self._total_count = 0
+        self._s_id = None
+        self._unk_id = None
+
+    @property
+    def num_tokens(self):
+        return self._num_tokens
+
+    @property
+    def unk(self):
+        return "<UNK>"
+
+    @property
+    def unk_id(self):
+        return self._unk_id
+
+    @property
+    def s(self):
+        return "<S>"
+
+    @property
+    def s_id(self):
+        return self._s_id
+
+    def add(self, token, count):
+        self._token_to_id[token] = self._num_tokens
+        self._token_to_count[token] = count
+        self._id_to_token.append(token)
+        self._num_tokens += 1
+        self._total_count += count
+
+    def finalize(self):
+        self._s_id = self.get_id(self.s)
+        self._unk_id = self.get_id(self.unk)
+
+    def get_id(self, token):
+        # Unseen token are mapped to UNK
+        return self._token_to_id.get(token, self.unk_id)
+
+    def get_token(self, id_):
+        return self._id_to_token[id_]
+
+    @staticmethod
+    def from_file(filename):
+        vocab = Vocabulary()
+        with codecs.open(filename, "r", "utf-8") as f:
+            for line in f:
+                word, count = line.strip().split()
+                vocab.add(word, int(count))
+        vocab.finalize()
+        return vocab
+
+class Dataset(object):
+    """ A dataset for truncated bptt with multiple sentences.
+        Adapeted from @rafaljozefowicz's implementation.
+     """
+    def __init__(self, vocab, file_pattern, shuffle=False):
+        self._vocab = vocab
+        self._file_pattern = file_pattern
+        self._shuffle = shuffle
+
+    def _parse_sentence(self, line):
+        s_id = self._vocab.s_id
+        return [s_id] + [self._vocab.get_id(word) for word in line.strip().split()] + [s_id]
+
+    def _parse_file(self, file_name):
+        logging.debug("Processing file: %s" % file_name)
+        with codecs.open(file_name, "r", "utf-8") as f:
+            lines = [line.strip() for line in f]
+            if not self._shuffle:
+                random.shuffle(lines)
+            logging.debug("Finished processing!")
+            for line in lines:
+                yield self._parse_sentence(line)
+
+    def _sentence_stream(self, file_stream):
+        for file_name in file_stream:
+            for sentence in self._parse_file(file_name):
+                yield sentence
+
+    def _iterate(self, sentences, batch_size, num_steps):
+        streams = [None] * batch_size
+        x = np.zeros([batch_size, num_steps], np.int32)
+        y = np.zeros([batch_size, num_steps], np.int32)
+        w = np.zeros([batch_size, num_steps], np.uint8)
+        while True:
+            x[:] = 0
+            y[:] = 0
+            w[:] = 0
+            for i in range(batch_size):
+                tokens_filled = 0
+                try:
+                    while tokens_filled < num_steps:
+                        if streams[i] is None or len(streams[i]) <= 1:
+                            streams[i] = next(sentences)
+                        num_tokens = min(len(streams[i]) - 1, num_steps - tokens_filled)
+                        x[i, tokens_filled:tokens_filled+num_tokens] = streams[i][:num_tokens]
+                        y[i, tokens_filled:tokens_filled + num_tokens] = streams[i][1:num_tokens+1]
+                        w[i, tokens_filled:tokens_filled + num_tokens] = 1
+                        streams[i] = streams[i][num_tokens:]
+                        tokens_filled += num_tokens
+                except StopIteration:
+                    pass
+            if not np.any(w):
+                return
+
+            yield x, y, w
+
+    def iterate_once(self, batch_size, num_steps):
+        def file_stream():
+            file_patterns = glob.glob(self._file_pattern)
+            if not self._shuffle:
+                random.shuffle(file_patterns)
+            for file_name in file_patterns:
+                yield file_name
+        for value in self._iterate(self._sentence_stream(file_stream()), batch_size, num_steps):
+            yield value
+
+    def iterate_forever(self, batch_size, num_steps):
+        def file_stream():
+            while True:
+                file_patterns = glob.glob(self._file_pattern)
+                if not self._shuffle:
+                    random.shuffle(file_patterns)
+                for file_name in file_patterns:
+                    yield file_name
+        for value in self._iterate(self._sentence_stream(file_stream()), batch_size, num_steps):
+            yield value
+
+class MultiSentenceIter(mx.io.DataIter):
+    """ An MXNet iterator that returns the a batch of sequence data and label each time.
+        It also returns a mask which indicates padded/missing data at the end of the dataset.
+        The iterator re-shuffles the data when reset is called.
+    """
+    def __init__(self, data_file, vocab, batch_size, bptt):
+        super(MultiSentenceIter, self).__init__()
+        self.batch_size = batch_size
+        self.bptt = bptt
+        self.provide_data = [('data', (batch_size, bptt), np.int32), ('mask', (batch_size, bptt))]
+        self.provide_label = [('label', (batch_size, bptt))]
+        self.vocab = vocab
+        self.data_file = data_file
+        self._dataset = Dataset(self.vocab, data_file, shuffle=True)
+        self._iter = self._dataset.iterate_once(batch_size, bptt)
+
+    def iter_next(self):
+        data = self._iter.next()
+        if data is None:
+            return False
+        self._next_data = mx.nd.array(data[0], dtype=np.int32)
+        self._next_label = mx.nd.array(data[1])
+        self._next_mask = mx.nd.array(data[2])
+        return True
+
+    def next(self):
+        if self.iter_next():
+            return mx.io.DataBatch(data=self.getdata(), label=self.getlabel())
+        else:
+            raise StopIteration
+
+    def reset(self):
+        self._dataset = Dataset(self.vocab, self.data_file, shuffle=False)
+        self._iter = self._dataset.iterate_once(self.batch_size, self.bptt)
+        self._next_data = None
+        self._next_label = None
+        self._next_mask = None
+
+    def getdata(self):
+        return [self._next_data, self._next_mask]
+
+    def getlabel(self):
+        return [self._next_label]
diff --git a/example/rnn/large_word_lm/get_vocab_file.sh b/example/rnn/large_word_lm/get_vocab_file.sh
new file mode 100755
index 00000000000..97fa29bf884
--- /dev/null
+++ b/example/rnn/large_word_lm/get_vocab_file.sh
@@ -0,0 +1,34 @@
+#!/usr/bin/env bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+echo ""
+echo "NOTE: This script only downloads the pre-processed vocabulary file. "
+echo "For the full training and testing dataset, please download from "
+echo "http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz"
+echo ""
+
+RNN_DIR=$(cd `dirname $0`; pwd)
+DATA_DIR="${RNN_DIR}/data/"
+
+if [[ ! -d "${DATA_DIR}" ]]; then
+  echo "${DATA_DIR} doesn't exist, will create one";
+  mkdir -p ${DATA_DIR}
+fi
+
+wget -P ${DATA_DIR} wget https://s3-us-west-2.amazonaws.com/sparse-dataset/gbw/1b_word_vocab.txt;
diff --git a/example/rnn/large_word_lm/model.py b/example/rnn/large_word_lm/model.py
new file mode 100644
index 00000000000..7ee010efb71
--- /dev/null
+++ b/example/rnn/large_word_lm/model.py
@@ -0,0 +1,181 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Licensed to the Apache Software Soundation (ASS) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASS licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OS ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import mxnet.symbol as S
+import numpy as np
+
+def cross_entropy_loss(inputs, labels, rescale_loss=1):
+    """ cross entropy loss with a mask """
+    criterion = mx.gluon.loss.SoftmaxCrossEntropyLoss(weight=rescale_loss)
+    loss = criterion(inputs, labels)
+    mask = S.var('mask')
+    loss = loss * S.reshape(mask, shape=(-1,))
+    return S.make_loss(loss.mean())
+
+def rnn(bptt, vocab_size, num_embed, nhid, num_layers, dropout, num_proj, batch_size):
+    """ word embedding + LSTM Projected """
+    embed = mx.sym.contrib.SparseEmbedding
+    state_names = []
+    data = S.var('data')
+    weight = S.var("encoder_weight", stype='row_sparse')
+    embed = embed(data=data, weight=weight, input_dim=vocab_size,
+                  output_dim=num_embed, name='embed', deterministic=True)
+    states = []
+    outputs = S.Dropout(embed, p=dropout)
+    for i in range(num_layers):
+        prefix = 'lstmp%d_' % i
+        init_h = S.var(prefix + 'init_h', shape=(batch_size, num_proj), init=mx.init.Zero())
+        init_c = S.var(prefix + 'init_c', shape=(batch_size, nhid), init=mx.init.Zero())
+        state_names += [prefix + 'init_h', prefix + 'init_c']
+        lstmp = mx.gluon.contrib.rnn.LSTMPCell(nhid, num_proj)
+        outputs, next_states = lstmp.unroll(bptt, outputs, begin_state=[init_h, init_c], \
+                                            layout='NTC', merge_outputs=True)
+        outputs = S.Dropout(outputs, p=dropout)
+        states += [S.stop_gradient(s) for s in next_states]
+    outputs = S.reshape(outputs, shape=(-1, num_proj))
+
+    trainable_lstm_args = []
+    for arg in outputs.list_arguments():
+        if 'lstmp' in arg and 'init' not in arg:
+            trainable_lstm_args.append(arg)
+    return outputs, states, trainable_lstm_args, state_names
+
+def sampled_softmax(num_classes, num_samples, in_dim, inputs, weight, bias,
+                    sampled_values, remove_accidental_hits=True):
+        """ Sampled softmax via importance sampling.
+            This under-estimates the full softmax and is only used for training.
+        """
+        # inputs = (n, in_dim)
+        embed = mx.sym.contrib.SparseEmbedding
+        sample, prob_sample, prob_target = sampled_values
+
+        # (num_samples, )
+        sample = S.var('sample', shape=(num_samples,), dtype='float32')
+        # (n, )
+        label = S.var('label')
+        label = S.reshape(label, shape=(-1,), name="label_reshape")
+        # (num_samples+n, )
+        sample_label = S.concat(sample, label, dim=0)
+        # lookup weights and biases
+        # (num_samples+n, dim)
+        sample_target_w = embed(data=sample_label, weight=weight,
+                                     input_dim=num_classes, output_dim=in_dim,
+                                     deterministic=True)
+        # (num_samples+n, 1)
+        sample_target_b = embed(data=sample_label, weight=bias,
+                                input_dim=num_classes, output_dim=1, deterministic=True)
+        # (num_samples, dim)
+        sample_w = S.slice(sample_target_w, begin=(0, 0), end=(num_samples, None))
+        target_w = S.slice(sample_target_w, begin=(num_samples, 0), end=(None, None))
+        sample_b = S.slice(sample_target_b, begin=(0, 0), end=(num_samples, None))
+        target_b = S.slice(sample_target_b, begin=(num_samples, 0), end=(None, None))
+
+        # target
+        # (n, 1)
+        true_pred = S.sum(target_w * inputs, axis=1, keepdims=True) + target_b
+        # samples
+        # (n, num_samples)
+        sample_b = S.reshape(sample_b, (-1,))
+        sample_pred = S.FullyConnected(inputs, weight=sample_w, bias=sample_b,
+                                       num_hidden=num_samples)
+
+        # remove accidental hits
+        if remove_accidental_hits:
+            label_v = S.reshape(label, (-1, 1))
+            sample_v = S.reshape(sample, (1, -1))
+            neg = S.broadcast_equal(label_v, sample_v) * -1e37
+            sample_pred = sample_pred + neg
+
+        prob_sample = S.reshape(prob_sample, shape=(1, num_samples))
+        p_target = true_pred - S.log(prob_target)
+        p_sample = S.broadcast_sub(sample_pred, S.log(prob_sample))
+
+        # return logits and new_labels
+        # (n, 1+num_samples)
+        logits = S.concat(p_target, p_sample, dim=1)
+        new_targets = S.zeros_like(label)
+        return logits, new_targets
+
+def generate_samples(label, num_splits, num_samples, num_classes):
+    """ Split labels into `num_splits` and
+        generate candidates based on log-uniform distribution.
+    """
+    def listify(x):
+        return x if isinstance(x, list) else [x]
+    label_splits = listify(label.split(num_splits, axis=0))
+    prob_samples = []
+    prob_targets = []
+    samples = []
+    for label_split in label_splits:
+        label_split_2d = label_split.reshape((-1,1))
+        sampled_value = mx.nd.contrib.rand_zipfian(label_split_2d, num_samples, num_classes)
+        sampled_classes, exp_cnt_true, exp_cnt_sampled = sampled_value
+        samples.append(sampled_classes.astype(np.float32))
+        prob_targets.append(exp_cnt_true.astype(np.float32))
+        prob_samples.append(exp_cnt_sampled.astype(np.float32))
+    return samples, prob_samples, prob_targets
+
+class Model():
+    """ LSTMP with Importance Sampling """
+    def __init__(self, args, ntokens, rescale_loss):
+        out = rnn(args.bptt, ntokens, args.emsize, args.nhid, args.nlayers,
+                  args.dropout, args.num_proj, args.batch_size)
+        rnn_out, self.last_states, self.lstm_args, self.state_names = out
+        # decoder weight and bias
+        decoder_w = S.var("decoder_weight", stype='row_sparse')
+        decoder_b = S.var("decoder_bias", shape=(ntokens, 1), stype='row_sparse')
+
+        # sampled softmax for training
+        sample = S.var('sample', shape=(args.k,))
+        prob_sample = S.var("prob_sample", shape=(args.k,))
+        prob_target = S.var("prob_target")
+        self.sample_names = ['sample', 'prob_sample', 'prob_target']
+        logits, new_targets = sampled_softmax(ntokens, args.k, args.num_proj,
+                                              rnn_out, decoder_w, decoder_b,
+                                              [sample, prob_sample, prob_target])
+        self.train_loss = cross_entropy_loss(logits, new_targets, rescale_loss=rescale_loss)
+
+        # full softmax for testing
+        eval_logits = S.FullyConnected(data=rnn_out, weight=decoder_w,
+                                       num_hidden=ntokens, name='decode_fc', bias=decoder_b)
+        label = S.Variable('label')
+        label = S.reshape(label, shape=(-1,))
+        self.eval_loss = cross_entropy_loss(eval_logits, label)
+
+    def eval(self):
+        return S.Group(self.last_states + [self.eval_loss])
+
+    def train(self):
+        return S.Group(self.last_states + [self.train_loss])
diff --git a/example/rnn/large_word_lm/readme.md b/example/rnn/large_word_lm/readme.md
new file mode 100644
index 00000000000..d74ffbd1a21
--- /dev/null
+++ b/example/rnn/large_word_lm/readme.md
@@ -0,0 +1,66 @@
+# Large-Scale Language Model
+This example implements the baseline model in
+[Exploring the Limits of Language Modeling](https://arxiv.org/abs/1602.02410) on the
+[Google 1-Billion Word](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) (GBW) dataset.
+
+This example reaches **41.97 perplexity** after 5 training epochs on a 1-layer, 2048-unit, 512-projection LSTM Language Model.
+The result is slightly better than the one reported in the paper(43.7 perplexity).
+The main differences with the original implementation include:
+* Synchronized gradient updates instead of asynchronized updates
+* Noise candidates are sampled with replacement
+
+Each epoch for training takes around 80 minutes on a p3.8xlarge instance, which comes with 4 Volta V100 GPUs.
+
+# Setup - Original Data Format
+1. Download 1-Billion Word Dataset - [Link](http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz)
+2. Download pre-processed vocabulary file which maps tokens into ids.
+
+# Run the Script
+```
+usage: train.py [-h] [--data DATA] [--test TEST] [--vocab VOCAB]
+                [--emsize EMSIZE] [--nhid NHID] [--num-proj NUM_PROJ]
+                [--nlayers NLAYERS] [--epochs EPOCHS]
+                [--batch-size BATCH_SIZE] [--dropout DROPOUT] [--eps EPS]
+                [--bptt BPTT] [--k K] [--gpus GPUS]
+                [--log-interval LOG_INTERVAL] [--seed SEED]
+                [--checkpoint-dir CHECKPOINT_DIR] [--lr LR] [--clip CLIP]
+                [--rescale-embed RESCALE_EMBED]
+
+Language Model on GBW
+
+optional arguments:
+  -h, --help            show this help message and exit
+  --data DATA           location of the training data
+  --test TEST           location of the test data
+  --vocab VOCAB         location of the corpus vocabulary file
+  --emsize EMSIZE       size of word embeddings
+  --nhid NHID           number of hidden units per layer
+  --num-proj NUM_PROJ   number of projection units per layer
+  --nlayers NLAYERS     number of LSTM layers
+  --epochs EPOCHS       number of epoch for training
+  --batch-size BATCH_SIZE
+                        batch size per gpu
+  --dropout DROPOUT     dropout applied to layers (0 = no dropout)
+  --eps EPS             epsilon for adagrad
+  --bptt BPTT           sequence length
+  --k K                 number of noise samples for estimation
+  --gpus GPUS           list of gpus to run, e.g. 0 or 0,2,5. empty means
+                        using gpu(0).
+  --log-interval LOG_INTERVAL
+                        report interval
+  --seed SEED           random seed
+  --checkpoint-dir CHECKPOINT_DIR
+                        dir for checkpoint
+  --lr LR               initial learning rate
+  --clip CLIP           gradient clipping by global norm.
+  --rescale-embed RESCALE_EMBED
+                        scale factor for the gradients of the embedding layer
+```
+
+To reproduce the result, run
+```
+train.py --gpus=0,1,2,3 --clip=1 --lr=0.05 --dropout=0.01 --eps=0.0001 --rescale-embed=128
+--test=/path/to/heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050
+--data=/path/to/training-monolingual.tokenized.shuffled/*
+# ~42 perplexity for 5 epochs of training
+```
diff --git a/example/rnn/large_word_lm/run_utils.py b/example/rnn/large_word_lm/run_utils.py
new file mode 100644
index 00000000000..7650530e80d
--- /dev/null
+++ b/example/rnn/large_word_lm/run_utils.py
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import argparse, time, logging, math
+
+def get_parser():
+    parser = argparse.ArgumentParser(description='Language Model on GBW')
+    parser.add_argument('--data', type=str,
+                        default='/path/to/training-monolingual.tokenized.shuffled/*',
+                        help='location of the training data')
+    parser.add_argument('--test', type=str,
+                        default='/path/to/heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050',
+                        help='location of the test data')
+    parser.add_argument('--vocab', type=str, default='./data/1b_word_vocab.txt',
+                        help='location of the corpus vocabulary file')
+    parser.add_argument('--emsize', type=int, default=512,
+                        help='size of word embeddings')
+    parser.add_argument('--nhid', type=int, default=2048,
+                        help='number of hidden units per layer')
+    parser.add_argument('--num-proj', type=int, default=512,
+                        help='number of projection units per layer')
+    parser.add_argument('--nlayers', type=int, default=1,
+                        help='number of LSTM layers')
+    parser.add_argument('--epochs', type=int, default=8,
+                        help='number of epoch for training')
+    parser.add_argument('--batch-size', type=int, default=128,
+                        help='batch size per gpu')
+    parser.add_argument('--dropout', type=float, default=0.1,
+                        help='dropout applied to layers (0 = no dropout)')
+    parser.add_argument('--eps', type=float, default=0.0001,
+                        help='epsilon for adagrad')
+    parser.add_argument('--bptt', type=int, default=20,
+                        help='sequence length')
+    parser.add_argument('--k', type=int, default=8192,
+                        help='number of noise samples for estimation')
+    parser.add_argument('--gpus', type=str,
+                        help='list of gpus to run, e.g. 0 or 0,2,5. empty means using gpu(0).')
+    parser.add_argument('--log-interval', type=int, default=200,
+                        help='report interval')
+    parser.add_argument('--seed', type=int, default=1,
+                        help='random seed')
+    parser.add_argument('--checkpoint-dir', type=str, default='./checkpoint/cp',
+                        help='dir for checkpoint')
+    parser.add_argument('--lr', type=float, default=0.1,
+                        help='initial learning rate')
+    parser.add_argument('--clip', type=float, default=1,
+                        help='gradient clipping by global norm.')
+    parser.add_argument('--rescale-embed', type=float, default=None,
+                        help='scale factor for the gradients of the embedding layer')
+    return parser
+
+def evaluate(mod, data_iter, epoch, log_interval):
+    """ Run evaluation on cpu. """
+    start = time.time()
+    total_L = 0.0
+    nbatch = 0
+    mod.set_states(value=0)
+    for batch in data_iter:
+        mod.forward(batch, is_train=False)
+        outputs = mod.get_outputs(merge_multi_context=False)
+        states = outputs[:-1]
+        total_L += outputs[-1][0].asscalar()
+        mod.set_states(states=states)
+        nbatch += 1
+        if (nbatch + 1) % log_interval == 0:
+            logging.info("Eval batch %d loss : %.7f" % (nbatch, total_L / nbatch))
+    data_iter.reset()
+    loss = total_L / nbatch
+    ppl = math.exp(loss) if loss < 100 else 1e37
+    end = time.time()
+    logging.info('Iter[%d]\t\t CE loss %.7f, ppl %.7f. Eval duration = %.2f seconds ' % \
+                 (epoch, loss, ppl, end - start))
+    return loss
diff --git a/example/rnn/large_word_lm/train.py b/example/rnn/large_word_lm/train.py
new file mode 100644
index 00000000000..a1b4e3140df
--- /dev/null
+++ b/example/rnn/large_word_lm/train.py
@@ -0,0 +1,152 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import mxnet as mx
+import mxnet.symbol as S
+import run_utils
+from data import MultiSentenceIter, Vocabulary
+from model import *
+from custom_module import CustomModule
+import os, math, logging, sys
+
+if __name__ == '__main__':
+    # parser
+    parser = run_utils.get_parser()
+    args = parser.parse_args()
+    head = '%(asctime)-15s %(message)s'
+    ctx = [mx.gpu(int(i)) for i in args.gpus.split(',')] if args.gpus else [mx.gpu()]
+    ngpus = len(ctx)
+    rescale_loss = args.bptt
+
+    # logging
+    logging.basicConfig(level=logging.INFO, format=head)
+    logging.info(args)
+    logging.debug(sys.argv)
+
+    # seeding
+    mx.random.seed(args.seed)
+    np.random.seed(args.seed)
+
+    # data
+    vocab = Vocabulary.from_file(args.vocab)
+    ntokens = vocab.num_tokens
+    train_data = mx.io.PrefetchingIter(MultiSentenceIter(args.data, vocab,
+                                       args.batch_size * ngpus, args.bptt))
+    # model
+    model = Model(args, ntokens, rescale_loss)
+    train_loss_and_states = model.train()
+    eval_loss_and_states = model.eval()
+
+    # training module
+    data_names, label_names = ['data', 'mask'], ['label']
+    eval_state_names = model.state_names
+    num_sample_names = len(model.sample_names)
+    train_state_names = model.state_names + model.sample_names
+
+    module = CustomModule(symbol=train_loss_and_states, context=ctx,
+                          state_names=train_state_names,
+                          data_names=data_names, label_names=label_names)
+    module.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label)
+    module.init_params(initializer=mx.init.Xavier(factor_type='out'))
+
+    # create kvstore and sparse optimizer
+    kvstore = mx.kv.create('device')
+    optimizer = mx.optimizer.create('adagrad', learning_rate=args.lr, \
+                                    rescale_grad=1.0/ngpus, eps=args.eps)
+    module.init_optimizer(optimizer=optimizer, kvstore=kvstore)
+
+    # speedometer
+    num_words_per_batch = args.batch_size * ngpus * args.bptt
+    speedometer = mx.callback.Speedometer(num_words_per_batch, args.log_interval)
+
+    # training loop
+    logging.info("Training started ... ")
+    for epoch in range(args.epochs):
+        total_L = mx.nd.array([0.0])
+        nbatch = 0
+        # reset initial LSTMP states
+        module.set_states(value=0)
+        state_cache = module.get_states(merge_multi_context=False)[:-num_sample_names]
+        next_batch = train_data.next()
+        next_sampled_values = generate_samples(next_batch.label[0], ngpus, args.k, ntokens)
+        stop_iter = False
+        while not stop_iter:
+            batch = next_batch
+            state_cache += next_sampled_values
+            # propagate LSTMP states from the previous batch
+            module.set_states(states=state_cache)
+            # selectively pull row_sparse weight to multiple devices based on the data batch
+            target_ids = [batch.label[0]]
+            sampled_ids = next_sampled_values[0]
+            param_rowids = {'encoder_weight': batch.data[0],
+                            'decoder_weight': sampled_ids + target_ids,
+                            'decoder_bias': sampled_ids + target_ids}
+            module.prepare_sparse_params(param_rowids)
+            # forward
+            module.forward(batch)
+            try:
+                # prefetch the next batch of data and samples
+                next_batch = train_data.next()
+                next_sampled_values = generate_samples(next_batch.label[0], ngpus,
+                                                       args.k, ntokens)
+            except StopIteration:
+                stop_iter = True
+            # cache LSTMP states of the current batch
+            outputs = module.get_outputs(merge_multi_context=False)
+            state_cache = outputs[:-1]
+            # backward
+            module.backward()
+            for g in range(ngpus):
+                total_L += outputs[-1][g].copyto(mx.cpu()) / ngpus
+
+            # rescaling the gradient for embedding layer emperically leads to faster convergence
+            module.rescale_grad(args.rescale_embed, 'encoder_weight')
+            # clip lstm params on each device based on norm
+            norm = module.clip_by_global_norm_per_ctx(max_norm=args.clip, param_names=model.lstm_args)
+            # update parameters
+            module.update()
+            speedometer_param = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
+                                                       eval_metric=None, locals=locals())
+            speedometer(speedometer_param)
+            # update training metric
+            if nbatch % args.log_interval == 0 and nbatch > 0:
+                cur_L = total_L.asscalar() / args.log_interval / rescale_loss
+                ppl = math.exp(cur_L) if cur_L < 100 else 1e36
+                logging.info('Iter[%d] Batch [%d] \tloss %.7f, ppl %.7f'%(epoch, nbatch, cur_L, ppl))
+                total_L[:] = 0.0
+            nbatch += 1
+
+        # run evaluation with full softmax on cpu
+        module.save_checkpoint(args.checkpoint_dir, epoch, save_optimizer_states=False)
+        cpu_train_mod = CustomModule.load(args.checkpoint_dir, epoch, context=mx.cpu(),
+                                          state_names=train_state_names,
+                                          data_names=data_names, label_names=label_names)
+        # eval data iter
+        eval_data = mx.io.PrefetchingIter(MultiSentenceIter(args.test, vocab,
+                                          args.batch_size, args.bptt))
+        cpu_train_mod.bind(data_shapes=eval_data.provide_data, label_shapes=eval_data.provide_label)
+
+        # eval module
+        eval_module = CustomModule(symbol=eval_loss_and_states, context=mx.cpu(), data_names=data_names,
+                                   label_names=label_names, state_names=eval_state_names)
+        # use `shared_module` to share parameter with the training module
+        eval_module.bind(data_shapes=eval_data.provide_data, label_shapes=eval_data.provide_label,
+                         shared_module=cpu_train_mod, for_training=False)
+        val_L = run_utils.evaluate(eval_module, eval_data, epoch, 20)
+        train_data.reset()
+    logging.info("Training completed. ")
diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
index b964c712ace..1b9afee14bf 100644
--- a/python/mxnet/gluon/contrib/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
@@ -17,13 +17,12 @@
 
 # coding: utf-8
 """Definition of various recurrent neural network cells."""
-__all__ = ['VariationalDropoutCell']
+__all__ = ['VariationalDropoutCell', 'LSTMPCell']
 
-from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell
+from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell, HybridRecurrentCell
 from ...rnn.rnn_cell import _format_sequence, _get_begin_state, _mask_sequence_variable_length
 from ... import tensor_types
 
-
 class VariationalDropoutCell(ModifierCell):
     """
     Applies Variational Dropout on base cell.
@@ -193,3 +192,126 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
             outputs = _mask_sequence_variable_length(F, outputs, length, valid_length, axis,
                                                      merge_outputs)
         return outputs, states
+
+
+class LSTMPCell(HybridRecurrentCell):
+    r"""Long-Short Term Memory Projected (LSTMP) network cell.
+    (https://arxiv.org/abs/1402.1128)
+    Each call computes the following function:
+    .. math::
+        \begin{array}{ll}
+        i_t = sigmoid(W_{ii} x_t + b_{ii} + W_{ri} r_{(t-1)} + b_{ri}) \\
+        f_t = sigmoid(W_{if} x_t + b_{if} + W_{rf} r_{(t-1)} + b_{rf}) \\
+        g_t = \tanh(W_{ig} x_t + b_{ig} + W_{rc} r_{(t-1)} + b_{rg}}) \\
+        o_t = sigmoid(W_{io} x_t + b_{io} + W_{ro} r_{(t-1)} + b_{ro}) \\
+        c_t = f_t * c_{(t-1)} + i_t * g_t \\
+        h_t = o_t * \tanh(c_t) \\
+        r_t = W_{hr} h_t
+        \end{array}
+    where :math:`r_t` is the projected recurrent activation at time `t`,
+    math:`h_t` is the hidden state at time `t`, :math:`c_t` is the
+    cell state at time `t`, :math:`x_t` is the input at time `t`, and :math:`i_t`,
+    :math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell, and
+    out gates, respectively.
+    Parameters
+    ----------
+    hidden_size : int
+        Number of units in cell state symbol.
+    projection_size : int
+        Number of units in output symbol.
+    i2h_weight_initializer : str or Initializer
+        Initializer for the input weights matrix, used for the linear
+        transformation of the inputs.
+    h2h_weight_initializer : str or Initializer
+        Initializer for the recurrent weights matrix, used for the linear
+        transformation of the hidden state.
+    h2r_weight_initializer : str or Initializer
+        Initializer for the projection weights matrix, used for the linear
+        transformation of the recurrent state.
+    i2h_bias_initializer : str or Initializer, default 'lstmbias'
+        Initializer for the bias vector. By default, bias for the forget
+        gate is initialized to 1 while all other biases are initialized
+        to zero.
+    h2h_bias_initializer : str or Initializer
+        Initializer for the bias vector.
+    prefix : str, default 'lstmp_'
+        Prefix for name of `Block`s
+        (and name of weight if params is `None`).
+    params : Parameter or None
+        Container for weight sharing between cells.
+        Created if `None`.
+    Inputs:
+        - **data**: input tensor with shape `(batch_size, input_size)`.
+        - **states**: a list of two initial recurrent state tensors, with shape
+          `(batch_size, projection_size)` and `(batch_size, hidden_size)` respectively.
+    Outputs:
+        - **out**: output tensor with shape `(batch_size, num_hidden)`.
+        - **next_states**: a list of two output recurrent state tensors. Each has
+          the same shape as `states`.
+    """
+    def __init__(self, hidden_size, projection_size,
+                 i2h_weight_initializer=None, h2h_weight_initializer=None,
+                 h2r_weight_initializer=None,
+                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
+                 input_size=0, prefix=None, params=None):
+        super(LSTMPCell, self).__init__(prefix=prefix, params=params)
+
+        self._hidden_size = hidden_size
+        self._input_size = input_size
+        self._projection_size = projection_size
+        self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size),
+                                          init=i2h_weight_initializer,
+                                          allow_deferred_init=True)
+        self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, projection_size),
+                                          init=h2h_weight_initializer,
+                                          allow_deferred_init=True)
+        self.h2r_weight = self.params.get('h2r_weight', shape=(projection_size, hidden_size),
+                                          init=h2r_weight_initializer,
+                                          allow_deferred_init=True)
+        self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,),
+                                        init=i2h_bias_initializer,
+                                        allow_deferred_init=True)
+        self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,),
+                                        init=h2h_bias_initializer,
+                                        allow_deferred_init=True)
+
+    def state_info(self, batch_size=0):
+        return [{'shape': (batch_size, self._projection_size), '__layout__': 'NC'},
+                {'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}]
+
+    def _alias(self):
+        return 'lstmp'
+
+    def __repr__(self):
+        s = '{name}({mapping})'
+        shape = self.i2h_weight.shape
+        proj_shape = self.h2r_weight.shape
+        mapping = '{0} -> {1} -> {2}'.format(shape[1] if shape[1] else None,
+                                             shape[0], proj_shape[0])
+        return s.format(name=self.__class__.__name__,
+                        mapping=mapping,
+                        **self.__dict__)
+
+    # pylint: disable= arguments-differ
+    def hybrid_forward(self, F, inputs, states, i2h_weight,
+                       h2h_weight, h2r_weight, i2h_bias, h2h_bias):
+        prefix = 't%d_'%self._counter
+        i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,
+                               num_hidden=self._hidden_size*4, name=prefix+'i2h')
+        h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
+                               num_hidden=self._hidden_size*4, name=prefix+'h2h')
+        gates = i2h + h2h
+        slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice')
+        in_gate = F.Activation(slice_gates[0], act_type="sigmoid", name=prefix+'i')
+        forget_gate = F.Activation(slice_gates[1], act_type="sigmoid", name=prefix+'f')
+        in_transform = F.Activation(slice_gates[2], act_type="tanh", name=prefix+'c')
+        out_gate = F.Activation(slice_gates[3], act_type="sigmoid", name=prefix+'o')
+        next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform,
+                                   name=prefix+'state')
+        hidden = F._internal._mul(out_gate, F.Activation(next_c, act_type="tanh"),
+                                  name=prefix+'hidden')
+        next_r = F.FullyConnected(data=hidden, num_hidden=self._projection_size,
+                                  weight=h2r_weight, no_bias=True, name=prefix+'out')
+
+        return next_r, [next_r, next_c]
+    # pylint: enable= arguments-differ
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 7d31a31b839..ee97649768b 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -802,17 +802,17 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
           CHECK_EQ(inferred_stype, arg_nd_stype)
             << "Inferred stype does not match shared_exec.arg_array's stype"
                " Therefore, the allocated memory for shared_exec.arg_array cannot"
-               " be resued for creating NDArray of the argument"
+               " be resued for creating NDArray of the argument "
             << arg_name << " for the current executor";
           CHECK_EQ(inferred_shape, in_arg_nd.shape())
             << "Inferred shape does not match shared_exec.arg_array's shape"
                " Therefore, the allocated memory for shared_exec.arg_array cannot"
-               " be resued for creating NDArray of the argument"
+               " be resued for creating NDArray of the argument "
             << arg_name << " for the current executor";
           CHECK_EQ(inferred_dtype, in_arg_nd.dtype())
             << "Inferred dtype does not match shared_exec.arg_array's dtype"
                " Therefore, the allocated memory for shared_exec.arg_array cannot"
-               " be resued for creating NDArray of the argument"
+               " be resued for creating NDArray of the argument "
             << arg_name << " for the current executor";
           in_arg_vec->emplace_back(in_arg_nd);
         } else {
diff --git a/src/operator/nn/fully_connected-inl.h b/src/operator/nn/fully_connected-inl.h
index e8e95643e64..7eba2e20e57 100644
--- a/src/operator/nn/fully_connected-inl.h
+++ b/src/operator/nn/fully_connected-inl.h
@@ -95,11 +95,20 @@ void FCForward(const OpContext &ctx, const FullyConnectedParam &param,
         Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s);
   }
 
+  CHECK_EQ(data.shape_[1], wmat.shape_[1])
+    << "Incomplete weight tensor detected: weight.data().shape[1] != prod(data.data().shape[1:])."
+       " This is not supported by FCForward. If weight is in row_sparse format,"
+       " please make sure all row ids are present.";
   // Legacy approach shown here for comparison:
   //   out = dot(data, wmat.T());
   linalg_gemm(data, wmat, out, false, true, s);
   if (!param.no_bias) {
-    Tensor<xpu, 1, DType> bias = in_data[fullc::kBias].get<xpu, 1, DType>(s);
+    Tensor<xpu, 1, DType> bias = in_data[fullc::kBias].get_with_shape<xpu, 1, DType>(
+      Shape1(wmat.shape_[0]), s);
+    CHECK_EQ(bias.shape_[0], wmat.shape_[0])
+      << "Incomplete bias tensor detected: bias.data().shape[1] != weight.data().shape[0]."
+         " This is not supported by FCForward. If bias is in row_sparse format, please"
+         " make sure all row ids are present.";
     out += repmat(bias, data.size(0));
   }
 }
diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc
index 4362408a23a..75d594ffd91 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -56,7 +56,10 @@ static bool FullyConnectedShape(const nnvm::NodeAttrs& attrs,
   }
   SHAPE_ASSIGN_CHECK(*in_shape, fullc::kWeight, Shape2(param.num_hidden, num_input));
   if (!param.no_bias) {
-    SHAPE_ASSIGN_CHECK(*in_shape, fullc::kBias, Shape1(param.num_hidden));
+    if (!shape_assign(&(*in_shape)[fullc::kBias], Shape1(param.num_hidden)) &&
+        !shape_assign(&(*in_shape)[fullc::kBias], Shape2(param.num_hidden, 1))) {
+      LOG(FATAL) << "Unexpected shape for bias " << (*in_shape)[fullc::kBias];
+    }
   }
 
   if (!param.flatten) {
@@ -73,22 +76,67 @@ static bool FullyConnectedShape(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
-#if MXNET_USE_MKLDNN == 1
 void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs,
                                 const OpContext &ctx,
                                 const std::vector<NDArray> &inputs,
                                 const std::vector<OpReqType> &req,
                                 const std::vector<NDArray> &outputs) {
-  if (SupportMKLDNN(inputs[0])) {
-    MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
-    MKLDNNFCForward(attrs, ctx, inputs, req, outputs);
-    MKLDNN_OPCHECK_RUN(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req,
-                       outputs);
+  const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
+  const bool valid_data = inputs[0].storage_type() == kDefaultStorage;
+  const bool valid_weight = inputs[1].storage_type() == kDefaultStorage ||
+                            inputs[1].storage_type() == kRowSparseStorage;
+  const bool valid_out = outputs[0].storage_type() == kDefaultStorage;
+  bool valid_bias = true;
+  if (!param.no_bias) {
+    valid_bias = inputs[2].storage_type() == kDefaultStorage ||
+                 inputs[2].storage_type() == kRowSparseStorage;
+  }
+#if MXNET_USE_MKLDNN == 1
+  if (common::ContainsOnlyStorage(inputs, kDefaultStorage) &&
+      common::ContainsOnlyStorage(outputs, kDefaultStorage)) {
+    if (SupportMKLDNN(inputs[0])) {
+      MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+      MKLDNNFCForward(attrs, ctx, inputs, req, outputs);
+      MKLDNN_OPCHECK_RUN(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req,
+                         outputs);
+    } else {
+      FallBackCompute(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req, outputs);
+    }
     return;
+  } else if (valid_data && valid_weight && valid_bias && valid_out) {
+    // inputs
+    std::vector<NDArray> temp_ndarrays;
+    std::vector<TBlob> in_blobs;
+    for (const NDArray& in : inputs) {
+      // if ndarray is in default storage and MKLDNN is available,
+      // need to make sure cpu layout data is used, instead of MKL layout
+      if (in.storage_type() == kDefaultStorage) {
+        temp_ndarrays.push_back(in.Reorder2Default());
+        in_blobs.emplace_back(temp_ndarrays.back().data());
+      } else {
+        in_blobs.emplace_back(in.data());
+      }
+    }
+    // output
+    if (req[0] == kWriteTo) const_cast<NDArray &>(outputs[0]).InvalidateMKLDNNData();
+    FullyConnectedCompute<cpu>(attrs, ctx, in_blobs, req, {outputs[0].data()});
+  } else {
+    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+  }
+#else
+  if (valid_data && valid_weight && valid_bias && valid_out) {
+    std::vector<TBlob> in_blobs(inputs.size());
+    for (size_t i = 0; i < in_blobs.size(); i++) in_blobs[i] = inputs[i].data();
+    std::vector<TBlob> out_blobs(outputs.size());
+    for (size_t i = 0; i < out_blobs.size(); i++) out_blobs[i] = outputs[i].data();
+    FullyConnectedCompute<cpu>(attrs, ctx, in_blobs, req, out_blobs);
+  } else {
+    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
   }
-  FallBackCompute(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req, outputs);
+#endif
 }
 
+#if MXNET_USE_MKLDNN == 1
 void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs,
                                     const OpContext &ctx,
                                     const std::vector<NDArray> &inputs,
@@ -129,19 +177,27 @@ inline static bool FCStorageType(const nnvm::NodeAttrs& attrs,
                                  std::vector<int> *in_attrs,
                                  std::vector<int> *out_attrs) {
   const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
-  uint32_t in_expected = param.no_bias ? 2 : 3;
+  const bool valid_data = in_attrs->at(0) == kDefaultStorage;
+  const bool valid_weight = in_attrs->at(1) == kDefaultStorage ||
+                            in_attrs->at(1) == kRowSparseStorage;
+  bool valid_bias = true;
+  uint32_t in_expected = 2;
+  if (!param.no_bias) {
+    in_expected = 3;
+    valid_bias = in_attrs->at(2) == kDefaultStorage || in_attrs->at(2) == kRowSparseStorage;
+  }
   CHECK_EQ(in_attrs->size(), in_expected);
   CHECK_EQ(out_attrs->size(), 1);
-
-  DispatchMode wanted_mode;
-#if MXNET_USE_MKLDNN == 1
-  if (dev_mask == mshadow::cpu::kDevMask)
-    wanted_mode = DispatchMode::kFComputeEx;
-  else
-#endif
-    wanted_mode = DispatchMode::kFCompute;
-  return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
-                             dispatch_mode, wanted_mode);
+  // dispatch to kFComputeEx is fine even if all inputs are dense and no MKL is present
+  bool dispatched = false;
+  if (!dispatched && valid_data && valid_weight && valid_bias) {
+    dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFComputeEx);
+  }
+  if (!dispatched) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
+  return dispatched;
 }
 
 inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
@@ -170,6 +226,7 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
 DMLC_REGISTER_PARAMETER(FullyConnectedParam);
 
 NNVM_REGISTER_OP(FullyConnected)
+MXNET_ADD_SPARSE_OP_ALIAS(FullyConnected)
 .describe(R"code(Applies a linear transformation: :math:`Y = XW^T + b`.
 
 If ``flatten`` is set to be true, then the shapes are:
@@ -190,6 +247,10 @@ The learnable parameters include both ``weight`` and ``bias``.
 
 If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
 
+Note that the operator also supports forward computation with `row_sparse` weight and bias,
+where the length of `weight.indices` and `bias.indices` must be equal to `num_hidden`.
+This could be used for model inference with `row_sparse` weights trained with `SparseEmbedding`.
+
 )code" ADD_FILELINE)
 .set_num_inputs([](const NodeAttrs& attrs) {
   const FullyConnectedParam& params = nnvm::get<FullyConnectedParam>(attrs.parsed);
@@ -214,9 +275,7 @@ If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
 .set_attr<nnvm::FInferShape>("FInferShape", FullyConnectedShape)
 .set_attr<nnvm::FInferType>("FInferType", FullyConnectedType)
 .set_attr<FCompute>("FCompute<cpu>", FullyConnectedCompute<cpu>)
-#if MXNET_USE_MKLDNN == 1
 .set_attr<FComputeEx>("FComputeEx<cpu>", FullyConnectedComputeExCPU)
-#endif
 .set_attr<nnvm::FGradient>("FGradient", FullyConnectedGrad{"_backward_FullyConnected"})
 .add_argument("data", "NDArray-or-Symbol", "Input data.")
 .add_argument("weight", "NDArray-or-Symbol", "Weight matrix.")
diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py
index 29850dce6ae..729ec8407f2 100644
--- a/tests/python/unittest/test_gluon_contrib.py
+++ b/tests/python/unittest/test_gluon_contrib.py
@@ -108,6 +108,22 @@ def test_conv_fill_shape():
     check_rnn_forward(cell, mx.nd.ones((8, 3, 5, 7)))
     assert cell.i2h_weight.shape[1] == 5, cell.i2h_weight.shape[1]
 
+@with_seed()
+def test_lstmp():
+    nhid = 100
+    nproj = 64
+    cell = contrib.rnn.LSTMPCell(nhid, nproj, prefix='rnn_')
+    inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
+    outputs, _ = cell.unroll(3, inputs)
+    outputs = mx.sym.Group(outputs)
+    expected_params = ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_h2r_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
+    expected_outputs = ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']
+    assert sorted(cell.collect_params().keys()) == expected_params
+    assert outputs.list_outputs() == expected_outputs, outputs.list_outputs()
+
+    args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
+    assert outs == [(10, nproj), (10, nproj), (10, nproj)]
+
 
 @with_seed()
 def test_vardrop():
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index 3d6f9d0711f..182e70c8d7b 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -872,6 +872,7 @@ def test_sparse_nd_check_format():
     a = mx.nd.sparse.row_sparse_array((data_list, indices_list), shape=shape)
     assertRaises(mx.base.MXNetError, a.check_format)
 
+@with_seed()
 def test_sparse_nd_norm():
     def check_sparse_nd_norm(stype, shape, density):
         data, _ = rand_sparse_ndarray(shape, stype, density)
@@ -886,6 +887,23 @@ def check_sparse_nd_norm(stype, shape, density):
         for density in densities:
             check_sparse_nd_norm(stype, shape, density)
 
+@with_seed()
+def test_sparse_fc():
+    def check_sparse_fc(batch_size, dim_in, dim_out, stype):
+        data = rand_ndarray((batch_size, dim_in), stype, density=0.5)
+        weight = rand_ndarray((dim_out, dim_in), 'row_sparse', density=1)
+        bias = rand_ndarray((dim_out, 1), 'row_sparse', density=1)
+        out = mx.nd.sparse.FullyConnected(data, weight, num_hidden=dim_out, bias=bias)
+        data_dns = data.tostype('default')
+        weight_dns = weight.tostype('default')
+        out_dns = mx.nd.FullyConnected(data_dns, weight_dns, num_hidden=dim_out, bias=bias)
+        assert_almost_equal(out.asnumpy(), out_dns.asnumpy())
+
+    # test FC with row_sparse weight w/ density=1, dense data
+    check_sparse_fc(5, 10, 8, 'default')
+    # test FC with row_sparse weight w/ density=1, csr data (fallback)
+    check_sparse_fc(5, 10, 8, 'csr')
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services