You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/06/11 17:44:07 UTC

[incubator-mxnet] branch master updated: [WIP] Gluon sparse block and sparse embedding (#11197)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 715457d  [WIP] Gluon sparse block and sparse embedding (#11197)
715457d is described below

commit 715457d94ebf8935e34dd6bd445b3ba3950fe9d4
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Mon Jun 11 10:43:40 2018 -0700

    [WIP] Gluon sparse block and sparse embedding (#11197)
    
    * add sparse block
    
    * add sparse embedding
    
    * add doc
    
    * lint
    
    * remove sparseblock
---
 docs/api/python/gluon/contrib.md              |  2 ++
 python/mxnet/gluon/contrib/nn/basic_layers.py | 45 +++++++++++++++++++++++++--
 tests/python/unittest/test_gluon_contrib.py   | 16 ++++++++--
 3 files changed, 59 insertions(+), 4 deletions(-)

diff --git a/docs/api/python/gluon/contrib.md b/docs/api/python/gluon/contrib.md
index bc3089f..877a294 100644
--- a/docs/api/python/gluon/contrib.md
+++ b/docs/api/python/gluon/contrib.md
@@ -35,6 +35,7 @@ In the rest of this document, we list routines provided by the `gluon.contrib` p
     Concurrent
     HybridConcurrent
     Identity
+    SparseEmbedding
 ```
 
 ### Recurrent neural network
@@ -55,6 +56,7 @@ In the rest of this document, we list routines provided by the `gluon.contrib` p
     Conv1DGRUCell
     Conv2DGRUCell
     Conv3DGRUCell
+    LSTMPCell
 ```
 
 ### Data
diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py
index eccdf18..1edef14 100644
--- a/python/mxnet/gluon/contrib/nn/basic_layers.py
+++ b/python/mxnet/gluon/contrib/nn/basic_layers.py
@@ -18,10 +18,10 @@
 # coding: utf-8
 # pylint: disable= arguments-differ
 """Custom neural network layers in model_zoo."""
-__all__ = ['Concurrent', 'HybridConcurrent', 'Identity']
+__all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding']
 
 from .... import nd
-from ...block import HybridBlock
+from ...block import HybridBlock, Block
 from ...nn import Sequential, HybridSequential
 
 class Concurrent(Sequential):
@@ -110,3 +110,44 @@ class Identity(HybridBlock):
 
     def hybrid_forward(self, F, x):
         return x
+
+class SparseEmbedding(Block):
+    r"""Turns non-negative integers (indexes/tokens) into dense vectors
+    of fixed size. eg. [4, 20] -> [[0.25, 0.1], [0.6, -0.2]]
+
+    This SparseBlock is designed for distributed training with extremely large
+    input dimension. Both weight and gradient w.r.t. weight are `RowSparseNDArray`.
+
+    Parameters
+    ----------
+    input_dim : int
+        Size of the vocabulary, i.e. maximum integer index + 1.
+    output_dim : int
+        Dimension of the dense embedding.
+    dtype : str or np.dtype, default 'float32'
+        Data type of output embeddings.
+    weight_initializer : Initializer
+        Initializer for the `embeddings` matrix.
+
+    Inputs:
+        - **data**: (N-1)-D tensor with shape: `(x1, x2, ..., xN-1)`.
+    Output:
+        - **out**: N-D tensor with shape: `(x1, x2, ..., xN-1, output_dim)`.
+    """
+    def __init__(self, input_dim, output_dim, dtype='float32',
+                 weight_initializer=None, **kwargs):
+        super(SparseEmbedding, self).__init__(**kwargs)
+        self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim,
+                        'dtype': dtype, 'sparse_grad': True}
+        self.weight = self.params.get('weight', shape=(input_dim, output_dim),
+                                      init=weight_initializer, dtype=dtype,
+                                      grad_stype='row_sparse', stype='row_sparse')
+
+    def forward(self, x):
+        weight = self.weight.row_sparse_data(x)
+        return nd.Embedding(x, weight, name='fwd', **self._kwargs)
+
+    def __repr__(self):
+        s = '{block_name}({input_dim} -> {output_dim}, {dtype})'
+        return s.format(block_name=self.__class__.__name__,
+                        **self._kwargs)
diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py
index 729ec84..264ff1f 100644
--- a/tests/python/unittest/test_gluon_contrib.py
+++ b/tests/python/unittest/test_gluon_contrib.py
@@ -19,7 +19,7 @@ from __future__ import print_function
 import mxnet as mx
 from mxnet.gluon import contrib
 from mxnet.gluon import nn
-from mxnet.gluon.contrib.nn import Concurrent, HybridConcurrent, Identity
+from mxnet.gluon.contrib.nn import Concurrent, HybridConcurrent, Identity, SparseEmbedding
 from mxnet.test_utils import almost_equal
 from common import setup_module, with_seed
 import numpy as np
@@ -185,13 +185,25 @@ def test_concurrent():
     x.wait_to_read()
     x2.wait_to_read()
 
-
+@with_seed()
 def test_identity():
     model = Identity()
     x = mx.nd.random.uniform(shape=(128, 33, 64))
     mx.test_utils.assert_almost_equal(model(x).asnumpy(),
                                       x.asnumpy())
 
+@with_seed()
+def test_sparse_embedding():
+    layer = SparseEmbedding(10, 100)
+    layer.initialize()
+    trainer = mx.gluon.Trainer(layer.collect_params(), 'sgd')
+    x = mx.nd.array([3,4,2,0,1])
+    with mx.autograd.record():
+        y = layer(x)
+        y.backward()
+    assert (layer.weight.grad().asnumpy()[:5] == 1).all()
+    assert (layer.weight.grad().asnumpy()[5:] == 0).all()
+
 def test_datasets():
     wikitext2_train = contrib.data.text.WikiText2(root='data/wikitext-2', segment='train')
     wikitext2_val = contrib.data.text.WikiText2(root='data/wikitext-2', segment='validation',

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.