You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/11 17:43:41 UTC

[GitHub] piiswrong closed pull request #11197: Gluon sparse block and sparse embedding

piiswrong closed pull request #11197: Gluon sparse block and sparse embedding
URL: https://github.com/apache/incubator-mxnet/pull/11197
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/api/python/gluon/contrib.md b/docs/api/python/gluon/contrib.md
index bc3089fa878..877a294d9a1 100644
--- a/docs/api/python/gluon/contrib.md
+++ b/docs/api/python/gluon/contrib.md
@@ -35,6 +35,7 @@ In the rest of this document, we list routines provided by the `gluon.contrib` p
     Concurrent
     HybridConcurrent
     Identity
+    SparseEmbedding
 ```
 
 ### Recurrent neural network
@@ -55,6 +56,7 @@ In the rest of this document, we list routines provided by the `gluon.contrib` p
     Conv1DGRUCell
     Conv2DGRUCell
     Conv3DGRUCell
+    LSTMPCell
 ```
 
 ### Data
diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py
index eccdf18c1bb..1edef1476ee 100644
--- a/python/mxnet/gluon/contrib/nn/basic_layers.py
+++ b/python/mxnet/gluon/contrib/nn/basic_layers.py
@@ -18,10 +18,10 @@
 # coding: utf-8
 # pylint: disable= arguments-differ
 """Custom neural network layers in model_zoo."""
-__all__ = ['Concurrent', 'HybridConcurrent', 'Identity']
+__all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding']
 
 from .... import nd
-from ...block import HybridBlock
+from ...block import HybridBlock, Block
 from ...nn import Sequential, HybridSequential
 
 class Concurrent(Sequential):
@@ -110,3 +110,44 @@ def __init__(self, prefix=None, params=None):
 
     def hybrid_forward(self, F, x):
         return x
+
+class SparseEmbedding(Block):
+    r"""Turns non-negative integers (indexes/tokens) into dense vectors
+    of fixed size. eg. [4, 20] -> [[0.25, 0.1], [0.6, -0.2]]
+
+    This SparseBlock is designed for distributed training with extremely large
+    input dimension. Both weight and gradient w.r.t. weight are `RowSparseNDArray`.
+
+    Parameters
+    ----------
+    input_dim : int
+        Size of the vocabulary, i.e. maximum integer index + 1.
+    output_dim : int
+        Dimension of the dense embedding.
+    dtype : str or np.dtype, default 'float32'
+        Data type of output embeddings.
+    weight_initializer : Initializer
+        Initializer for the `embeddings` matrix.
+
+    Inputs:
+        - **data**: (N-1)-D tensor with shape: `(x1, x2, ..., xN-1)`.
+    Output:
+        - **out**: N-D tensor with shape: `(x1, x2, ..., xN-1, output_dim)`.
+    """
+    def __init__(self, input_dim, output_dim, dtype='float32',
+                 weight_initializer=None, **kwargs):
+        super(SparseEmbedding, self).__init__(**kwargs)
+        self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim,
+                        'dtype': dtype, 'sparse_grad': True}
+        self.weight = self.params.get('weight', shape=(input_dim, output_dim),
+                                      init=weight_initializer, dtype=dtype,
+                                      grad_stype='row_sparse', stype='row_sparse')
+
+    def forward(self, x):
+        weight = self.weight.row_sparse_data(x)
+        return nd.Embedding(x, weight, name='fwd', **self._kwargs)
+
+    def __repr__(self):
+        s = '{block_name}({input_dim} -> {output_dim}, {dtype})'
+        return s.format(block_name=self.__class__.__name__,
+                        **self._kwargs)
diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py
index 729ec8407f2..264ff1f5e53 100644
--- a/tests/python/unittest/test_gluon_contrib.py
+++ b/tests/python/unittest/test_gluon_contrib.py
@@ -19,7 +19,7 @@
 import mxnet as mx
 from mxnet.gluon import contrib
 from mxnet.gluon import nn
-from mxnet.gluon.contrib.nn import Concurrent, HybridConcurrent, Identity
+from mxnet.gluon.contrib.nn import Concurrent, HybridConcurrent, Identity, SparseEmbedding
 from mxnet.test_utils import almost_equal
 from common import setup_module, with_seed
 import numpy as np
@@ -185,13 +185,25 @@ def test_concurrent():
     x.wait_to_read()
     x2.wait_to_read()
 
-
+@with_seed()
 def test_identity():
     model = Identity()
     x = mx.nd.random.uniform(shape=(128, 33, 64))
     mx.test_utils.assert_almost_equal(model(x).asnumpy(),
                                       x.asnumpy())
 
+@with_seed()
+def test_sparse_embedding():
+    layer = SparseEmbedding(10, 100)
+    layer.initialize()
+    trainer = mx.gluon.Trainer(layer.collect_params(), 'sgd')
+    x = mx.nd.array([3,4,2,0,1])
+    with mx.autograd.record():
+        y = layer(x)
+        y.backward()
+    assert (layer.weight.grad().asnumpy()[:5] == 1).all()
+    assert (layer.weight.grad().asnumpy()[5:] == 0).all()
+
 def test_datasets():
     wikitext2_train = contrib.data.text.WikiText2(root='data/wikitext-2', segment='train')
     wikitext2_val = contrib.data.text.WikiText2(root='data/wikitext-2', segment='validation',


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services