You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/14 18:18:55 UTC

[GitHub] piiswrong closed pull request #10861: split trainer.step into allreduce and update

piiswrong closed pull request #10861: split trainer.step into allreduce and update
URL: https://github.com/apache/incubator-mxnet/pull/10861
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index 5ae0e46b7dc..da67fc0b1d9 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -49,6 +49,9 @@ class Trainer(object):
         on the type of compression being used. For example, 2bit compression requires a threshold.
         Arguments would then be {'type':'2bit', 'threshold':0.5}
         See mxnet.KVStore.set_gradient_compression method for more details on gradient compression.
+    update_on_kvstore : bool, default None
+        Whether to perform parameter updates on kvstore. If None, then trainer will choose the more
+        suitable option depending on the type of kvstore.
 
     Properties
     ----------
@@ -57,7 +60,7 @@ class Trainer(object):
         optimizer, its learning rate can be accessed as optimizer.learning_rate.
     """
     def __init__(self, params, optimizer, optimizer_params=None, kvstore='device',
-                 compression_params=None):
+                 compression_params=None, update_on_kvstore=None):
         if isinstance(params, (dict, ParameterDict)):
             params = list(params.values())
         if not isinstance(params, (list, tuple)):
@@ -73,11 +76,12 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device',
             self._params.append(param)
         self._compression_params = compression_params
         optimizer_params = optimizer_params if optimizer_params else {}
-        self._scale = optimizer_params.get('rescale_grad', 1.0)
+        self._scale = float(optimizer_params.get('rescale_grad', 1.0))
         self._contexts = self._check_contexts()
         self._init_optimizer(optimizer, optimizer_params)
         self._kv_initialized = False
         self._kvstore = kvstore
+        self._update_on_kvstore = update_on_kvstore
 
     def _check_contexts(self):
         contexts = None
@@ -109,6 +113,8 @@ def _init_kvstore(self):
         arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params}
         kvstore, update_on_kvstore = _create_kvstore(self._kvstore, len(self._contexts),
                                                      arg_arrays)
+        update_on_kvstore = self._update_on_kvstore if self._update_on_kvstore is not None \
+                            else update_on_kvstore
         if kvstore:
             if self._compression_params:
                 kvstore.set_gradient_compression(self._compression_params)
@@ -129,7 +135,6 @@ def _init_kvstore(self):
 
         self._kv_initialized = True
 
-
     @property
     def learning_rate(self):
         if not isinstance(self._optimizer, opt.Optimizer):
@@ -138,7 +143,6 @@ def learning_rate(self):
         else:
             return self._optimizer.learning_rate
 
-
     def set_learning_rate(self, lr):
         """Sets a new learning rate of the optimizer.
 
@@ -153,10 +157,73 @@ def set_learning_rate(self, lr):
         else:
             self._optimizer.set_learning_rate(lr)
 
-
     def step(self, batch_size, ignore_stale_grad=False):
         """Makes one step of parameter update. Should be called after
-        `autograd.compute_gradient` and outside of `record()` scope.
+        `autograd.backward()` and outside of `record()` scope.
+
+        For normal parameter updates, `step()` should be used, which internally calls
+        `allreduce_grads()` and then `update()`. However, if you need to get the reduced
+        gradients to perform certain transformation, such as in gradient clipping, then
+        you may want to manually call `allreduce_grads()` and `update()` separately.
+
+        Parameters
+        ----------
+        batch_size : int
+            Batch size of data processed. Gradient will be normalized by `1/batch_size`.
+            Set this to 1 if you normalized loss manually with `loss = mean(loss)`.
+        ignore_stale_grad : bool, optional, default=False
+            If true, ignores Parameters with stale gradient (gradient that has not
+            been updated by `backward` after last step) and skip update.
+        """
+        if not self._kv_initialized:
+            self._init_kvstore()
+
+        self._optimizer.rescale_grad = self._scale / batch_size
+
+        self._allreduce_grads()
+        self._update(ignore_stale_grad)
+
+    def allreduce_grads(self):
+        """For each parameter, reduce the gradients from different contexts.
+
+        Should be called after `autograd.backward()`, outside of `record()` scope,
+        and before `trainer.update()`.
+
+        For normal parameter updates, `step()` should be used, which internally calls
+        `allreduce_grads()` and then `update()`. However, if you need to get the reduced
+        gradients to perform certain transformation, such as in gradient clipping, then
+        you may want to manually call `allreduce_grads()` and `update()` separately.
+        """
+        if not self._kv_initialized:
+            self._init_kvstore()
+        assert not (self._kvstore and self._update_on_kvstore), \
+                'allreduce_grads() when parameters are updated on kvstore ' \
+                'is not supported. Try setting `update_on_kvstore` ' \
+                'to False when creating trainer.'
+
+        self._allreduce_grads()
+
+    def _allreduce_grads(self):
+        if self._kvstore:
+            for i, param in enumerate(self._params):
+                if param.grad_req != 'null':
+
+                    self._kvstore.push(i, param.list_grad(), priority=-i)
+
+                    if not self._update_on_kvstore:
+                        self._kvstore.pull(i, param.list_grad(), priority=-i)
+
+    def update(self, batch_size, ignore_stale_grad=False):
+        """Makes one step of parameter update.
+
+        Should be called after `autograd.backward()` and outside of `record()` scope,
+        and after `trainer.update()`.
+
+
+        For normal parameter updates, `step()` should be used, which internally calls
+        `allreduce_grads()` and then `update()`. However, if you need to get the reduced
+        gradients to perform certain transformation, such as in gradient clipping, then
+        you may want to manually call `allreduce_grads()` and `update()` separately.
 
         Parameters
         ----------
@@ -169,12 +236,19 @@ def step(self, batch_size, ignore_stale_grad=False):
         """
         if not self._kv_initialized:
             self._init_kvstore()
