You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/05/08 16:29:06 UTC

[GitHub] [incubator-mxnet] eric-haibin-lin commented on a change in pull request #14885: [Fit-API] Adress PR comments

eric-haibin-lin commented on a change in pull request #14885: [Fit-API] Adress PR comments
URL: https://github.com/apache/incubator-mxnet/pull/14885#discussion_r282142125
 
 

 ##########
 File path: python/mxnet/gluon/contrib/estimator/event_handler.py
 ##########
 @@ -263,108 +268,171 @@ def epoch_end(self, estimator, *args, **kwargs):
 
 
 class CheckpointHandler(BatchEnd, EpochEnd):
-    """Save the model after every epoch.
+    """Save the model after user define period
 
-    :py:class:`CheckpointHandler` save the network parameters every epoch
+    :py:class:`CheckpointHandler` saves the network architecture after first batch if the model
+    can be hybridized(), saves model parameters and trainer states after user defined period,
+    default saves every epoch.
 
     Parameters
     ----------
-    filepath : str
-        file name to save the parameters, it can contain directories,
-        for example: ./saved_model/resnet.params
+    model_dir : str
+        file directory to save all the model related files including model architecture,
+        model parameters, and trainer states.
+    model_prefix : str default 'model'
+        prefix to add for all checkpoint file names
     monitor: EvalMetric
-        the metrics to monitor
+        the metrics to monitor and determine if model has improved
     verbose: int, default 0
         verbosity mode
-    save_best_only: bool
-        if True, only save the parameters if monitored value improved
+    save_best: bool
+        if True, save the model parameters and trainer states with the best monitored value
     mode: str, default 'auto'
         one of {auto, min, max}, if `save_best_only=True`, the comparison to make
-        and determine if the monitored value has improved
-    period: int, default 1
-        intervals between saving the network
+        and determine if the monitored value has improved. if 'auto' mode, checkpoint
+        handler will try to use min or max based on the monitored metric name
+    epoch_period: int, default 1
+        epoch intervals between saving the network
+    batch_period: int, default None
+        batch intervals between saving the network,
+        by default don't save any checkpoint based on number of batches
+    max_checkpoints : int, default 5
+        maximum number of checkpoint files to keep in the model_dir, older checkpoints
+        will be removed
     """
 
     def __init__(self,
-                 filepath,
+                 model_dir,
+                 model_prefix='model',
                  monitor=None,
                  verbose=0,
-                 save_best_only=False,
+                 save_best=False,
                  mode='auto',
                  epoch_period=1,
-                 batch_period=None):
+                 batch_period=None,
+                 max_checkpoints=5):
         self.monitor = monitor
         self.verbose = verbose
-        self.filepath = filepath
-        self.save_best_only = save_best_only
-        if self.save_best_only and not isinstance(self.monitor, EvalMetric):
+        if not os.path.exists(model_dir):
+            os.makedirs(model_dir)
+        self.model_dir = model_dir
+        self.model_prefix = model_prefix
+        self.save_best = save_best
+        if self.save_best and not isinstance(self.monitor, EvalMetric):
             raise ValueError("To save best model only, please provide one of the metric objects as monitor, "
                              "You can get these objects using estimator.prepare_loss_and_metric()")
         self.epoch_period = epoch_period
         self.batch_period = batch_period
         self.num_batches = 0
         self.num_epochs = 0
+        self.max_checkpoints = max_checkpoints
+        self.saved_checkpoints = []
         self.logger = logging.getLogger(__name__)
-
-        if mode not in ['auto', 'min', 'max']:
-            warnings.warn('ModelCheckpoint mode %s is unknown, '
-                          'fallback to auto mode.' % (mode),
-                          RuntimeWarning)
-            mode = 'auto'
-
-        if mode == 'min':
-            self.monitor_op = np.less
-            self.best = np.Inf
-        elif mode == 'max':
-            self.monitor_op = np.greater
-            self.best = -np.Inf
-        else:
-            # use greater for accuracy and less otherwise
-            if 'acc' in self.monitor.get()[0].lower():
+        if self.save_best:
+            if mode not in ['auto', 'min', 'max']:
+                warnings.warn('ModelCheckpoint mode %s is unknown, '
+                              'fallback to auto mode. CheckpointHandler will use'
+                              'max mode for f1 and accuracy metric comparison and '
+                              'use min mode other wise' % (mode),
+                              RuntimeWarning)
+                mode = 'auto'
+
+            if mode == 'min':
+                self.monitor_op = np.less
+                self.best = np.Inf
+            elif mode == 'max':
                 self.monitor_op = np.greater
                 self.best = -np.Inf
             else:
-                self.monitor_op = np.less
-                self.best = np.Inf
+                # use greater for accuracy and f1 and less otherwise
+                if 'acc' or 'f1' in self.monitor.get()[0].lower():
+                    self.logger.info("`greater` operator will be used to determine "
+                                     "if %s has improved, please use `min` for mode "
+                                     "if you want otherwise", self.monitor.get()[0])
+                    self.monitor_op = np.greater
+                    self.best = -np.Inf
+                else:
+                    self.logger.info("`less` operator will be used to determine "
+                                     "if %s has improved, please use `max` for mode "
+                                     "if you want otherwise", self.monitor.get()[0])
+                    self.monitor_op = np.less
+                    self.best = np.Inf
 
     def batch_end(self, estimator, *args, **kwargs):
-        self._save_checkpoint(estimator.net, "Batch", self.num_batches)
+        # only save symbol once after first batch
+        if self.num_batches == 0:
+            self._save_symbol(estimator)
+        if self.batch_period:
+            self._save_checkpoint(estimator, "Batch", self.batch_period, self.num_batches)
         self.num_batches += 1
 
     def epoch_end(self, estimator, *args, **kwargs):
-        self._save_checkpoint(estimator.net, "Epoch", self.num_epochs)
+        if self.epoch_period:
+            self._save_checkpoint(estimator, "Epoch", self.epoch_period, self.num_epochs)
         self.num_epochs += 1
 
-    def _save_checkpoint(self, net, period_name, period_value):
+    def _save_checkpoint(self, estimator, period_name, period_value, num_of_periods):
+        # period name can be batch or epoch
+        # period value determine how often a checkpoint is saved
+        # num_of_periods records the number of batch or epoch
         # add extension for weights
-        if '.params' not in self.filepath:
-            self.filepath += '.params'
-        if self.num_epochs % self.epoch_period == 0:
-            if self.save_best_only:
+        if num_of_periods % period_value == 0:
+            if self.verbose > 0:
+                self.logger.info('[%s %d] saving model to %s', period_name, num_of_periods, self.model_dir)
+            prefix = "%s-%s%d" % (self.model_prefix, period_name.lower(), num_of_periods)
+            self._save_params_and_trainer(estimator, prefix)
+
+            if self.save_best:
                 monitor_name, monitor_value = self.monitor.get()
                 # check if monitor exists in train stats
                 if np.isnan(monitor_value):
-                    warnings.warn(RuntimeWarning('%s is not updated, make sure you pass one of the metric objects'
-                                                 'as monitor, you can use estimator.prepare_loss_and_metrics to'
+                    warnings.warn(RuntimeWarning('SKipping save best because %s is not updated, make sure you '
 
 Review comment:
   SKipping -> Skipping

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services