You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/05/02 23:02:22 UTC

[GitHub] [incubator-mxnet] pinaraws commented on a change in pull request #14629: [MXNET-1333] Estimator and Fit API

pinaraws commented on a change in pull request #14629: [MXNET-1333] Estimator and Fit API
URL: https://github.com/apache/incubator-mxnet/pull/14629#discussion_r280625345
 
 

 ##########
 File path: python/mxnet/gluon/contrib/estimator/event_handler.py
 ##########
 @@ -0,0 +1,459 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import, unused-argument
+"""Gluon EventHandlers for Estimators"""
+
+import logging
+import os
+import time
+import warnings
+
+import numpy as np
+
+from ....metric import EvalMetric, Loss
+
+
+class TrainBegin(object):
+    def train_begin(self, estimator, *args, **kwargs):
+        pass
+
+
+class TrainEnd(object):
+    def train_end(self, estimator, *args, **kwargs):
+        pass
+
+
+class EpochBegin(object):
+    def epoch_begin(self, estimator, *args, **kwargs):
+        pass
+
+
+class EpochEnd(object):
+    def epoch_end(self, estimator, *args, **kwargs):
+        return False
+
+
+class BatchBegin(object):
+    def batch_begin(self, estimator, *args, **kwargs):
+        pass
+
+
+class BatchEnd(object):
+    def batch_end(self, estimator, *args, **kwargs):
+        return False
+
+
+class MetricHandler(EpochBegin, BatchEnd):
+    """Metric Handler that update metric values at batch end
+
+    :py:class:`MetricHandler` takes model predictions and true labels
+    and update the metrics, it also update metric wrapper for loss with loss values
+    Validation loss and metrics will be handled by :py:class:`ValidationHandler`
+
+    Parameters
+    ----------
+    train_metrics : List of EvalMetrics
+        training metrics to be updated at batch end
+    """
+
+    def __init__(self, train_metrics):
+        self.train_metrics = train_metrics or []
+        # order to be called among all callbacks
+        # metrics need to be calculated before other callbacks can access them
+        self.priority = -np.Inf
+
+    def epoch_begin(self, estimator, *args, **kwargs):
+        for metric in self.train_metrics:
+            metric.reset()
+
+    def batch_end(self, estimator, *args, **kwargs):
+        pred = kwargs['pred']
+        label = kwargs['label']
+        loss = kwargs['loss']
+        for metric in self.train_metrics:
+            if isinstance(metric, Loss):
+                # metric wrapper for loss values
+                metric.update(0, loss)
+            else:
+                metric.update(label, pred)
+
+
+class ValidationHandler(BatchEnd, EpochEnd):
+    """"Validation Handler that evaluate model on validation dataset
+
+    :py:class:`ValidationHandler` takes validation dataset, an evaluation function,
+    metrics to be evaluated, and how often to run the validation. You can provide custom
+    evaluation function or use the one provided my :py:class:`Estimator`
+
+    Parameters
+    ----------
+    val_data : DataLoader
+        validation data set to run evaluation
+    eval_fn : function
+        a function defines how to run evaluation and
+        calculate loss and metrics
+    val_metrics : List of EvalMetrics
+        validation metrics to be updated
+    epoch_period : int, default 1
+        how often to run validation at epoch end, by default
+        validate every epoch
+    batch_period : int, default None
+        how often to run validation at batch end, by default
+        does not validate at batch end
+    """
+
+    def __init__(self,
+                 val_data,
+                 eval_fn,
+                 val_metrics=None,
+                 epoch_period=1,
+                 batch_period=None):
+        self.val_data = val_data
+        self.eval_fn = eval_fn
+        self.epoch_period = epoch_period
+        self.batch_period = batch_period
+        self.val_metrics = val_metrics
+        self.num_batches = 0
+        self.num_epochs = 0
+        # order to be called among all callbacks
+        # validation metrics need to be calculated before other callbacks can access them
+        self.priority = -np.Inf
+
+    def batch_end(self, estimator, *args, **kwargs):
+        if self.batch_period and self.num_batches % self.batch_period == 0:
+            self.eval_fn(val_data=self.val_data,
+                         val_metrics=self.val_metrics)
+        self.num_batches += 1
+
+    def epoch_end(self, estimator, *args, **kwargs):
+        if self.num_epochs % self.epoch_period == 0:
+            self.eval_fn(val_data=self.val_data,
+                         val_metrics=self.val_metrics)
+
+        self.num_epochs += 1
+
+
+class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd):
+    """Basic Logging Handler that applies to every Gluon estimator by default.
+
+    :py:class:`LoggingHandler` logs hyper-parameters, training statistics,
+    and other useful information during training
+
+    Parameters
+    ----------
+    file_name : str
+        file name to save the logs
+    file_location : str
+        file location to save the logs
+    verbose : int, default LOG_VERBOSITY_PER_EPOCH
+        Limit the granularity of metrics displayed during training process
+        verbose=LOG_VERBOSITY_PER_EPOCH: display metrics every epoch
+        verbose=LOG_VERBOSITY_PER_BATCH: display metrics every batch
+    train_metrics : list of EvalMetrics
+        training metrics to be logged, logged at batch end, epoch end, train end
+    val_metrics : list of EvalMetrics
+        validation metrics to be logged, logged at epoch end, train end
+    """
+
+    LOG_VERBOSITY_PER_EPOCH = 1
+    LOG_VERBOSITY_PER_BATCH = 2
+
+    def __init__(self, file_name=None,
+                 file_location=None,
+                 verbose=LOG_VERBOSITY_PER_EPOCH,
+                 train_metrics=None,
+                 val_metrics=None):
+        super(LoggingHandler, self).__init__()
+        self.logger = logging.getLogger(__name__)
+        self.logger.setLevel(logging.INFO)
+        stream_handler = logging.StreamHandler()
+        self.logger.addHandler(stream_handler)
+        if verbose not in [self.LOG_VERBOSITY_PER_EPOCH, self.LOG_VERBOSITY_PER_BATCH]:
+            raise ValueError("verbose level must be either LOG_VERBOSITY_PER_EPOCH or "
+                             "LOG_VERBOSITY_PER_BATCH, received %s. "
+                             "E.g: LoggingHandler(verbose=LoggingHandler.LOG_VERBOSITY_PER_EPOCH)"
+                             % verbose)
+        self.verbose = verbose
+        # save logger to file only if file name or location is specified
+        if file_name or file_location:
+            file_name = file_name or 'estimator_log'
+            file_location = file_location or './'
+            file_handler = logging.FileHandler(os.path.join(file_location, file_name))
+            self.logger.addHandler(file_handler)
+        self.train_metrics = train_metrics or []
+        self.val_metrics = val_metrics or []
+        self.batch_index = 0
+        self.current_epoch = 0
+        self.processed_samples = 0
+        # logging handler need to be called at last to make sure all states are updated
+        # it will also shut down logging at train end
+        self.priority = np.Inf
+
+    def train_begin(self, estimator, *args, **kwargs):
+        self.train_start = time.time()
+        trainer = estimator.trainer
+        optimizer = trainer.optimizer.__class__.__name__
+        lr = trainer.learning_rate
+        self.logger.info("Training begin: using optimizer %s "
+                         "with current learning rate %.4f ",
 
 Review comment:
   print learning rate at every epoch

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services