+        assert not (self._kvstore and self._update_on_kvstore), \
+                'update() when parameters are updated on kvstore ' \
+                'is not supported. Try setting `update_on_kvstore` ' \
+                'to False when creating trainer.'
 
         self._optimizer.rescale_grad = self._scale / batch_size
+        self._update(ignore_stale_grad)
 
+    def _update(self, ignore_stale_grad=False):
         for i, param in enumerate(self._params):
             if param.grad_req == 'null':
                 continue
+
             if not ignore_stale_grad:
                 for data in param.list_data():
                     if not data._fresh_grad:
@@ -187,13 +261,9 @@ def step(self, batch_size, ignore_stale_grad=False):
                             "warning and skip updating of Parameters with stale gradient" \
                             %(param.name, str(data.context)))
 
-            if self._kvstore:
-                self._kvstore.push(i, param.list_grad(), priority=-i)
-                if self._update_on_kvstore:
-                    self._kvstore.pull(i, param.list_data(), priority=-i)
-                    continue
-                else:
-                    self._kvstore.pull(i, param.list_grad(), priority=-i)
+            if self._kvstore and self._update_on_kvstore:
+                self._kvstore.pull(i, param.list_data(), priority=-i)
+                continue
 
             for upd, arr, grad in zip(self._updaters, param.list_data(), param.list_grad()):
                 if not ignore_stale_grad or arr._fresh_grad:
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 350f8856436..b054aa6555f 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -21,7 +21,7 @@
 from mxnet.test_utils import assert_almost_equal
 from common import setup_module, with_seed
 import numpy as np
-from nose.tools import raises
+from nose.tools import raises, assert_raises
 from copy import deepcopy
 import warnings
 import json
@@ -520,6 +520,23 @@ def dict_equ(a, b):
         for updater in trainer._updaters:
             dict_equ(updater.states, states)
         assert trainer._optimizer == trainer._updaters[0].optimizer
+    assert_raises(AssertionError, trainer.update, 1)
+    assert_raises(AssertionError, trainer.allreduce_grads)
+
+    x = gluon.Parameter('x', shape=(10,))
+    x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
+    trainer2 = gluon.Trainer([x], 'sgd', {'learning_rate': 1.0, 'momentum': 0.5},
+                             update_on_kvstore=False)
+    with mx.autograd.record():
+        for i, w in enumerate(x.list_data()):
+            y = i*w
+            y.backward()
+    assert (x.grad(mx.cpu(0)).asnumpy() != x.grad(mx.cpu(1)).asnumpy()).all()
+    trainer2.allreduce_grads()
+    assert (x.grad(mx.cpu(0)).asnumpy() == x.grad(mx.cpu(1)).asnumpy()).all()
+    trainer2.update(1)
+
+    assert (x.data(mx.cpu(1)).asnumpy() == -1).all(), x.data(mx.cpu(1)).asnumpy()
 
 
 @with_seed()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services