You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/03/21 18:48:02 UTC

[GitHub] [incubator-mxnet] karan6181 commented on a change in pull request #14442: [MXNet-1349][Fit API]Add validation support and unit tests for fit() API

karan6181 commented on a change in pull request #14442: [MXNet-1349][Fit API]Add validation support and unit tests for fit() API
URL: https://github.com/apache/incubator-mxnet/pull/14442#discussion_r267907770
 
 

 ##########
 File path: python/mxnet/gluon/estimator/estimator.py
 ##########
 @@ -62,37 +63,45 @@ def __init__(self, net,
 
         if isinstance(loss, gluon.loss.Loss):
             self.loss = [loss]
+        elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss) for l in loss]):
+            self.loss = loss
         else:
-            self.loss = loss or []
-            for l in self.loss:
-                if not isinstance(loss, gluon.loss.Loss):
-                    raise ValueError("loss must be a Loss or a list of Loss, refer to gluon.loss.Loss")
+            raise ValueError("loss must be a Loss or a list of Loss, "
+                             "refer to gluon.loss.Loss:{}".format(loss))
 
         if isinstance(metrics, EvalMetric):
-            self.metrics = [metrics]
+            self.train_metrics = [metrics]
         else:
-            self.metrics = metrics or []
-            for metric in self.metrics:
+            self.train_metrics = metrics or []
+            for metric in self.train_metrics:
                 if not isinstance(metric, EvalMetric):
-                    raise ValueError("metrics must be a Metric or a list of Metric, refer to mxnet.metric.EvalMetric")
+                    raise ValueError("metrics must be a Metric or a list of Metric, "
+                                     "refer to mxnet.metric.EvalMetric:{}".format(metric))
+
+        # Use same metrics for validation
+        self.val_metrics = copy.deepcopy(self.train_metrics)
 
-        self.initializer = initializer
         # store training statistics
         self.train_stats = {}
         self.train_stats['epochs'] = []
         self.train_stats['learning_rate'] = []
         # current step of the epoch
         self.train_stats['step'] = ''
-        for metric in self.metrics:
+        for metric in self.train_metrics:
             # record a history of metrics over each epoch
             self.train_stats['train_' + metric.name] = []
             # only record the latest metric numbers after each batch
             self.train_stats['batch_' + metric.name] = 0.
-        self.loss_metrics = []
+        for metric in self.val_metrics:
 
 Review comment:
   Can we have one for loop for `self.train_metrics` and `self.val_metrics` since value for both the parameter is same according to line 82. Something like this `for train_m, val_m in zip(self.train_metrics, self.val_metrics)`. Though `zip()` operator stop after exhausting shorter array but since both the array are of same length, we can use `zip()` operator.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services