You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/09/19 18:30:05 UTC
[incubator-mxnet] branch master updated: Track epoch metric
separately (#12182)
This is an automated email from the ASF dual-hosted git repository.
skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new ce6525a Track epoch metric separately (#12182)
ce6525a is described below
commit ce6525ac9084a7df5789ca62c5228e3a3f55bc07
Author: Vandana Kannan <va...@users.noreply.github.com>
AuthorDate: Wed Sep 19 11:29:51 2018 -0700
Track epoch metric separately (#12182)
---
python/mxnet/callback.py | 10 +++++++---
python/mxnet/module/base_module.py | 9 ++++++++-
2 files changed, 15 insertions(+), 4 deletions(-)
diff --git a/python/mxnet/callback.py b/python/mxnet/callback.py
index 5c76280..e1c1714 100644
--- a/python/mxnet/callback.py
+++ b/python/mxnet/callback.py
@@ -165,9 +165,13 @@ class Speedometer(object):
name_value = param.eval_metric.get_name_value()
if self.auto_reset:
param.eval_metric.reset()
- msg = 'Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec'
- msg += '\t%s=%f'*len(name_value)
- logging.info(msg, param.epoch, count, speed, *sum(name_value, ()))
+ msg = 'Epoch[%d] Batch [%d-%d]\tSpeed: %.2f samples/sec'
+ msg += '\t%s=%f'*len(name_value)
+ logging.info(msg, param.epoch, count-self.frequent, count, speed, *sum(name_value, ()))
+ else:
+ msg = 'Epoch[%d] Batch [0-%d]\tSpeed: %.2f samples/sec'
+ msg += '\t%s=%f'*len(name_value)
+ logging.info(msg, param.epoch, count, speed, *sum(name_value, ()))
else:
logging.info("Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec",
param.epoch, count, speed)
diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py
index 08ab8fa..c534261 100644
--- a/python/mxnet/module/base_module.py
+++ b/python/mxnet/module/base_module.py
@@ -22,6 +22,7 @@
import time
import logging
import warnings
+import copy
import numpy as np
from .. import metric
@@ -507,6 +508,7 @@ class BaseModule(object):
validation_metric = eval_metric
if not isinstance(eval_metric, metric.EvalMetric):
eval_metric = metric.create(eval_metric)
+ epoch_eval_metric = copy.deepcopy(eval_metric)
################################################################################
# training loop
@@ -514,6 +516,7 @@ class BaseModule(object):
for epoch in range(begin_epoch, num_epoch):
tic = time.time()
eval_metric.reset()
+ epoch_eval_metric.reset()
nbatch = 0
data_iter = iter(train_data)
end_of_batch = False
@@ -529,8 +532,12 @@ class BaseModule(object):
self.update_metric(eval_metric,
[db.label for db in data_batch],
pre_sliced=True)
+ self.update_metric(epoch_eval_metric,
+ [db.label for db in data_batch],
+ pre_sliced=True)
else:
self.update_metric(eval_metric, data_batch.label)
+ self.update_metric(epoch_eval_metric, data_batch.label)
try:
# pre fetch next batch
@@ -543,7 +550,7 @@ class BaseModule(object):
monitor.toc_print()
if end_of_batch:
- eval_name_vals = eval_metric.get_name_value()
+ eval_name_vals = epoch_eval_metric.get_name_value()
if batch_end_callback is not None:
batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,