You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2020/08/24 17:02:59 UTC

[incubator-mxnet] branch v1.x updated: Get rid of monkey patching in LossScaler overflow handling (#18959) (#18973)

This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new dfefe87  Get rid of monkey patching in LossScaler overflow handling (#18959) (#18973)
dfefe87 is described below

commit dfefe879997a1208f1f01e58635c61e86eed1d6a
Author: mk-61 <56...@users.noreply.github.com>
AuthorDate: Mon Aug 24 10:00:52 2020 -0700

    Get rid of monkey patching in LossScaler overflow handling (#18959) (#18973)
    
    Co-authored-by: Vladimir Cherepanov <vc...@nvidia.com>
    
    Co-authored-by: Vladimir Cherepanov <vc...@nvidia.com>
---
 python/mxnet/contrib/amp/amp.py         | 16 --------------
 python/mxnet/contrib/amp/loss_scaler.py | 39 +++++++++++++--------------------
 python/mxnet/gluon/trainer.py           |  5 +++++
 3 files changed, 20 insertions(+), 40 deletions(-)

diff --git a/python/mxnet/contrib/amp/amp.py b/python/mxnet/contrib/amp/amp.py
index 688d73e..fa30855 100755
--- a/python/mxnet/contrib/amp/amp.py
+++ b/python/mxnet/contrib/amp/amp.py
@@ -23,7 +23,6 @@ __all__ = ['init', 'init_trainer', 'scale_loss', 'unscale', 'convert_model',
            'list_widest_type_cast', 'list_loss_output_functions', 'list_lp16_use_fp32_params',
            'convert_symbol']
 
-from types import MethodType
 from array import array
 import ctypes
 import logging
@@ -341,21 +340,6 @@ def init_trainer(optimizer_or_trainer):
     if isinstance(optimizer_or_trainer, trainer.Trainer):
         optimizer_or_trainer._amp_loss_scaler = loss_scaler
         optimizer_or_trainer._amp_original_scale = optimizer_or_trainer._scale
-        skip_update = optimizer_or_trainer._amp_loss_scaler.wait_and_update
-        optimizer_or_trainer._optimizer.old_update_multi_precision = \
-                optimizer_or_trainer._optimizer.update_multi_precision
-        def new_update_multi_precision(self, index, weight, grad, state):
-            if not skip_update():
-                self.old_update_multi_precision(index, weight, grad, state)
-        optimizer_or_trainer._optimizer.update_multi_precision = \
-            MethodType(new_update_multi_precision, optimizer_or_trainer._optimizer)
-        launch_check_overflow = optimizer_or_trainer._amp_loss_scaler.launch_check_overflow
-        optimizer_or_trainer._old_update = optimizer_or_trainer._update
-        def new_update(self, ignore_stale_grad=False):
-            launch_check_overflow(self._params)
-            self._old_update(ignore_stale_grad)
-        optimizer_or_trainer._update = MethodType(new_update, optimizer_or_trainer)
-
     elif isinstance(optimizer_or_trainer, opt.Optimizer):
         # TODO(ptredak): make it work with the optimizer
         raise TypeError("AMP is currently only compatible with Gluon Trainer")
diff --git a/python/mxnet/contrib/amp/loss_scaler.py b/python/mxnet/contrib/amp/loss_scaler.py
index a2600bc..3a177ce 100755
--- a/python/mxnet/contrib/amp/loss_scaler.py
+++ b/python/mxnet/contrib/amp/loss_scaler.py
@@ -37,16 +37,13 @@ class LossScaler(object):
         self._max_loss_scale = 2.**24
         self._scale_seq_len = 2000
         self._unskipped = 0
-        self._has_overflow = False
 
     @property
     def loss_scale(self):
         return self._loss_scale
 
-    def launch_check_overflow(self, params):
-        """Launch overflow checking for gradients."""
-        self._wait_for_outputs = True
-        self._has_overflow = False
+    def has_overflow(self, params):
+        """Check gradients for overflow."""
         with ag.pause():
             chunk_size = 200
             valid_params = [p._grad[0] for p in params if p._grad is not None]
@@ -56,22 +53,16 @@ class LossScaler(object):
                 multi_all_finite(*valid_params[idx:idx+chunk_size],
                                  num_arrays=len(valid_params[idx:idx+chunk_size]),
                                  init_output=False, out=gpu_output)
-            self.output = gpu_output
-
-    def wait_and_update(self):
-        """Wait for the results of overflow checking and update the loss scale."""
-        if self._wait_for_outputs:
-            self._has_overflow = not bool(self.output.asnumpy())
-            self._loss_scale = self._next_loss_scale
-            if self._has_overflow:
-                self._next_loss_scale = self._loss_scale / 2.
-                self._unskipped = 0
-                logging.info("AMP: decreasing loss scale to %f", self._next_loss_scale)
-            else:
-                self._unskipped += 1
-            if self._unskipped == self._scale_seq_len:
-                self._unskipped = 0
-                self._next_loss_scale = min(self._max_loss_scale, self._loss_scale * 2.)
-                logging.info("AMP: increasing loss scale to %f", self._next_loss_scale)
-            self._wait_for_outputs = False
-        return self._has_overflow
+        has_overflow = not bool(gpu_output.asnumpy())
+        self._loss_scale = self._next_loss_scale
+        if has_overflow:
+            self._next_loss_scale = self._loss_scale / 2.
+            self._unskipped = 0
+            logging.info("AMP: decreasing loss scale to %f", self._next_loss_scale)
+        else:
+            self._unskipped += 1
+        if self._unskipped == self._scale_seq_len:
+            self._unskipped = 0
+            self._next_loss_scale = min(self._max_loss_scale, self._loss_scale * 2.)
+            logging.info("AMP: increasing loss scale to %f", self._next_loss_scale)
+        return has_overflow
diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index fed3c44..dd8551d 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -428,6 +428,11 @@ class Trainer(object):
         self._update(ignore_stale_grad)
 
     def _update(self, ignore_stale_grad=False):
+        loss_scaler = getattr(self, '_amp_loss_scaler', None)
+        if loss_scaler is not None:
+            if loss_scaler.has_overflow(self._params):
+                return  # skip on overflow
+
         updates = [[] for _ in self._updaters]
 
         for i, param in enumerate(self._params):