You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/16 00:10:01 UTC

[GitHub] piiswrong closed pull request #10924: [Sparse-Gluon] embedding with sparse grad

piiswrong closed pull request #10924: [Sparse-Gluon] embedding with sparse grad
URL: https://github.com/apache/incubator-mxnet/pull/10924
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index d86c3e6ce4f..abde51b433a 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -381,7 +381,8 @@ class Embedding(HybridBlock):
         Data type of output embeddings.
     weight_initializer : Initializer
         Initializer for the `embeddings` matrix.
-
+    sparse_grad: bool
+        If True, gradient w.r.t. weight will be a 'row_sparse' NDArray.
 
     Inputs:
         - **data**: (N-1)-D tensor with shape: `(x1, x2, ..., xN-1)`.
@@ -390,13 +391,14 @@ class Embedding(HybridBlock):
         - **out**: N-D tensor with shape: `(x1, x2, ..., xN-1, output_dim)`.
     """
     def __init__(self, input_dim, output_dim, dtype='float32',
-                 weight_initializer=None, **kwargs):
+                 weight_initializer=None, sparse_grad=False, **kwargs):
         super(Embedding, self).__init__(**kwargs)
+        grad_stype = 'row_sparse' if sparse_grad else 'default'
         self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim,
-                        'dtype': dtype}
+                        'dtype': dtype, 'sparse_grad': sparse_grad}
         self.weight = self.params.get('weight', shape=(input_dim, output_dim),
                                       init=weight_initializer, dtype=dtype,
-                                      allow_deferred_init=True)
+                                      allow_deferred_init=True, grad_stype=grad_stype)
 
     def hybrid_forward(self, F, x, weight):
         return F.Embedding(x, weight, name='fwd', **self._kwargs)
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 320b376fe0b..c7cbcccc95e 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -81,6 +81,8 @@ class Parameter(object):
         Weight decay multiplier (L2 regularizer coefficient). Works similar to lr_mult.
     init : Initializer, default None
         Initializer of this parameter. Will use the global initializer by default.
+    grad_stype: {'default', 'row_sparse', 'csr'}, defaults to 'default'.
+        The storage type of the parameter's gradient.
 
     Attributes
     ----------
@@ -97,7 +99,7 @@ class Parameter(object):
     """
     def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t,
                  lr_mult=1.0, wd_mult=1.0, init=None, allow_deferred_init=False,
-                 differentiable=True):
+                 differentiable=True, grad_stype='default'):
         self._var = None
         self._data = None
         self._grad = None
@@ -114,6 +116,11 @@ def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t,
         self.wd_mult = wd_mult
         self.grad_req = grad_req
         self.init = init
+        assert grad_stype in ['default', 'row_sparse', 'csr'], \
+            "grad_stype for Parameter '%s' must be one of 'default', 'row_sparse', or 'csr'," \
+            " but got '%s'" % (name, grad_stype)
+        self._grad_stype = grad_stype
+
 
     def __repr__(self):
         s = 'Parameter {name} (shape={shape}, dtype={dtype})'
@@ -261,7 +268,9 @@ def _init_grad(self):
             self._grad = None
             return
 
-        self._grad = [ndarray.zeros_like(i) for i in self._data]
+        self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, ctx=i.context,
+                                    stype=self._grad_stype) for i in self._data]
+
         autograd.mark_variables(self.list_data(), self.list_grad(), self.grad_req)
 
     def _reduce(self):
@@ -431,7 +440,7 @@ def zero_grad(self):
         if self._grad is None:
             return
         for i in self._grad:
-            i[:] = 0
+            ndarray.zeros_like(i, out=i)
 
     def var(self):
         """Returns a symbol representing this parameter."""
diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index da67fc0b1d9..39c4a1fd610 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -110,7 +110,17 @@ def _init_optimizer(self, optimizer, optimizer_params):
                             for _ in self._contexts]
 
     def _init_kvstore(self):
-        arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params}
+        arg_arrays = {}
+        contains_sparse = False
+        for param in self._params:
+            arg_arrays[param.name] = param.data(self._contexts[0])
+            if param._grad_stype != 'default':
+                contains_sparse = True
+                # update_on_kvstore is set to False by the user
+                if self._update_on_kvstore is False:
+                    raise RuntimeError("Cannot set update_on_kvstore to False when sparse "
+                                       "gradients and/or sparse weights are present for "
+                                       "Parameter %s." % param.name)
         kvstore, update_on_kvstore = _create_kvstore(self._kvstore, len(self._contexts),
                                                      arg_arrays)
         update_on_kvstore = self._update_on_kvstore if self._update_on_kvstore is not None \
@@ -118,8 +128,12 @@ def _init_kvstore(self):
         if kvstore:
             if self._compression_params:
                 kvstore.set_gradient_compression(self._compression_params)
