You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/08/10 17:19:09 UTC

[incubator-mxnet] branch master updated: reduce a copy for rowsparse parameter.reduce (#12039)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 6f7dee0  reduce a copy for rowsparse parameter.reduce (#12039)
6f7dee0 is described below

commit 6f7dee02cb670e514b5ed77c98b78b0a2caf8272
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Fri Aug 10 10:18:58 2018 -0700

    reduce a copy for rowsparse parameter.reduce (#12039)
---
 python/mxnet/gluon/parameter.py             |  2 +-
 python/mxnet/gluon/trainer.py               | 11 +++++++--
 tests/python/unittest/test_gluon_trainer.py | 35 +++++++++++++++--------------
 3 files changed, 28 insertions(+), 20 deletions(-)

diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 0c6aae9..1f6b86c 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -319,7 +319,7 @@ class Parameter(object):
             # fetch all rows for 'row_sparse' param
             all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=ctx)
             data = ndarray.zeros(self.shape, stype='row_sparse', ctx=ctx)
-            self._trainer._row_sparse_pull(self, data, all_row_ids)
+            self._trainer._row_sparse_pull(self, data, all_row_ids, full_idx=True)
         return data
 
     def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index 98a6878..028e660 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -235,14 +235,21 @@ class Trainer(object):
         else:
             self._optimizer.set_learning_rate(lr)
 
-    def _row_sparse_pull(self, parameter, out, row_id):
+    def _row_sparse_pull(self, parameter, out, row_id, full_idx=False):
+        """Internal method to invoke pull operations on KVStore. If `full_idx` is set to True,
+        `kv.pull` is preferred instead of `kv.row_sparse_pull`.
+        """
         # initialize kv and params if not already
         if not self._kv_initialized:
             self._init_kvstore()
         if self._params_to_init:
             self._init_params()
         idx = self._param2idx[parameter.name]
-        self._kvstore.row_sparse_pull(idx, out=out, row_ids=row_id, priority=-idx)
+        if full_idx and 'dist' not in self._kvstore.type:
+            assert row_id.size == out.shape[0]
+            self._kvstore.pull(idx, out=out, priority=-idx, ignore_sparse=False)
+        else:
+            self._kvstore.row_sparse_pull(idx, out=out, row_ids=row_id, priority=-idx)
 
     def step(self, batch_size, ignore_stale_grad=False):
         """Makes one step of parameter update. Should be called after
diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py
index 2a34400..72c01ac 100644
--- a/tests/python/unittest/test_gluon_trainer.py
+++ b/tests/python/unittest/test_gluon_trainer.py
@@ -114,6 +114,24 @@ def test_trainer_save_load():
     assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2
 
 @with_seed()
+def test_trainer_sparse_save_load():
+    x = gluon.Parameter('x', shape=(10, 1), lr_mult=1.0, stype='row_sparse')
+    x.initialize(ctx=[mx.cpu(0)], init='zeros')
+    trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 0.1})
+    all_rows = mx.nd.arange(0, 10, ctx=mx.cpu(0))
+    with mx.autograd.record():
+        for w in x.list_row_sparse_data(all_rows):
+            y = w * 1
+            y.backward()
+    trainer.step(1)
+    assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.1
+    trainer.save_states('test_trainer_sparse_save_load.states')
+    trainer.load_states('test_trainer_sparse_save_load.states')
+    x.lr_mult = 2.0
+    # check if parameter dict is correctly associated with optimizer after load_state
+    assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2
+
+@with_seed()
 def test_trainer_multi_layer_init():
     class Net(gluon.Block):
         def __init__(self, **kwargs):
@@ -159,23 +177,6 @@ def test_trainer_multi_layer_init():
     check_init([mx.cpu(1)])
 
 @with_seed()
-def test_trainer_save_load():
-    x = gluon.Parameter('x', shape=(10,), lr_mult=1.0)
-    x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
-    trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 0.1})
-    with mx.autograd.record():
-        for w in x.list_data():
-            y = w + 1
-            y.backward()
-    trainer.step(1)
-    assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.1
-    trainer.save_states('test_trainer_save_load.states')
-    trainer.load_states('test_trainer_save_load.states')
-    x.lr_mult = 2.0
-    # check if parameter dict is correctly associated with optimizer after load_state
-    assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2
-
-@with_seed()
 def test_trainer_reset_kv():
     def check_trainer_reset_kv(kv):
         params = gluon.ParameterDict()