You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by wk...@apache.org on 2019/03/21 01:44:18 UTC

[incubator-mxnet] branch fit-api updated: Fixed issue where the estimator was printing beyond the dataset size … (#14464)

This is an automated email from the ASF dual-hosted git repository.

wkcn pushed a commit to branch fit-api
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/fit-api by this push:
     new e6c63b1  Fixed issue where the estimator was printing beyond the dataset size … (#14464)
e6c63b1 is described below

commit e6c63b112a8f741f63b287c33caa29e13bc1d87e
Author: Piyush Ghai <gh...@osu.edu>
AuthorDate: Wed Mar 20 18:44:00 2019 -0700

    Fixed issue where the estimator was printing beyond the dataset size … (#14464)
    
    * Fixed issue where the estimator was printing beyond the dataset size for the last batch
    
    * Added comments
    
    * Nudge to CI
---
 python/mxnet/gluon/estimator/estimator.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/python/mxnet/gluon/estimator/estimator.py b/python/mxnet/gluon/estimator/estimator.py
index 159f7e2..c160115 100644
--- a/python/mxnet/gluon/estimator/estimator.py
+++ b/python/mxnet/gluon/estimator/estimator.py
@@ -242,7 +242,11 @@ class Estimator(object):
                     self.train_stats['batch_' + loss_metric.name] = loss_metric.get()[1]
 
                 try:
-                    self.train_stats['step'] = "{}/{}".format(batch_size * (i + 1), len(train_data._dataset))
+                    completed_samples = len(train_data._dataset) if i == len(train_data._dataset) - 1 \
+                                        else batch_size * (i + 1)
+                    # We need to check if this is the last batch in the current epoch and select
+                    # the value to print appropriately
+                    self.train_stats['step'] = "{}/{}".format(completed_samples, len(train_data._dataset))
                 except AttributeError:
                     self.train_stats['step'] = i