-            if 'dist' in kvstore.type:
-                update_on_kvstore = False
+            # kv.pull(row_sparse_grad) is not supported
+            if contains_sparse:
+                update_on_kvstore = True
+            else:
+                if 'dist' in kvstore.type:
+                    update_on_kvstore = False
             if update_on_kvstore:
                 kvstore.set_optimizer(self._optimizer)
             # optimizer preferably needs to be set before init for multiprecision
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 2ac6c11a167..38ecf121dfe 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -276,7 +276,7 @@ class KVStoreLocal : public KVStore {
       // invalid, print warning messages once
       if (this->warnings_printed_.find(key) == this->warnings_printed_.end()) {
         LOG(INFO) << "Warning: non-default weights detected during kvstore pull. "
-                     "This call has been ignored. Please make sure to use"
+                     "This call has been ignored. Please make sure to use "
                      "kv.row_sparse_pull() or module.prepare() with row_ids.";
         this->warnings_printed_.insert(key);
       }
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index 0c74cac2dca..6c966055637 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -383,8 +383,8 @@ void FillComputeZerosEx(const nnvm::NodeAttrs& attrs,
   Stream<xpu> *s = ctx.get_stream<xpu>();
   CHECK_EQ(outputs.size(), 1);
   auto stype = outputs[0].storage_type();
-  if (req[0] == kNullOp) return;
-  CHECK_EQ(req[0], kWriteTo) << "kWriteTo is expected for FillComputeZerosEx";
+  // x + 0 == x
+  if (req[0] == kNullOp || req[0] == kAddTo) return;
   if (stype == kRowSparseStorage) {
     FillZerosRspImpl(s, outputs[0]);
   } else if (stype == kCSRStorage) {
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index b054aa6555f..946b1406e78 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -38,6 +38,21 @@ def test_parameter():
     assert p.data(mx.cpu(1)).context == mx.cpu(1)
     assert p.data(mx.cpu(0)).shape == (10, 10)
     assert p.var().name == 'weight'
+    assert p.grad(mx.cpu(0)).stype == 'default'
+
+    p.reset_ctx(ctx=[mx.cpu(1), mx.cpu(2)])
+    assert p.list_ctx() == [mx.cpu(1), mx.cpu(2)]
+
+@with_seed()
+def test_sparse_parameter():
+    p = gluon.Parameter('weight', shape=(10, 10), grad_stype='row_sparse')
+    p.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)])
+    assert len(p.list_data()) == 2
+    assert len(p.list_grad()) == 2
+    assert p.data(mx.cpu(1)).context == mx.cpu(1)
+    assert p.data(mx.cpu(0)).shape == (10, 10)
+    assert p.var().name == 'weight'
+    assert p.grad(mx.cpu(0)).stype == 'row_sparse'
 
     p.reset_ctx(ctx=[mx.cpu(1), mx.cpu(2)])
     assert p.list_ctx() == [mx.cpu(1), mx.cpu(2)]
@@ -676,15 +691,17 @@ def test_global_norm_clip():
 
 @with_seed()
 def test_embedding():
-    layer = gluon.nn.Embedding(10, 100)
-    layer.initialize()
-    x = mx.nd.array([3,4,2,0,1])
-    with mx.autograd.record():
-        y = layer(x)
-        y.backward()
-    assert (layer.weight.grad()[:5] == 1).asnumpy().all()
-    assert (layer.weight.grad()[5:] == 0).asnumpy().all()
-
+    def check_embedding(sparse_grad):
+        layer = gluon.nn.Embedding(10, 100, sparse_grad=sparse_grad)
+        layer.initialize()
+        x = mx.nd.array([3,4,2,0,1])
+        with mx.autograd.record():
+            y = layer(x)
+            y.backward()
+        assert (layer.weight.grad().asnumpy()[:5] == 1).all()
+        assert (layer.weight.grad().asnumpy()[5:] == 0).all()
+    check_embedding(True)
+    check_embedding(False)
 
 @with_seed()
 def test_export():
@@ -977,6 +994,7 @@ def test_req():
     assert_almost_equal(grad * 2, grad_double)
 
 
+@with_seed()
 def test_save_load():
     net = mx.gluon.model_zoo.vision.get_resnet(1, 18, pretrained=True)
     net.save_params('test.params')
@@ -987,6 +1005,7 @@ def test_save_load():
     net.load_params('test.params')
 
 
+@with_seed()
 def test_hybrid_multi_context():
     net = mx.gluon.model_zoo.vision.get_resnet(1, 18)
     net.initialize(ctx=[mx.cpu(0), mx.cpu(1)])
@@ -994,6 +1013,19 @@ def test_hybrid_multi_context():
     net(mx.nd.zeros((1, 3, 32, 32), ctx=mx.cpu(0))).asnumpy()
 
 
+@with_seed()
+def test_zero_grad():
+    data = mx.nd.random.uniform(shape=(3,3))
+    net = nn.Embedding(3, 4, sparse_grad=True, prefix='test_zero_grad_')
+    net.initialize()
+    with mx.autograd.record():
+        l = net(data)
+        l.backward()
+    net.collect_params().zero_grad()
+    grad = net.collect_params()['test_zero_grad_weight'].grad()
+    assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services