You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/09/19 18:30:05 UTC

[incubator-mxnet] branch master updated: Track epoch metric separately (#12182)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new ce6525a  Track epoch metric separately (#12182)
ce6525a is described below

commit ce6525ac9084a7df5789ca62c5228e3a3f55bc07
Author: Vandana Kannan <va...@users.noreply.github.com>
AuthorDate: Wed Sep 19 11:29:51 2018 -0700

    Track epoch metric separately (#12182)
---
 python/mxnet/callback.py           | 10 +++++++---
 python/mxnet/module/base_module.py |  9 ++++++++-
 2 files changed, 15 insertions(+), 4 deletions(-)

diff --git a/python/mxnet/callback.py b/python/mxnet/callback.py
index 5c76280..e1c1714 100644
--- a/python/mxnet/callback.py
+++ b/python/mxnet/callback.py
@@ -165,9 +165,13 @@ class Speedometer(object):
                     name_value = param.eval_metric.get_name_value()
                     if self.auto_reset:
                         param.eval_metric.reset()
-                    msg = 'Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec'
-                    msg += '\t%s=%f'*len(name_value)
-                    logging.info(msg, param.epoch, count, speed, *sum(name_value, ()))
+                        msg = 'Epoch[%d] Batch [%d-%d]\tSpeed: %.2f samples/sec'
+                        msg += '\t%s=%f'*len(name_value)
+                        logging.info(msg, param.epoch, count-self.frequent, count, speed, *sum(name_value, ()))
+                    else:
+                        msg = 'Epoch[%d] Batch [0-%d]\tSpeed: %.2f samples/sec'
+                        msg += '\t%s=%f'*len(name_value)
+                        logging.info(msg, param.epoch, count, speed, *sum(name_value, ()))
                 else:
                     logging.info("Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec",
                                  param.epoch, count, speed)
diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py
index 08ab8fa..c534261 100644
--- a/python/mxnet/module/base_module.py
+++ b/python/mxnet/module/base_module.py
@@ -22,6 +22,7 @@
 import time
 import logging
 import warnings
+import copy
 import numpy as np
 
 from .. import metric
@@ -507,6 +508,7 @@ class BaseModule(object):
             validation_metric = eval_metric
         if not isinstance(eval_metric, metric.EvalMetric):
             eval_metric = metric.create(eval_metric)
+        epoch_eval_metric = copy.deepcopy(eval_metric)
 
         ################################################################################
         # training loop
@@ -514,6 +516,7 @@ class BaseModule(object):
         for epoch in range(begin_epoch, num_epoch):
             tic = time.time()
             eval_metric.reset()
+            epoch_eval_metric.reset()
             nbatch = 0
             data_iter = iter(train_data)
             end_of_batch = False
@@ -529,8 +532,12 @@ class BaseModule(object):
                     self.update_metric(eval_metric,
                                        [db.label for db in data_batch],
                                        pre_sliced=True)
+                    self.update_metric(epoch_eval_metric,
+                                       [db.label for db in data_batch],
+                                       pre_sliced=True)
                 else:
                     self.update_metric(eval_metric, data_batch.label)
+                    self.update_metric(epoch_eval_metric, data_batch.label)
 
                 try:
                     # pre fetch next batch
@@ -543,7 +550,7 @@ class BaseModule(object):
                     monitor.toc_print()
 
                 if end_of_batch:
-                    eval_name_vals = eval_metric.get_name_value()
+                    eval_name_vals = epoch_eval_metric.get_name_value()
 
                 if batch_end_callback is not None:
                     batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,