You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/03/22 01:02:00 UTC

[incubator-mxnet] branch master updated: [MXNET-96] Language model with Google's billion words dataset (#10025)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 57534bf  [MXNET-96] Language model with Google's billion words dataset (#10025)
57534bf is described below

commit 57534bfadb06ade5fbbab12cf5d05dd8da4538af
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Wed Mar 21 18:01:56 2018 -0700

    [MXNET-96] Language model with Google's billion words dataset (#10025)
    
    * Language model with Google's billion words dataset (#197)
    
    Language model with Google's billion words dataset (#197)
    
    * fix lint
    
    * ffix license
    
    * patch
    
    * fix lint
    
    * cr comment
    
    * update fc fallback
    
    * fix build
    
    * fix temp memory in fc
    
    * fix compilation
---
 example/rnn/large_word_lm/custom_module.py   | 182 ++++++++++++++++++++++++
 example/rnn/large_word_lm/data.py            | 202 +++++++++++++++++++++++++++
 example/rnn/large_word_lm/get_vocab_file.sh  |  34 +++++
 example/rnn/large_word_lm/model.py           | 181 ++++++++++++++++++++++++
 example/rnn/large_word_lm/readme.md          |  66 +++++++++
 example/rnn/large_word_lm/run_utils.py       |  87 ++++++++++++
 example/rnn/large_word_lm/train.py           | 152 ++++++++++++++++++++
 python/mxnet/gluon/contrib/rnn/rnn_cell.py   | 128 ++++++++++++++++-
 src/executor/graph_executor.cc               |   6 +-
 src/operator/nn/fully_connected-inl.h        |  11 +-
 src/operator/nn/fully_connected.cc           | 101 +++++++++++---
 tests/python/unittest/test_gluon_contrib.py  |  16 +++
 tests/python/unittest/test_sparse_ndarray.py |  18 +++
 13 files changed, 1156 insertions(+), 28 deletions(-)

diff --git a/example/rnn/large_word_lm/custom_module.py b/example/rnn/large_word_lm/custom_module.py
new file mode 100644
index 0000000..05d0fb7
--- /dev/null
+++ b/example/rnn/large_word_lm/custom_module.py
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import logging
+import warnings
+
+import mxnet as mx
+import numpy as np
+from mxnet.module import Module
+from mxnet.model import load_checkpoint
+
+class CustomModule(Module):
+
+    def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
+                 logger=logging, context=mx.cpu(), work_load_list=None,
+                 fixed_param_names=None, state_names=None, group2ctxs=None,
+                 compression_params=None):
+
+        super(CustomModule, self).__init__(symbol, data_names=data_names, label_names=label_names,
+                                           logger=logger, context=context, work_load_list=work_load_list,
+                                           fixed_param_names=fixed_param_names, state_names=state_names,
+                                           group2ctxs=group2ctxs, compression_params=compression_params)
+
+    def prepare_sparse_params(self, param_rowids):
+        '''Prepares the module for processing a data batch by pulling row_sparse
+        parameters from kvstore to all devices based on rowids.
+
+        Parameters
+        ----------
+        param_rowids : dict of str to NDArray of list of NDArrays
+        '''
+        if not self._kvstore:
+            return
+        assert(isinstance(param_rowids, dict))
+        for param_name, rowids in param_rowids.items():
+            if isinstance(rowids, (tuple, list)):
+                rowids_1d = []
+                for r in rowids:
+                    rowids_1d.append(r.reshape((-1,)).astype(np.int64))
+                rowid = mx.nd.concat(*rowids_1d, dim=0)
+            else:
+                rowid = rowids
+            param_idx = self._exec_group.param_names.index(param_name)
+            param_val = self._exec_group.param_arrays[param_idx]
+            self._kvstore.row_sparse_pull(param_name, param_val, row_ids=rowid,
+                                          priority=-param_idx)
+
+    @staticmethod
+    def load(prefix, epoch, load_optimizer_states=False, **kwargs):
+        """Creates a model from previously saved checkpoint.
+
+        Parameters
+        ----------
+        prefix : str
+            path prefix of saved model files. You should have
+            "prefix-symbol.json", "prefix-xxxx.params", and
+            optionally "prefix-xxxx.states", where xxxx is the
+            epoch number.
+        epoch : int
+            epoch to load.
+        load_optimizer_states : bool
+            whether to load optimizer states. Checkpoint needs
+            to have been made with save_optimizer_states=True.
+        data_names : list of str
+            Default is `('data')` for a typical model used in image classification.
+        label_names : list of str
+            Default is `('softmax_label')` for a typical model used in image
+            classification.
+        logger : Logger
+            Default is `logging`.
+        context : Context or list of Context
+            Default is ``cpu()``.
+        work_load_list : list of number
+            Default ``None``, indicating uniform workload.
+        fixed_param_names: list of str
+            Default ``None``, indicating no network parameters are fixed.
+        """
+        sym, args, auxs = load_checkpoint(prefix, epoch)
+        mod = CustomModule(symbol=sym, **kwargs)
+        mod._arg_params = args
+        mod._aux_params = auxs
+        mod.params_initialized = True
+        if load_optimizer_states:
+            mod._preload_opt_states = '%s-%04d.states'%(prefix, epoch)
+        return mod
+
+    def save_params(self, fname):
+        """Saves model parameters to file.
+        Parameters
+        ----------
+        fname : str
+            Path to output param file.
+        Examples
+        --------
+        >>> # An example of saving module parameters.
+        >>> mod.save_params('myfile')
+        """
+        arg_params, aux_params = self.get_params_from_kv(self._arg_params, self._aux_params)
+        save_dict = {('arg:%s' % k) : v.as_in_context(mx.cpu()) for k, v in arg_params.items()}
+        save_dict.update({('aux:%s' % k) : v.as_in_context(mx.cpu()) for k, v in aux_params.items()})
+        mx.nd.save(fname, save_dict)
+
+    def get_params_from_kv(self, arg_params, aux_params):
+        """ Copy data from kvstore to `arg_params` and `aux_params`.
+        Parameters
+        ----------
+        arg_params : list of NDArray
+            Target parameter arrays.
+        aux_params : list of NDArray
+            Target aux arrays.
+        Notes
+        -----
+        - This function will inplace update the NDArrays in arg_params and aux_params.
+        """
+        assert(self._kvstore is not None)
+        for name, block in zip(self._exec_group.param_names, self._exec_group.param_arrays):
+            assert(isinstance(block, list))
+            if block[0].stype == 'row_sparse':
+                row_ids = mx.nd.arange(start=0, stop=block[0].shape[0], dtype='int64')
+                self._kvstore.row_sparse_pull(name, arg_params[name], row_ids=row_ids)
+            else:
+                assert(block[0].stype == 'default')
+                self._kvstore.pull(name, out=arg_params[name])
+        if len(aux_params) > 0:
+            raise NotImplementedError()
+        return arg_params, aux_params
+
+    def clip_by_global_norm_per_ctx(self, max_norm=1.0, param_names=None):
+        """Clips gradient norm.
+
+        The norm is computed over all gradients together, as if they were
+         concatenated into a single vector. Gradients are modified in-place.
+
+        The method is first used in
+         `[ICML2013] On the difficulty of training recurrent neural networks`
+
+        Note that the gradients are concatenated per context in this implementation.
+
+        Examples
+        --------
+        An example of using clip_grad_norm to clip the gradient before updating the parameters::
+            >>> #Get the gradient via back-propagation
+            >>> net.forward_backward(data_batch=data_batch)
+            >>> norm_val = net.clip_by_global_norm(max_norm=2.0, param_names='w0')
+            >>> net.update()
+        """
+        assert self.binded and self.params_initialized and self.optimizer_initialized
+        num_ctx = len(self._exec_group.grad_arrays[0])
+        grad_array_per_ctx = [[] for i in range(num_ctx)]
+        assert(param_names is not None)
+        for param_name in param_names:
+            param_idx = self._exec_group.param_names.index(param_name)
+            grad_val = self._exec_group.grad_arrays[param_idx]
+            assert(len(grad_val) == num_ctx)
+            for i in range(num_ctx):
+                grad_array_per_ctx[i].append(grad_val[i])
+        norm_vals = []
+        for i in range(num_ctx):
+            mx.gluon.utils.clip_global_norm(grad_array_per_ctx[i], max_norm)
+
+    def rescale_grad(self, scale=None, param_name=None):
+        """ Rescale the gradient of provided parameters by a certain scale """
+        if scale is None or param_name is None:
+            return
+        param_idx = self._exec_group.param_names.index(param_name)
+        grad_vals = self._exec_group.grad_arrays[param_idx]
+        for grad in grad_vals:
+            grad[:] *= scale
diff --git a/example/rnn/large_word_lm/data.py b/example/rnn/large_word_lm/data.py
new file mode 100644
index 0000000..b9cc3e8
--- /dev/null
+++ b/example/rnn/large_word_lm/data.py
@@ -0,0 +1,202 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import numpy as np
+import codecs, glob, random, logging, collections
+
+class Vocabulary(object):
+    """ A dictionary for words.
+        Adapeted from @rafaljozefowicz's implementation.
+    """
+    def __init__(self):
+        self._token_to_id = {}
+        self._token_to_count = collections.Counter()
+        self._id_to_token = []
+        self._num_tokens = 0
+        self._total_count = 0
+        self._s_id = None
+        self._unk_id = None
+
+    @property
+    def num_tokens(self):
+        return self._num_tokens
+
+    @property
+    def unk(self):
+        return "<UNK>"
+
+    @property
+    def unk_id(self):
+        return self._unk_id
+
+    @property
+    def s(self):
+        return "<S>"
+
+    @property
+    def s_id(self):
+        return self._s_id
+
+    def add(self, token, count):
+        self._token_to_id[token] = self._num_tokens
+        self._token_to_count[token] = count
+        self._id_to_token.append(token)
+        self._num_tokens += 1
+        self._total_count += count
+
+    def finalize(self):
+        self._s_id = self.get_id(self.s)
+        self._unk_id = self.get_id(self.unk)
+
+    def get_id(self, token):
+        # Unseen token are mapped to UNK
+        return self._token_to_id.get(token, self.unk_id)
+
+    def get_token(self, id_):
+        return self._id_to_token[id_]
+
+    @staticmethod
+    def from_file(filename):
+        vocab = Vocabulary()
+        with codecs.open(filename, "r", "utf-8") as f:
+            for line in f:
+                word, count = line.strip().split()
+                vocab.add(word, int(count))
+        vocab.finalize()
+        return vocab
+
+class Dataset(object):
+    """ A dataset for truncated bptt with multiple sentences.
+        Adapeted from @rafaljozefowicz's implementation.
+     """
+    def __init__(self, vocab, file_pattern, shuffle=False):
+        self._vocab = vocab
+        self._file_pattern = file_pattern
+        self._shuffle = shuffle
+
+    def _parse_sentence(self, line):
+        s_id = self._vocab.s_id
+        return [s_id] + [self._vocab.get_id(word) for word in line.strip().split()] + [s_id]
+
+    def _parse_file(self, file_name):
+        logging.debug("Processing file: %s" % file_name)
+        with codecs.open(file_name, "r", "utf-8") as f:
+            lines = [line.strip() for line in f]
+            if not self._shuffle:
+                random.shuffle(lines)
+            logging.debug("Finished processing!")
+            for line in lines:
+                yield self._parse_sentence(line)
+
+    def _sentence_stream(self, file_stream):
+        for file_name in file_stream:
+            for sentence in self._parse_file(file_name):
+                yield sentence
+
+    def _iterate(self, sentences, batch_size, num_steps):
+        streams = [None] * batch_size
+        x = np.zeros([batch_size, num_steps], np.int32)
+        y = np.zeros([batch_size, num_steps], np.int32)
+        w = np.zeros([batch_size, num_steps], np.uint8)
+        while True:
+            x[:] = 0
+            y[:] = 0
+            w[:] = 0
+            for i in range(batch_size):
+                tokens_filled = 0
+                try:
+                    while tokens_filled < num_steps:
+                        if streams[i] is None or len(streams[i]) <= 1:
+                            streams[i] = next(sentences)
+                        num_tokens = min(len(streams[i]) - 1, num_steps - tokens_filled)
+                        x[i, tokens_filled:tokens_filled+num_tokens] = streams[i][:num_tokens]
+                        y[i, tokens_filled:tokens_filled + num_tokens] = streams[i][1:num_tokens+1]
+                        w[i, tokens_filled:tokens_filled + num_tokens] = 1
+                        streams[i] = streams[i][num_tokens:]
+                        tokens_filled += num_tokens
+                except StopIteration:
+                    pass
+            if not np.any(w):
+                return
+
+            yield x, y, w
+
+    def iterate_once(self, batch_size, num_steps):
+        def file_stream():
+            file_patterns = glob.glob(self._file_pattern)
+            if not self._shuffle:
+                random.shuffle(file_patterns)
+            for file_name in file_patterns:
+                yield file_name
+        for value in self._iterate(self._sentence_stream(file_stream()), batch_size, num_steps):
+            yield value
+
+    def iterate_forever(self, batch_size, num_steps):
+        def file_stream():
+            while True:
+                file_patterns = glob.glob(self._file_pattern)
+                if not self._shuffle:
+                    random.shuffle(file_patterns)
+                for file_name in file_patterns:
+                    yield file_name
+        for value in self._iterate(self._sentence_stream(file_stream()), batch_size, num_steps):
+            yield value
+
+class MultiSentenceIter(mx.io.DataIter):
+    """ An MXNet iterator that returns the a batch of sequence data and label each time.
+        It also returns a mask which indicates padded/missing data at the end of the dataset.
+        The iterator re-shuffles the data when reset is called.
+    """
+    def __init__(self, data_file, vocab, batch_size, bptt):
+        super(MultiSentenceIter, self).__init__()
+        self.batch_size = batch_size
+        self.bptt = bptt
+        self.provide_data = [('data', (batch_size, bptt), np.int32), ('mask', (batch_size, bptt))]
+        self.provide_label = [('label', (batch_size, bptt))]
+        self.vocab = vocab
+        self.data_file = data_file
+        self._dataset = Dataset(self.vocab, data_file, shuffle=True)
+        self._iter = self._dataset.iterate_once(batch_size, bptt)
+
+    def iter_next(self):
+        data = self._iter.next()
+        if data is None:
+            return False
+        self._next_data = mx.nd.array(data[0], dtype=np.int32)
+        self._next_label = mx.nd.array(data[1])
+        self._next_mask = mx.nd.array(data[2])
+        return True
+
+    def next(self):
+        if self.iter_next():
+            return mx.io.DataBatch(data=self.getdata(), label=self.getlabel())
+        else:
+            raise StopIteration
+
+    def reset(self):
+        self._dataset = Dataset(self.vocab, self.data_file, shuffle=False)
+        self._iter = self._dataset.iterate_once(self.batch_size, self.bptt)
+        self._next_data = None
+        self._next_label = None
+        self._next_mask = None
+
+    def getdata(self):
+        return [self._next_data, self._next_mask]
+
+    def getlabel(self):
+        return [self._next_label]
diff --git a/example/rnn/large_word_lm/get_vocab_file.sh b/example/rnn/large_word_lm/get_vocab_file.sh
new file mode 100755
index 0000000..97fa29b
--- /dev/null
+++ b/example/rnn/large_word_lm/get_vocab_file.sh
@@ -0,0 +1,34 @@
+#!/usr/bin/env bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+echo ""
+echo "NOTE: This script only downloads the pre-processed vocabulary file. "
+echo "For the full training and testing dataset, please download from "
+echo "http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz"
+echo ""
+
+RNN_DIR=$(cd `dirname $0`; pwd)
+DATA_DIR="${RNN_DIR}/data/"
+
+if [[ ! -d "${DATA_DIR}" ]]; then
+  echo "${DATA_DIR} doesn't exist, will create one";
+  mkdir -p ${DATA_DIR}
+fi
+
+wget -P ${DATA_DIR} wget https://s3-us-west-2.amazonaws.com/sparse-dataset/gbw/1b_word_vocab.txt;
diff --git a/example/rnn/large_word_lm/model.py b/example/rnn/large_word_lm/model.py
new file mode 100644
index 0000000..7ee010e
--- /dev/null
+++ b/example/rnn/large_word_lm/model.py
@@ -0,0 +1,181 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Licensed to the Apache Software Soundation (ASS) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASS licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OS ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import mxnet.symbol as S
+import numpy as np
+
+def cross_entropy_loss(inputs, labels, rescale_loss=1):
+    """ cross entropy loss with a mask """
+    criterion = mx.gluon.loss.SoftmaxCrossEntropyLoss(weight=rescale_loss)
+    loss = criterion(inputs, labels)
+    mask = S.var('mask')
+    loss = loss * S.reshape(mask, shape=(-1,))
+    return S.make_loss(loss.mean())
+
+def rnn(bptt, vocab_size, num_embed, nhid, num_layers, dropout, num_proj, batch_size):
+    """ word embedding + LSTM Projected """
+    embed = mx.sym.contrib.SparseEmbedding
+    state_names = []
+    data = S.var('data')
+    weight = S.var("encoder_weight", stype='row_sparse')
+    embed = embed(data=data, weight=weight, input_dim=vocab_size,
+                  output_dim=num_embed, name='embed', deterministic=True)
+    states = []
+    outputs = S.Dropout(embed, p=dropout)
+    for i in range(num_layers):
+        prefix = 'lstmp%d_' % i
+        init_h = S.var(prefix + 'init_h', shape=(batch_size, num_proj), init=mx.init.Zero())
+        init_c = S.var(prefix + 'init_c', shape=(batch_size, nhid), init=mx.init.Zero())
+        state_names += [prefix + 'init_h', prefix + 'init_c']
+        lstmp = mx.gluon.contrib.rnn.LSTMPCell(nhid, num_proj)
+        outputs, next_states = lstmp.unroll(bptt, outputs, begin_state=[init_h, init_c], \
+                                            layout='NTC', merge_outputs=True)
+        outputs = S.Dropout(outputs, p=dropout)
+        states += [S.stop_gradient(s) for s in next_states]
+    outputs = S.reshape(outputs, shape=(-1, num_proj))
+
+    trainable_lstm_args = []
+    for arg in outputs.list_arguments():
+        if 'lstmp' in arg and 'init' not in arg:
+            trainable_lstm_args.append(arg)
+    return outputs, states, trainable_lstm_args, state_names
+
+def sampled_softmax(num_classes, num_samples, in_dim, inputs, weight, bias,
+                    sampled_values, remove_accidental_hits=True):
+        """ Sampled softmax via importance sampling.
+            This under-estimates the full softmax and is only used for training.
+        """
+        # inputs = (n, in_dim)
+        embed = mx.sym.contrib.SparseEmbedding
+        sample, prob_sample, prob_target = sampled_values
+
+        # (num_samples, )
+        sample = S.var('sample', shape=(num_samples,), dtype='float32')
+        # (n, )
+        label = S.var('label')
+        label = S.reshape(label, shape=(-1,), name="label_reshape")
+        # (num_samples+n, )
+        sample_label = S.concat(sample, label, dim=0)
+        # lookup weights and biases
+        # (num_samples+n, dim)
+        sample_target_w = embed(data=sample_label, weight=weight,
+                                     input_dim=num_classes, output_dim=in_dim,
+                                     deterministic=True)
+        # (num_samples+n, 1)
+        sample_target_b = embed(data=sample_label, weight=bias,
+                                input_dim=num_classes, output_dim=1, deterministic=True)
+        # (num_samples, dim)
+        sample_w = S.slice(sample_target_w, begin=(0, 0), end=(num_samples, None))
+        target_w = S.slice(sample_target_w, begin=(num_samples, 0), end=(None, None))
+        sample_b = S.slice(sample_target_b, begin=(0, 0), end=(num_samples, None))
+        target_b = S.slice(sample_target_b, begin=(num_samples, 0), end=(None, None))
+
+        # target
+        # (n, 1)
+        true_pred = S.sum(target_w * inputs, axis=1, keepdims=True) + target_b
+        # samples
+        # (n, num_samples)
+        sample_b = S.reshape(sample_b, (-1,))
+        sample_pred = S.FullyConnected(inputs, weight=sample_w, bias=sample_b,
+                                       num_hidden=num_samples)
+
+        # remove accidental hits
+        if remove_accidental_hits:
+            label_v = S.reshape(label, (-1, 1))
+            sample_v = S.reshape(sample, (1, -1))
+            neg = S.broadcast_equal(label_v, sample_v) * -1e37
+            sample_pred = sample_pred + neg
+
+        prob_sample = S.reshape(prob_sample, shape=(1, num_samples))
+        p_target = true_pred - S.log(prob_target)
+        p_sample = S.broadcast_sub(sample_pred, S.log(prob_sample))
+
+        # return logits and new_labels
+        # (n, 1+num_samples)
+        logits = S.concat(p_target, p_sample, dim=1)
+        new_targets = S.zeros_like(label)
+        return logits, new_targets
+
+def generate_samples(label, num_splits, num_samples, num_classes):
+    """ Split labels into `num_splits` and
+        generate candidates based on log-uniform distribution.
+    """
+    def listify(x):
+        return x if isinstance(x, list) else [x]
+    label_splits = listify(label.split(num_splits, axis=0))
+    prob_samples = []
+    prob_targets = []
+    samples = []
+    for label_split in label_splits:
+        label_split_2d = label_split.reshape((-1,1))
+        sampled_value = mx.nd.contrib.rand_zipfian(label_split_2d, num_samples, num_classes)
+        sampled_classes, exp_cnt_true, exp_cnt_sampled = sampled_value
+        samples.append(sampled_classes.astype(np.float32))
+        prob_targets.append(exp_cnt_true.astype(np.float32))
+        prob_samples.append(exp_cnt_sampled.astype(np.float32))
+    return samples, prob_samples, prob_targets
+
+class Model():
+    """ LSTMP with Importance Sampling """
+    def __init__(self, args, ntokens, rescale_loss):
+        out = rnn(args.bptt, ntokens, args.emsize, args.nhid, args.nlayers,
+                  args.dropout, args.num_proj, args.batch_size)
+        rnn_out, self.last_states, self.lstm_args, self.state_names = out
+        # decoder weight and bias
+        decoder_w = S.var("decoder_weight", stype='row_sparse')
+        decoder_b = S.var("decoder_bias", shape=(ntokens, 1), stype='row_sparse')
+
+        # sampled softmax for training
+        sample = S.var('sample', shape=(args.k,))
+        prob_sample = S.var("prob_sample", shape=(args.k,))
+        prob_target = S.var("prob_target")
+        self.sample_names = ['sample', 'prob_sample', 'prob_target']
+        logits, new_targets = sampled_softmax(ntokens, args.k, args.num_proj,
+                                              rnn_out, decoder_w, decoder_b,
+                                              [sample, prob_sample, prob_target])
+        self.train_loss = cross_entropy_loss(logits, new_targets, rescale_loss=rescale_loss)
+
+        # full softmax for testing
+        eval_logits = S.FullyConnected(data=rnn_out, weight=decoder_w,
+                                       num_hidden=ntokens, name='decode_fc', bias=decoder_b)
+        label = S.Variable('label')
+        label = S.reshape(label, shape=(-1,))
+        self.eval_loss = cross_entropy_loss(eval_logits, label)
+
+    def eval(self):
+        return S.Group(self.last_states + [self.eval_loss])
+
+    def train(self):
+        return S.Group(self.last_states + [self.train_loss])
diff --git a/example/rnn/large_word_lm/readme.md b/example/rnn/large_word_lm/readme.md
new file mode 100644
index 0000000..d74ffbd
--- /dev/null
+++ b/example/rnn/large_word_lm/readme.md
@@ -0,0 +1,66 @@
+# Large-Scale Language Model
+This example implements the baseline model in
+[Exploring the Limits of Language Modeling](https://arxiv.org/abs/1602.02410) on the
+[Google 1-Billion Word](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) (GBW) dataset.
+
+This example reaches **41.97 perplexity** after 5 training epochs on a 1-layer, 2048-unit, 512-projection LSTM Language Model.
+The result is slightly better than the one reported in the paper(43.7 perplexity).
+The main differences with the original implementation include:
+* Synchronized gradient updates instead of asynchronized updates
+* Noise candidates are sampled with replacement
+
+Each epoch for training takes around 80 minutes on a p3.8xlarge instance, which comes with 4 Volta V100 GPUs.
+
+# Setup - Original Data Format
+1. Download 1-Billion Word Dataset - [Link](http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz)
+2. Download pre-processed vocabulary file which maps tokens into ids.
+
+# Run the Script
+```
+usage: train.py [-h] [--data DATA] [--test TEST] [--vocab VOCAB]
+                [--emsize EMSIZE] [--nhid NHID] [--num-proj NUM_PROJ]
+                [--nlayers NLAYERS] [--epochs EPOCHS]
+                [--batch-size BATCH_SIZE] [--dropout DROPOUT] [--eps EPS]
+                [--bptt BPTT] [--k K] [--gpus GPUS]
+                [--log-interval LOG_INTERVAL] [--seed SEED]
+                [--checkpoint-dir CHECKPOINT_DIR] [--lr LR] [--clip CLIP]
+                [--rescale-embed RESCALE_EMBED]
+
+Language Model on GBW
+
+optional arguments:
+  -h, --help            show this help message and exit
+  --data DATA           location of the training data
+  --test TEST           location of the test data
+  --vocab VOCAB         location of the corpus vocabulary file
+  --emsize EMSIZE       size of word embeddings
+  --nhid NHID           number of hidden units per layer
+  --num-proj NUM_PROJ   number of projection units per layer
+  --nlayers NLAYERS     number of LSTM layers
+  --epochs EPOCHS       number of epoch for training
+  --batch-size BATCH_SIZE
+                        batch size per gpu
+  --dropout DROPOUT     dropout applied to layers (0 = no dropout)
+  --eps EPS             epsilon for adagrad
+  --bptt BPTT           sequence length
+  --k K                 number of noise samples for estimation
+  --gpus GPUS           list of gpus to run, e.g. 0 or 0,2,5. empty means
+                        using gpu(0).
+  --log-interval LOG_INTERVAL
+                        report interval
+  --seed SEED           random seed
+  --checkpoint-dir CHECKPOINT_DIR
+                        dir for checkpoint
+  --lr LR               initial learning rate
+  --clip CLIP           gradient clipping by global norm.
+  --rescale-embed RESCALE_EMBED
+                        scale factor for the gradients of the embedding layer
+```
+
+To reproduce the result, run
+```
+train.py --gpus=0,1,2,3 --clip=1 --lr=0.05 --dropout=0.01 --eps=0.0001 --rescale-embed=128
+--test=/path/to/heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050
+--data=/path/to/training-monolingual.tokenized.shuffled/*
+# ~42 perplexity for 5 epochs of training
+```
diff --git a/example/rnn/large_word_lm/run_utils.py b/example/rnn/large_word_lm/run_utils.py
new file mode 100644
index 0000000..7650530e
--- /dev/null
+++ b/example/rnn/large_word_lm/run_utils.py
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import argparse, time, logging, math
+
+def get_parser():
+    parser = argparse.ArgumentParser(description='Language Model on GBW')
+    parser.add_argument('--data', type=str,
+                        default='/path/to/training-monolingual.tokenized.shuffled/*',
+                        help='location of the training data')
+    parser.add_argument('--test', type=str,
+                        default='/path/to/heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050',
+                        help='location of the test data')
+    parser.add_argument('--vocab', type=str, default='./data/1b_word_vocab.txt',
+                        help='location of the corpus vocabulary file')
+    parser.add_argument('--emsize', type=int, default=512,
+                        help='size of word embeddings')
+    parser.add_argument('--nhid', type=int, default=2048,
+                        help='number of hidden units per layer')
+    parser.add_argument('--num-proj', type=int, default=512,
+                        help='number of projection units per layer')
+    parser.add_argument('--nlayers', type=int, default=1,
+                        help='number of LSTM layers')
+    parser.add_argument('--epochs', type=int, default=8,
+                        help='number of epoch for training')
+    parser.add_argument('--batch-size', type=int, default=128,
+                        help='batch size per gpu')
+    parser.add_argument('--dropout', type=float, default=0.1,
+                        help='dropout applied to layers (0 = no dropout)')
+    parser.add_argument('--eps', type=float, default=0.0001,
+                        help='epsilon for adagrad')
+    parser.add_argument('--bptt', type=int, default=20,
+                        help='sequence length')
+    parser.add_argument('--k', type=int, default=8192,
+                        help='number of noise samples for estimation')
+    parser.add_argument('--gpus', type=str,
+                        help='list of gpus to run, e.g. 0 or 0,2,5. empty means using gpu(0).')
+    parser.add_argument('--log-interval', type=int, default=200,
+                        help='report interval')
+    parser.add_argument('--seed', type=int, default=1,
+                        help='random seed')
+    parser.add_argument('--checkpoint-dir', type=str, default='./checkpoint/cp',
+                        help='dir for checkpoint')
+    parser.add_argument('--lr', type=float, default=0.1,
+                        help='initial learning rate')
+    parser.add_argument('--clip', type=float, default=1,
+                        help='gradient clipping by global norm.')
+    parser.add_argument('--rescale-embed', type=float, default=None,
+                        help='scale factor for the gradients of the embedding layer')
+    return parser
+
+def evaluate(mod, data_iter, epoch, log_interval):
+    """ Run evaluation on cpu. """
+    start = time.time()
+    total_L = 0.0
+    nbatch = 0
+    mod.set_states(value=0)
+    for batch in data_iter:
+        mod.forward(batch, is_train=False)
+        outputs = mod.get_outputs(merge_multi_context=False)
+        states = outputs[:-1]
+        total_L += outputs[-1][0].asscalar()
+        mod.set_states(states=states)
+        nbatch += 1
+        if (nbatch + 1) % log_interval == 0:
+            logging.info("Eval batch %d loss : %.7f" % (nbatch, total_L / nbatch))
+    data_iter.reset()
+    loss = total_L / nbatch
+    ppl = math.exp(loss) if loss < 100 else 1e37
+    end = time.time()
+    logging.info('Iter[%d]\t\t CE loss %.7f, ppl %.7f. Eval duration = %.2f seconds ' % \
+                 (epoch, loss, ppl, end - start))
+    return loss
diff --git a/example/rnn/large_word_lm/train.py b/example/rnn/large_word_lm/train.py
new file mode 100644
index 0000000..a1b4e31
--- /dev/null
+++ b/example/rnn/large_word_lm/train.py
@@ -0,0 +1,152 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import mxnet as mx
+import mxnet.symbol as S
+import run_utils
+from data import MultiSentenceIter, Vocabulary
+from model import *
+from custom_module import CustomModule
+import os, math, logging, sys
+
+if __name__ == '__main__':
+    # parser
+    parser = run_utils.get_parser()
+    args = parser.parse_args()
+    head = '%(asctime)-15s %(message)s'
+    ctx = [mx.gpu(int(i)) for i in args.gpus.split(',')] if args.gpus else [mx.gpu()]
+    ngpus = len(ctx)
+    rescale_loss = args.bptt
+
+    # logging
+    logging.basicConfig(level=logging.INFO, format=head)
+    logging.info(args)
+    logging.debug(sys.argv)
+
+    # seeding
+    mx.random.seed(args.seed)
+    np.random.seed(args.seed)
+
+    # data
+    vocab = Vocabulary.from_file(args.vocab)
+    ntokens = vocab.num_tokens
+    train_data = mx.io.PrefetchingIter(MultiSentenceIter(args.data, vocab,
+                                       args.batch_size * ngpus, args.bptt))
+    # model
+    model = Model(args, ntokens, rescale_loss)
+    train_loss_and_states = model.train()
+    eval_loss_and_states = model.eval()
+
+    # training module
+    data_names, label_names = ['data', 'mask'], ['label']
+    eval_state_names = model.state_names
+    num_sample_names = len(model.sample_names)
+    train_state_names = model.state_names + model.sample_names
+
+    module = CustomModule(symbol=train_loss_and_states, context=ctx,
+                          state_names=train_state_names,
+                          data_names=data_names, label_names=label_names)
+    module.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label)
+    module.init_params(initializer=mx.init.Xavier(factor_type='out'))
+
+    # create kvstore and sparse optimizer
+    kvstore = mx.kv.create('device')
+    optimizer = mx.optimizer.create('adagrad', learning_rate=args.lr, \
+                                    rescale_grad=1.0/ngpus, eps=args.eps)
+    module.init_optimizer(optimizer=optimizer, kvstore=kvstore)
+
+    # speedometer
+    num_words_per_batch = args.batch_size * ngpus * args.bptt
+    speedometer = mx.callback.Speedometer(num_words_per_batch, args.log_interval)
+
+    # training loop
+    logging.info("Training started ... ")
+    for epoch in range(args.epochs):
+        total_L = mx.nd.array([0.0])
+        nbatch = 0
+        # reset initial LSTMP states
+        module.set_states(value=0)
+        state_cache = module.get_states(merge_multi_context=False)[:-num_sample_names]
+        next_batch = train_data.next()
+        next_sampled_values = generate_samples(next_batch.label[0], ngpus, args.k, ntokens)
+        stop_iter = False
+        while not stop_iter:
+            batch = next_batch
+            state_cache += next_sampled_values
+            # propagate LSTMP states from the previous batch
+            module.set_states(states=state_cache)
+            # selectively pull row_sparse weight to multiple devices based on the data batch
+            target_ids = [batch.label[0]]
+            sampled_ids = next_sampled_values[0]
+            param_rowids = {'encoder_weight': batch.data[0],
+                            'decoder_weight': sampled_ids + target_ids,
+                            'decoder_bias': sampled_ids + target_ids}
+            module.prepare_sparse_params(param_rowids)
+            # forward
+            module.forward(batch)
+            try:
+                # prefetch the next batch of data and samples
+                next_batch = train_data.next()
+                next_sampled_values = generate_samples(next_batch.label[0], ngpus,
+                                                       args.k, ntokens)
+            except StopIteration:
+                stop_iter = True
+            # cache LSTMP states of the current batch
+            outputs = module.get_outputs(merge_multi_context=False)
+            state_cache = outputs[:-1]
+            # backward
+            module.backward()
+            for g in range(ngpus):
+                total_L += outputs[-1][g].copyto(mx.cpu()) / ngpus
+
+            # rescaling the gradient for embedding layer emperically leads to faster convergence
+            module.rescale_grad(args.rescale_embed, 'encoder_weight')
+            # clip lstm params on each device based on norm
+            norm = module.clip_by_global_norm_per_ctx(max_norm=args.clip, param_names=model.lstm_args)
+            # update parameters
+            module.update()
+            speedometer_param = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
+                                                       eval_metric=None, locals=locals())
+            speedometer(speedometer_param)
+            # update training metric
+            if nbatch % args.log_interval == 0 and nbatch > 0:
+                cur_L = total_L.asscalar() / args.log_interval / rescale_loss
+                ppl = math.exp(cur_L) if cur_L < 100 else 1e36
+                logging.info('Iter[%d] Batch [%d] \tloss %.7f, ppl %.7f'%(epoch, nbatch, cur_L, ppl))
+                total_L[:] = 0.0
+            nbatch += 1
+
+        # run evaluation with full softmax on cpu
+        module.save_checkpoint(args.checkpoint_dir, epoch, save_optimizer_states=False)
+        cpu_train_mod = CustomModule.load(args.checkpoint_dir, epoch, context=mx.cpu(),
+                                          state_names=train_state_names,
+                                          data_names=data_names, label_names=label_names)
+        # eval data iter
+        eval_data = mx.io.PrefetchingIter(MultiSentenceIter(args.test, vocab,
+                                          args.batch_size, args.bptt))
+        cpu_train_mod.bind(data_shapes=eval_data.provide_data, label_shapes=eval_data.provide_label)
+
+        # eval module
+        eval_module = CustomModule(symbol=eval_loss_and_states, context=mx.cpu(), data_names=data_names,
+                                   label_names=label_names, state_names=eval_state_names)
+        # use `shared_module` to share parameter with the training module
+        eval_module.bind(data_shapes=eval_data.provide_data, label_shapes=eval_data.provide_label,
+                         shared_module=cpu_train_mod, for_training=False)
+        val_L = run_utils.evaluate(eval_module, eval_data, epoch, 20)
+        train_data.reset()
+    logging.info("Training completed. ")
diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
index b964c71..1b9afee 100644
--- a/python/mxnet/gluon/contrib/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
@@ -17,13 +17,12 @@
 
 # coding: utf-8
 """Definition of various recurrent neural network cells."""
