You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2020/08/24 17:02:59 UTC
[incubator-mxnet] branch v1.x updated: Get rid of monkey patching
in LossScaler overflow handling (#18959) (#18973)
This is an automated email from the ASF dual-hosted git repository.
ptrendx pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new dfefe87 Get rid of monkey patching in LossScaler overflow handling (#18959) (#18973)
dfefe87 is described below
commit dfefe879997a1208f1f01e58635c61e86eed1d6a
Author: mk-61 <56...@users.noreply.github.com>
AuthorDate: Mon Aug 24 10:00:52 2020 -0700
Get rid of monkey patching in LossScaler overflow handling (#18959) (#18973)
Co-authored-by: Vladimir Cherepanov <vc...@nvidia.com>
Co-authored-by: Vladimir Cherepanov <vc...@nvidia.com>
---
python/mxnet/contrib/amp/amp.py | 16 --------------
python/mxnet/contrib/amp/loss_scaler.py | 39 +++++++++++++--------------------
python/mxnet/gluon/trainer.py | 5 +++++
3 files changed, 20 insertions(+), 40 deletions(-)
diff --git a/python/mxnet/contrib/amp/amp.py b/python/mxnet/contrib/amp/amp.py
index 688d73e..fa30855 100755
--- a/python/mxnet/contrib/amp/amp.py
+++ b/python/mxnet/contrib/amp/amp.py
@@ -23,7 +23,6 @@ __all__ = ['init', 'init_trainer', 'scale_loss', 'unscale', 'convert_model',
'list_widest_type_cast', 'list_loss_output_functions', 'list_lp16_use_fp32_params',
'convert_symbol']
-from types import MethodType
from array import array
import ctypes
import logging
@@ -341,21 +340,6 @@ def init_trainer(optimizer_or_trainer):
if isinstance(optimizer_or_trainer, trainer.Trainer):
optimizer_or_trainer._amp_loss_scaler = loss_scaler
optimizer_or_trainer._amp_original_scale = optimizer_or_trainer._scale
- skip_update = optimizer_or_trainer._amp_loss_scaler.wait_and_update
- optimizer_or_trainer._optimizer.old_update_multi_precision = \
- optimizer_or_trainer._optimizer.update_multi_precision
- def new_update_multi_precision(self, index, weight, grad, state):
- if not skip_update():
- self.old_update_multi_precision(index, weight, grad, state)
- optimizer_or_trainer._optimizer.update_multi_precision = \
- MethodType(new_update_multi_precision, optimizer_or_trainer._optimizer)
- launch_check_overflow = optimizer_or_trainer._amp_loss_scaler.launch_check_overflow
- optimizer_or_trainer._old_update = optimizer_or_trainer._update
- def new_update(self, ignore_stale_grad=False):
- launch_check_overflow(self._params)
- self._old_update(ignore_stale_grad)
- optimizer_or_trainer._update = MethodType(new_update, optimizer_or_trainer)
-
elif isinstance(optimizer_or_trainer, opt.Optimizer):
# TODO(ptredak): make it work with the optimizer
raise TypeError("AMP is currently only compatible with Gluon Trainer")
diff --git a/python/mxnet/contrib/amp/loss_scaler.py b/python/mxnet/contrib/amp/loss_scaler.py
index a2600bc..3a177ce 100755
--- a/python/mxnet/contrib/amp/loss_scaler.py
+++ b/python/mxnet/contrib/amp/loss_scaler.py
@@ -37,16 +37,13 @@ class LossScaler(object):
self._max_loss_scale = 2.**24
self._scale_seq_len = 2000
self._unskipped = 0
- self._has_overflow = False
@property
def loss_scale(self):
return self._loss_scale
- def launch_check_overflow(self, params):
- """Launch overflow checking for gradients."""
- self._wait_for_outputs = True
- self._has_overflow = False
+ def has_overflow(self, params):
+ """Check gradients for overflow."""
with ag.pause():
chunk_size = 200
valid_params = [p._grad[0] for p in params if p._grad is not None]
@@ -56,22 +53,16 @@ class LossScaler(object):
multi_all_finite(*valid_params[idx:idx+chunk_size],
num_arrays=len(valid_params[idx:idx+chunk_size]),
init_output=False, out=gpu_output)
- self.output = gpu_output
-
- def wait_and_update(self):
- """Wait for the results of overflow checking and update the loss scale."""
- if self._wait_for_outputs:
- self._has_overflow = not bool(self.output.asnumpy())
- self._loss_scale = self._next_loss_scale
- if self._has_overflow:
- self._next_loss_scale = self._loss_scale / 2.
- self._unskipped = 0
- logging.info("AMP: decreasing loss scale to %f", self._next_loss_scale)
- else:
- self._unskipped += 1
- if self._unskipped == self._scale_seq_len:
- self._unskipped = 0
- self._next_loss_scale = min(self._max_loss_scale, self._loss_scale * 2.)
- logging.info("AMP: increasing loss scale to %f", self._next_loss_scale)
- self._wait_for_outputs = False
- return self._has_overflow
+ has_overflow = not bool(gpu_output.asnumpy())
+ self._loss_scale = self._next_loss_scale
+ if has_overflow:
+ self._next_loss_scale = self._loss_scale / 2.
+ self._unskipped = 0
+ logging.info("AMP: decreasing loss scale to %f", self._next_loss_scale)
+ else:
+ self._unskipped += 1
+ if self._unskipped == self._scale_seq_len:
+ self._unskipped = 0
+ self._next_loss_scale = min(self._max_loss_scale, self._loss_scale * 2.)
+ logging.info("AMP: increasing loss scale to %f", self._next_loss_scale)
+ return has_overflow
diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index fed3c44..dd8551d 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -428,6 +428,11 @@ class Trainer(object):
self._update(ignore_stale_grad)
def _update(self, ignore_stale_grad=False):
+ loss_scaler = getattr(self, '_amp_loss_scaler', None)
+ if loss_scaler is not None:
+ if loss_scaler.has_overflow(self._params):
+ return # skip on overflow
+
updates = [[] for _ in self._updaters]
for i, param in enumerate(self._params):