-__all__ = ['VariationalDropoutCell']
+__all__ = ['VariationalDropoutCell', 'LSTMPCell']
 
-from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell
+from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell, HybridRecurrentCell
 from ...rnn.rnn_cell import _format_sequence, _get_begin_state, _mask_sequence_variable_length
 from ... import tensor_types
 
-
 class VariationalDropoutCell(ModifierCell):
     """
     Applies Variational Dropout on base cell.
@@ -193,3 +192,126 @@ class VariationalDropoutCell(ModifierCell):
             outputs = _mask_sequence_variable_length(F, outputs, length, valid_length, axis,
                                                      merge_outputs)
         return outputs, states
+
+
+class LSTMPCell(HybridRecurrentCell):
+    r"""Long-Short Term Memory Projected (LSTMP) network cell.
+    (https://arxiv.org/abs/1402.1128)
+    Each call computes the following function:
+    .. math::
+        \begin{array}{ll}
+        i_t = sigmoid(W_{ii} x_t + b_{ii} + W_{ri} r_{(t-1)} + b_{ri}) \\
+        f_t = sigmoid(W_{if} x_t + b_{if} + W_{rf} r_{(t-1)} + b_{rf}) \\
+        g_t = \tanh(W_{ig} x_t + b_{ig} + W_{rc} r_{(t-1)} + b_{rg}}) \\
+        o_t = sigmoid(W_{io} x_t + b_{io} + W_{ro} r_{(t-1)} + b_{ro}) \\
+        c_t = f_t * c_{(t-1)} + i_t * g_t \\
+        h_t = o_t * \tanh(c_t) \\
+        r_t = W_{hr} h_t
+        \end{array}
+    where :math:`r_t` is the projected recurrent activation at time `t`,
+    math:`h_t` is the hidden state at time `t`, :math:`c_t` is the
+    cell state at time `t`, :math:`x_t` is the input at time `t`, and :math:`i_t`,
+    :math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell, and
+    out gates, respectively.
+    Parameters
+    ----------
+    hidden_size : int
+        Number of units in cell state symbol.
+    projection_size : int
+        Number of units in output symbol.
+    i2h_weight_initializer : str or Initializer
+        Initializer for the input weights matrix, used for the linear
+        transformation of the inputs.
+    h2h_weight_initializer : str or Initializer
+        Initializer for the recurrent weights matrix, used for the linear
+        transformation of the hidden state.
+    h2r_weight_initializer : str or Initializer
+        Initializer for the projection weights matrix, used for the linear
+        transformation of the recurrent state.
+    i2h_bias_initializer : str or Initializer, default 'lstmbias'
+        Initializer for the bias vector. By default, bias for the forget
+        gate is initialized to 1 while all other biases are initialized
+        to zero.
+    h2h_bias_initializer : str or Initializer
+        Initializer for the bias vector.
+    prefix : str, default 'lstmp_'
+        Prefix for name of `Block`s
+        (and name of weight if params is `None`).
+    params : Parameter or None
+        Container for weight sharing between cells.
+        Created if `None`.
+    Inputs:
+        - **data**: input tensor with shape `(batch_size, input_size)`.
+        - **states**: a list of two initial recurrent state tensors, with shape
+          `(batch_size, projection_size)` and `(batch_size, hidden_size)` respectively.
+    Outputs:
+        - **out**: output tensor with shape `(batch_size, num_hidden)`.
+        - **next_states**: a list of two output recurrent state tensors. Each has
+          the same shape as `states`.
+    """
+    def __init__(self, hidden_size, projection_size,
+                 i2h_weight_initializer=None, h2h_weight_initializer=None,
+                 h2r_weight_initializer=None,
+                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
+                 input_size=0, prefix=None, params=None):
+        super(LSTMPCell, self).__init__(prefix=prefix, params=params)
+
+        self._hidden_size = hidden_size
+        self._input_size = input_size
+        self._projection_size = projection_size
+        self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size),
+                                          init=i2h_weight_initializer,
+                                          allow_deferred_init=True)
+        self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, projection_size),
+                                          init=h2h_weight_initializer,
+                                          allow_deferred_init=True)
+        self.h2r_weight = self.params.get('h2r_weight', shape=(projection_size, hidden_size),
+                                          init=h2r_weight_initializer,
+                                          allow_deferred_init=True)
+        self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,),
+                                        init=i2h_bias_initializer,
+                                        allow_deferred_init=True)
+        self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,),
+                                        init=h2h_bias_initializer,
+                                        allow_deferred_init=True)
+
+    def state_info(self, batch_size=0):
+        return [{'shape': (batch_size, self._projection_size), '__layout__': 'NC'},
+                {'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}]
+
+    def _alias(self):
+        return 'lstmp'
+
+    def __repr__(self):
+        s = '{name}({mapping})'
+        shape = self.i2h_weight.shape
+        proj_shape = self.h2r_weight.shape
+        mapping = '{0} -> {1} -> {2}'.format(shape[1] if shape[1] else None,
+                                             shape[0], proj_shape[0])
+        return s.format(name=self.__class__.__name__,
+                        mapping=mapping,
+                        **self.__dict__)
+
+    # pylint: disable= arguments-differ
+    def hybrid_forward(self, F, inputs, states, i2h_weight,
+                       h2h_weight, h2r_weight, i2h_bias, h2h_bias):
+        prefix = 't%d_'%self._counter
+        i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,
+                               num_hidden=self._hidden_size*4, name=prefix+'i2h')
+        h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
+                               num_hidden=self._hidden_size*4, name=prefix+'h2h')
+        gates = i2h + h2h
+        slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice')
+        in_gate = F.Activation(slice_gates[0], act_type="sigmoid", name=prefix+'i')
+        forget_gate = F.Activation(slice_gates[1], act_type="sigmoid", name=prefix+'f')
+        in_transform = F.Activation(slice_gates[2], act_type="tanh", name=prefix+'c')
+        out_gate = F.Activation(slice_gates[3], act_type="sigmoid", name=prefix+'o')
+        next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform,
+                                   name=prefix+'state')
+        hidden = F._internal._mul(out_gate, F.Activation(next_c, act_type="tanh"),
+                                  name=prefix+'hidden')
+        next_r = F.FullyConnected(data=hidden, num_hidden=self._projection_size,
+                                  weight=h2r_weight, no_bias=True, name=prefix+'out')
+
+        return next_r, [next_r, next_c]
+    # pylint: enable= arguments-differ
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 7d31a31..ee97649 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -802,17 +802,17 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
           CHECK_EQ(inferred_stype, arg_nd_stype)
             << "Inferred stype does not match shared_exec.arg_array's stype"
                " Therefore, the allocated memory for shared_exec.arg_array cannot"
-               " be resued for creating NDArray of the argument"
+               " be resued for creating NDArray of the argument "
             << arg_name << " for the current executor";
           CHECK_EQ(inferred_shape, in_arg_nd.shape())
             << "Inferred shape does not match shared_exec.arg_array's shape"
                " Therefore, the allocated memory for shared_exec.arg_array cannot"
-               " be resued for creating NDArray of the argument"
+               " be resued for creating NDArray of the argument "
             << arg_name << " for the current executor";
           CHECK_EQ(inferred_dtype, in_arg_nd.dtype())
             << "Inferred dtype does not match shared_exec.arg_array's dtype"
                " Therefore, the allocated memory for shared_exec.arg_array cannot"
-               " be resued for creating NDArray of the argument"
+               " be resued for creating NDArray of the argument "
             << arg_name << " for the current executor";
           in_arg_vec->emplace_back(in_arg_nd);
         } else {
diff --git a/src/operator/nn/fully_connected-inl.h b/src/operator/nn/fully_connected-inl.h
index e8e9564..7eba2e2 100644
--- a/src/operator/nn/fully_connected-inl.h
+++ b/src/operator/nn/fully_connected-inl.h
@@ -95,11 +95,20 @@ void FCForward(const OpContext &ctx, const FullyConnectedParam &param,
         Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s);
   }
 
+  CHECK_EQ(data.shape_[1], wmat.shape_[1])
+    << "Incomplete weight tensor detected: weight.data().shape[1] != prod(data.data().shape[1:])."
+       " This is not supported by FCForward. If weight is in row_sparse format,"
+       " please make sure all row ids are present.";
   // Legacy approach shown here for comparison:
   //   out = dot(data, wmat.T());
   linalg_gemm(data, wmat, out, false, true, s);
   if (!param.no_bias) {
-    Tensor<xpu, 1, DType> bias = in_data[fullc::kBias].get<xpu, 1, DType>(s);
+    Tensor<xpu, 1, DType> bias = in_data[fullc::kBias].get_with_shape<xpu, 1, DType>(
+      Shape1(wmat.shape_[0]), s);
+    CHECK_EQ(bias.shape_[0], wmat.shape_[0])
+      << "Incomplete bias tensor detected: bias.data().shape[1] != weight.data().shape[0]."
+         " This is not supported by FCForward. If bias is in row_sparse format, please"
+         " make sure all row ids are present.";
     out += repmat(bias, data.size(0));
   }
 }
diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc
index 4362408..75d594f 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -56,7 +56,10 @@ static bool FullyConnectedShape(const nnvm::NodeAttrs& attrs,
   }
   SHAPE_ASSIGN_CHECK(*in_shape, fullc::kWeight, Shape2(param.num_hidden, num_input));
   if (!param.no_bias) {
-    SHAPE_ASSIGN_CHECK(*in_shape, fullc::kBias, Shape1(param.num_hidden));
+    if (!shape_assign(&(*in_shape)[fullc::kBias], Shape1(param.num_hidden)) &&
+        !shape_assign(&(*in_shape)[fullc::kBias], Shape2(param.num_hidden, 1))) {
+      LOG(FATAL) << "Unexpected shape for bias " << (*in_shape)[fullc::kBias];
+    }
   }
 
   if (!param.flatten) {
@@ -73,22 +76,67 @@ static bool FullyConnectedShape(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
-#if MXNET_USE_MKLDNN == 1
 void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs,
                                 const OpContext &ctx,
                                 const std::vector<NDArray> &inputs,
                                 const std::vector<OpReqType> &req,
                                 const std::vector<NDArray> &outputs) {
-  if (SupportMKLDNN(inputs[0])) {
-    MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
-    MKLDNNFCForward(attrs, ctx, inputs, req, outputs);
-    MKLDNN_OPCHECK_RUN(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req,
-                       outputs);
+  const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
+  const bool valid_data = inputs[0].storage_type() == kDefaultStorage;
+  const bool valid_weight = inputs[1].storage_type() == kDefaultStorage ||
+                            inputs[1].storage_type() == kRowSparseStorage;
+  const bool valid_out = outputs[0].storage_type() == kDefaultStorage;
+  bool valid_bias = true;
+  if (!param.no_bias) {
+    valid_bias = inputs[2].storage_type() == kDefaultStorage ||
+                 inputs[2].storage_type() == kRowSparseStorage;
+  }
+#if MXNET_USE_MKLDNN == 1
+  if (common::ContainsOnlyStorage(inputs, kDefaultStorage) &&
+      common::ContainsOnlyStorage(outputs, kDefaultStorage)) {
+    if (SupportMKLDNN(inputs[0])) {
+      MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+      MKLDNNFCForward(attrs, ctx, inputs, req, outputs);
+      MKLDNN_OPCHECK_RUN(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req,
+                         outputs);
+    } else {
+      FallBackCompute(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req, outputs);
+    }
     return;
+  } else if (valid_data && valid_weight && valid_bias && valid_out) {
+    // inputs
+    std::vector<NDArray> temp_ndarrays;
+    std::vector<TBlob> in_blobs;
+    for (const NDArray& in : inputs) {
+      // if ndarray is in default storage and MKLDNN is available,
+      // need to make sure cpu layout data is used, instead of MKL layout
+      if (in.storage_type() == kDefaultStorage) {
+        temp_ndarrays.push_back(in.Reorder2Default());
+        in_blobs.emplace_back(temp_ndarrays.back().data());
+      } else {
+        in_blobs.emplace_back(in.data());
+      }
+    }
+    // output
+    if (req[0] == kWriteTo) const_cast<NDArray &>(outputs[0]).InvalidateMKLDNNData();
+    FullyConnectedCompute<cpu>(attrs, ctx, in_blobs, req, {outputs[0].data()});
+  } else {
+    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+  }
+#else
+  if (valid_data && valid_weight && valid_bias && valid_out) {
+    std::vector<TBlob> in_blobs(inputs.size());
+    for (size_t i = 0; i < in_blobs.size(); i++) in_blobs[i] = inputs[i].data();
+    std::vector<TBlob> out_blobs(outputs.size());
+    for (size_t i = 0; i < out_blobs.size(); i++) out_blobs[i] = outputs[i].data();
+    FullyConnectedCompute<cpu>(attrs, ctx, in_blobs, req, out_blobs);
+  } else {
+    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
   }
-  FallBackCompute(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req, outputs);
+#endif
 }
 
+#if MXNET_USE_MKLDNN == 1
 void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs,
                                     const OpContext &ctx,
                                     const std::vector<NDArray> &inputs,
@@ -129,19 +177,27 @@ inline static bool FCStorageType(const nnvm::NodeAttrs& attrs,
                                  std::vector<int> *in_attrs,
                                  std::vector<int> *out_attrs) {
   const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
-  uint32_t in_expected = param.no_bias ? 2 : 3;
+  const bool valid_data = in_attrs->at(0) == kDefaultStorage;
+  const bool valid_weight = in_attrs->at(1) == kDefaultStorage ||
+                            in_attrs->at(1) == kRowSparseStorage;
+  bool valid_bias = true;
+  uint32_t in_expected = 2;
+  if (!param.no_bias) {
+    in_expected = 3;
+    valid_bias = in_attrs->at(2) == kDefaultStorage || in_attrs->at(2) == kRowSparseStorage;
+  }
   CHECK_EQ(in_attrs->size(), in_expected);
   CHECK_EQ(out_attrs->size(), 1);
-
-  DispatchMode wanted_mode;
-#if MXNET_USE_MKLDNN == 1
-  if (dev_mask == mshadow::cpu::kDevMask)
-    wanted_mode = DispatchMode::kFComputeEx;
-  else
-#endif
-    wanted_mode = DispatchMode::kFCompute;
-  return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
-                             dispatch_mode, wanted_mode);
+  // dispatch to kFComputeEx is fine even if all inputs are dense and no MKL is present
+  bool dispatched = false;
+  if (!dispatched && valid_data && valid_weight && valid_bias) {
+    dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFComputeEx);
+  }
+  if (!dispatched) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
+  return dispatched;
 }
 
 inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
@@ -170,6 +226,7 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
 DMLC_REGISTER_PARAMETER(FullyConnectedParam);
 
 NNVM_REGISTER_OP(FullyConnected)
+MXNET_ADD_SPARSE_OP_ALIAS(FullyConnected)
 .describe(R"code(Applies a linear transformation: :math:`Y = XW^T + b`.
 
 If ``flatten`` is set to be true, then the shapes are:
@@ -190,6 +247,10 @@ The learnable parameters include both ``weight`` and ``bias``.
 
 If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
 
+Note that the operator also supports forward computation with `row_sparse` weight and bias,
+where the length of `weight.indices` and `bias.indices` must be equal to `num_hidden`.
+This could be used for model inference with `row_sparse` weights trained with `SparseEmbedding`.
+
 )code" ADD_FILELINE)
 .set_num_inputs([](const NodeAttrs& attrs) {
   const FullyConnectedParam& params = nnvm::get<FullyConnectedParam>(attrs.parsed);
@@ -214,9 +275,7 @@ If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
 .set_attr<nnvm::FInferShape>("FInferShape", FullyConnectedShape)
 .set_attr<nnvm::FInferType>("FInferType", FullyConnectedType)
 .set_attr<FCompute>("FCompute<cpu>", FullyConnectedCompute<cpu>)
-#if MXNET_USE_MKLDNN == 1
 .set_attr<FComputeEx>("FComputeEx<cpu>", FullyConnectedComputeExCPU)
-#endif
 .set_attr<nnvm::FGradient>("FGradient", FullyConnectedGrad{"_backward_FullyConnected"})
 .add_argument("data", "NDArray-or-Symbol", "Input data.")
 .add_argument("weight", "NDArray-or-Symbol", "Weight matrix.")
diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py
index 29850dc..729ec84 100644
--- a/tests/python/unittest/test_gluon_contrib.py
+++ b/tests/python/unittest/test_gluon_contrib.py
@@ -108,6 +108,22 @@ def test_conv_fill_shape():
     check_rnn_forward(cell, mx.nd.ones((8, 3, 5, 7)))
     assert cell.i2h_weight.shape[1] == 5, cell.i2h_weight.shape[1]
 
+@with_seed()
+def test_lstmp():
+    nhid = 100
+    nproj = 64
+    cell = contrib.rnn.LSTMPCell(nhid, nproj, prefix='rnn_')
+    inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
+    outputs, _ = cell.unroll(3, inputs)
+    outputs = mx.sym.Group(outputs)
+    expected_params = ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_h2r_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
+    expected_outputs = ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']
+    assert sorted(cell.collect_params().keys()) == expected_params
+    assert outputs.list_outputs() == expected_outputs, outputs.list_outputs()
+
+    args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
+    assert outs == [(10, nproj), (10, nproj), (10, nproj)]
+
 
 @with_seed()
 def test_vardrop():
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index 3d6f9d0..182e70c 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -872,6 +872,7 @@ def test_sparse_nd_check_format():
     a = mx.nd.sparse.row_sparse_array((data_list, indices_list), shape=shape)
     assertRaises(mx.base.MXNetError, a.check_format)
 
+@with_seed()
 def test_sparse_nd_norm():
     def check_sparse_nd_norm(stype, shape, density):
         data, _ = rand_sparse_ndarray(shape, stype, density)
@@ -886,6 +887,23 @@ def test_sparse_nd_norm():
         for density in densities:
             check_sparse_nd_norm(stype, shape, density)
 
+@with_seed()
+def test_sparse_fc():
+    def check_sparse_fc(batch_size, dim_in, dim_out, stype):
+        data = rand_ndarray((batch_size, dim_in), stype, density=0.5)
+        weight = rand_ndarray((dim_out, dim_in), 'row_sparse', density=1)
+        bias = rand_ndarray((dim_out, 1), 'row_sparse', density=1)
+        out = mx.nd.sparse.FullyConnected(data, weight, num_hidden=dim_out, bias=bias)
+        data_dns = data.tostype('default')
+        weight_dns = weight.tostype('default')
+        out_dns = mx.nd.FullyConnected(data_dns, weight_dns, num_hidden=dim_out, bias=bias)
+        assert_almost_equal(out.asnumpy(), out_dns.asnumpy())
+
+    # test FC with row_sparse weight w/ density=1, dense data
+    check_sparse_fc(5, 10, 8, 'default')
+    # test FC with row_sparse weight w/ density=1, csr data (fallback)
+    check_sparse_fc(5, 10, 8, 'csr')
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.