You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2019/03/25 18:18:03 UTC

[incubator-mxnet] branch fit-api updated: [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (#14442)

This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch fit-api
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/fit-api by this push:
     new 8186772  [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (#14442)
8186772 is described below

commit 8186772f78c1a4e5c9412d1e860f31f05e31fcfd
Author: Abhinav Sharma <ab...@gmail.com>
AuthorDate: Mon Mar 25 11:17:31 2019 -0700

    [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (#14442)
    
    * added estimator unittests
    
    * add more tests for estimator
    
    * added validation logic
    
    * added error handlers, unittests
    
    * improve val stats
    
    * fix pylint
    
    * fix pylint
    
    * update unit test
    
    * fix tests
    
    * fix tests
    
    * updated metrics, val logic
    
    * trigger ci
    
    * trigger ci
    
    * update metric, batch_fn error handler
    
    * update context logic, add default metric
---
 python/mxnet/gluon/estimator/estimator.py     | 116 ++++++++---
 python/mxnet/gluon/estimator/event_handler.py |   2 +-
 tests/python/unittest/test_gluon_estimator.py | 277 ++++++++++++++++++++++++++
 3 files changed, 370 insertions(+), 25 deletions(-)

diff --git a/python/mxnet/gluon/estimator/estimator.py b/python/mxnet/gluon/estimator/estimator.py
index c160115..e759fa7 100644
--- a/python/mxnet/gluon/estimator/estimator.py
+++ b/python/mxnet/gluon/estimator/estimator.py
@@ -19,13 +19,14 @@
 # pylint: disable=wildcard-import
 """Gluon Estimator"""
 
+import copy
 import warnings
 
 from .event_handler import LoggingHandler
 from ... import gluon, autograd
 from ...context import Context, cpu, gpu, num_gpus
 from ...io import DataIter
-from ...metric import EvalMetric, Loss
+from ...metric import EvalMetric, Loss, Accuracy
 
 __all__ = ['Estimator']
 
@@ -62,44 +63,57 @@ class Estimator(object):
 
         if isinstance(loss, gluon.loss.Loss):
             self.loss = [loss]
+        elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss) for l in loss]):
+            self.loss = loss
         else:
-            self.loss = loss or []
-            for l in self.loss:
-                if not isinstance(loss, gluon.loss.Loss):
-                    raise ValueError("loss must be a Loss or a list of Loss, refer to gluon.loss.Loss")
+            raise ValueError("loss must be a Loss or a list of Loss, "
+                             "refer to gluon.loss.Loss:{}".format(loss))
 
         if isinstance(metrics, EvalMetric):
-            self.metrics = [metrics]
+            self.train_metrics = [metrics]
         else:
-            self.metrics = metrics or []
-            for metric in self.metrics:
-                if not isinstance(metric, EvalMetric):
-                    raise ValueError("metrics must be a Metric or a list of Metric, refer to mxnet.metric.EvalMetric")
+            self.train_metrics = metrics or []
+            if not all([isinstance(metric, EvalMetric) for metric in self.train_metrics]):
+                raise ValueError("metrics must be a Metric or a list of Metric, "
+                                 "refer to mxnet.metric.EvalMetric:{}".format(metrics))
+
+        # Use default mx.metric.Accuracy() for gluon.loss.SoftmaxCrossEntropyLoss()
+        if not self.train_metrics and any([isinstance(l, gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]):
+            self.train_metrics = [Accuracy()]
+
+        # Use same metrics for validation
+        self.val_metrics = copy.deepcopy(self.train_metrics)
 
-        self.initializer = initializer
         # store training statistics
         self.train_stats = {}
         self.train_stats['epochs'] = []
         self.train_stats['learning_rate'] = []
         # current step of the epoch
         self.train_stats['step'] = ''
-        for metric in self.metrics:
+        for metric in self.train_metrics:
             # record a history of metrics over each epoch
             self.train_stats['train_' + metric.name] = []
             # only record the latest metric numbers after each batch
             self.train_stats['batch_' + metric.name] = 0.
-        self.loss_metrics = []
+        for metric in self.val_metrics:
+            self.train_stats['val_' + metric.name] = []
+        self.train_loss_metrics = []
+        self.val_loss_metrics = []
         # using the metric wrapper for loss to record loss value
         for l in self.loss:
-            self.loss_metrics.append(Loss(l.name))
+            self.train_loss_metrics.append(Loss(l.name))
+            self.val_loss_metrics.append(Loss(l.name))
             self.train_stats['train_' + l.name] = []
+            self.train_stats['val_' + l.name] = []
             # only record the latest loss numbers after each batch
             self.train_stats['batch_' + l.name] = 0.
 
         # handle context
         if isinstance(context, Context):
             self.context = [context]
-        if not context:
+        elif isinstance(context, list) and all([isinstance(c, Context) for c in context]):
+            self.context = context
+        elif not context:
             if num_gpus() > 0:
                 # only use 1 GPU by default
                 if num_gpus() > 1:
@@ -109,8 +123,13 @@ class Estimator(object):
                 self.context = [gpu(0)]
             else:
                 self.context = [cpu()]
+        else:
+            raise ValueError("context must be a Context or a list of Context, "
+                             "refer to mxnet.Context:{}".format(context))
+
 
         # initialize the network
+        self.initializer = initializer
         if self.initializer:
             if self._is_initialized():
                 # if already initialized, re-init with user specified initializer
@@ -128,13 +147,13 @@ class Estimator(object):
         # handle trainers
         if isinstance(trainers, gluon.Trainer):
             self.trainers = [trainers]
-        else:
-            self.trainers = trainers or []
-        if not self.trainers:
+        elif not trainers:
             warnings.warn("No trainer specified, default SGD optimizer "
                           "with learning rate 0.001 is used.")
             self.trainers = [gluon.Trainer(self.net.collect_params(),
                                            'sgd', {'learning_rate': 0.001})]
+        else:
+            raise ValueError("Invalid trainer specified, please provide a valid gluon.Trainer")
 
     def _is_initialized(self):
         param_dict = self.net.collect_params()
@@ -156,7 +175,48 @@ class Estimator(object):
         label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0)
         return data, label
 
+    def evaluate(self,
+                 val_data,
+                 batch_fn=None):
+        """Evaluate model on validation data
+
+         Parameters
+         ----------
+         val_data : DataLoader or DataIter
+             validation data with data and labels
+         batch_fn : function
+             custom batch function to extract data and label
+             from a data batch and load into contexts(devices)
+         """
+
+        for metric in self.val_metrics + self.val_loss_metrics:
+            metric.reset()
+
+        for _, batch in enumerate(val_data):
+            if not batch_fn:
+                if isinstance(val_data, gluon.data.DataLoader):
+                    data, label = self._batch_fn(batch, self.context)
+                elif isinstance(val_data, DataIter):
+                    data, label = self._batch_fn(batch, self.context, is_iterator=True)
+                else:
+                    raise ValueError("You are using a custom iteration, please also provide "
+                                     "batch_fn to extract data and label. Alternatively, you "
+                                     "can provide the data as gluon.data.DataLoader or "
+                                     "mx.io.DataIter")
+            else:
+                data, label = batch_fn(batch, self.context)
+            pred = [self.net(x) for x in data]
+            losses = []
+            for loss in self.loss:
+                losses.append([loss(y_hat, y) for y_hat, y in zip(pred, label)])
+            # update metrics
+            for metric in self.val_metrics:
+                metric.update(label, pred)
+            for loss, loss_metric, in zip(losses, self.val_loss_metrics):
+                loss_metric.update(0, [l for l in loss])
+
     def fit(self, train_data,
+            val_data=None,
             epochs=1,
             batch_size=None,
             event_handlers=None,
@@ -204,7 +264,7 @@ class Estimator(object):
             for handler in event_handlers:
                 handler.epoch_begin()
 
-            for metric in self.metrics + self.loss_metrics:
+            for metric in self.train_metrics + self.train_loss_metrics:
                 metric.reset()
 
             for i, batch in enumerate(train_data):
@@ -215,7 +275,9 @@ class Estimator(object):
                         data, label = self._batch_fn(batch, self.context, is_iterator=True)
                     else:
                         raise ValueError("You are using a custom iteration, please also provide "
-                                         "batch_fn to extract data and label")
+                                         "batch_fn to extract data and label. Alternatively, you "
+                                         "can provide the data as gluon.data.DataLoader or "
+                                         "mx.io.DataIter")
                 else:
                     data, label = batch_fn(batch, self.context)
 
@@ -233,11 +295,11 @@ class Estimator(object):
                     for l in loss:
                         l.backward()
 
-                # update metrics
-                for metric in self.metrics:
+                # update train metrics
+                for metric in self.train_metrics:
                     metric.update(label, pred)
                     self.train_stats['batch_' + metric.name] = metric.get()[1]
-                for loss, loss_metric, in zip(losses, self.loss_metrics):
+                for loss, loss_metric, in zip(losses, self.train_loss_metrics):
                     loss_metric.update(0, [l for l in loss])
                     self.train_stats['batch_' + loss_metric.name] = loss_metric.get()[1]
 
@@ -257,8 +319,14 @@ class Estimator(object):
                 for handler in event_handlers:
                     handler.batch_end()
 
-            for metric in self.metrics + self.loss_metrics:
+            if val_data:
+                self.evaluate(val_data, batch_fn)
+
+            for metric in self.train_metrics + self.train_loss_metrics:
                 self.train_stats['train_' + metric.name].append(metric.get()[1])
+            for metric in self.val_metrics + self.val_loss_metrics:
+                self.train_stats['val_' + metric.name].append(metric.get()[1])
+
             # epoch end
             for handler in event_handlers:
                 handler.epoch_end()
diff --git a/python/mxnet/gluon/estimator/event_handler.py b/python/mxnet/gluon/estimator/event_handler.py
index 0162c36..c59644e 100644
--- a/python/mxnet/gluon/estimator/event_handler.py
+++ b/python/mxnet/gluon/estimator/event_handler.py
@@ -118,7 +118,7 @@ class LoggingHandler(EventHandler):
         epoch = self._estimator.train_stats['epochs'][-1]
         msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time)
         for key in self._estimator.train_stats.keys():
-            if key.startswith('train_') or key.startswith('test_'):
+            if key.startswith('train_') or key.startswith('val_'):
                 msg += key + ': ' + '%.4f ' % self._estimator.train_stats[key][epoch]
         self.logger.info(msg)
 
diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py
new file mode 100644
index 0000000..85e61ce
--- /dev/null
+++ b/tests/python/unittest/test_gluon_estimator.py
@@ -0,0 +1,277 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+''' Unit tests for Gluon Estimator '''
+
+import unittest
+import sys
+import warnings
+from nose.tools import assert_raises
+import mxnet as mx
+from mxnet import gluon
+from mxnet.gluon import nn
+from mxnet.gluon.estimator import estimator
+
+
+def get_model():
+    net = nn.Sequential()
+    net.add(nn.Dense(4, activation='relu', flatten=False))
+    return net
+
+
+def test_fit():
+    ''' test estimator with different train data types '''
+    net = get_model()
+    num_epochs = 1
+    batch_size = 4
+    ctx = mx.cpu()
+    loss = gluon.loss.L2Loss()
+    acc = mx.metric.Accuracy()
+    net.initialize(ctx=ctx)
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+    est = estimator.Estimator(net=net,
+                              loss=loss,
+                              metrics=acc,
+                              trainers=trainer,
+                              context=ctx)
+    in_data = mx.nd.random.uniform(shape=(10, 3))
+    out_data = mx.nd.random.uniform(shape=(10, 4))
+    # Input dataloader
+    dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
+    train_dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size)
+    est.fit(train_data=train_dataloader,
+            epochs=num_epochs,
+            batch_size=batch_size)
+
+    # Input dataiter
+    train_dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size)
+    est.fit(train_data=train_dataiter,
+            epochs=num_epochs,
+            batch_size=batch_size)
+
+    # Input NDArray
+    with assert_raises(ValueError):
+        est.fit(train_data=[in_data, out_data],
+                epochs=num_epochs,
+                batch_size=batch_size)
+
+
+def test_validation():
+    ''' test different validation data types'''
+    net = get_model()
+    num_epochs = 1
+    batch_size = 4
+    ctx = mx.cpu()
+    loss = gluon.loss.L2Loss()
+    acc = mx.metric.Accuracy()
+    net.initialize(ctx=ctx)
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+    est = estimator.Estimator(net=net,
+                              loss=loss,
+                              metrics=acc,
+                              trainers=trainer,
+                              context=ctx)
+    in_data = mx.nd.random.uniform(shape=(10, 3))
+    out_data = mx.nd.random.uniform(shape=(10, 4))
+    # Input dataloader
+    dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
+    train_dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size)
+    val_dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size)
+    est.fit(train_data=train_dataloader,
+            val_data=val_dataloader,
+            epochs=num_epochs,
+            batch_size=batch_size)
+
+    # Input dataiter
+    train_dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size)
+    val_dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size)
+    est.fit(train_data=train_dataiter,
+            val_data=val_dataiter,
+            epochs=num_epochs,
+            batch_size=batch_size)
+    # Input NDArray
+    with assert_raises(ValueError):
+        est.fit(train_data=[in_data, out_data],
+                val_data=[in_data, out_data],
+                epochs=num_epochs,
+                batch_size=batch_size)
+
+
+@unittest.skipIf(sys.version_info.major < 3, 'Test on python 3')
+def test_initializer():
+    ''' test with no initializer, inconsistent initializer '''
+    net = get_model()
+    num_epochs = 1
+    batch_size = 4
+    ctx = mx.cpu()
+    in_data = mx.nd.random.uniform(shape=(10, 3))
+    out_data = mx.nd.random.uniform(shape=(10, 4))
+    dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
+    train_data = gluon.data.DataLoader(dataset, batch_size=batch_size)
+    loss = gluon.loss.L2Loss()
+    acc = mx.metric.Accuracy()
+    # no initializer
+    est = estimator.Estimator(net=net,
+                              loss=loss,
+                              metrics=acc,
+                              context=ctx)
+    est.fit(train_data=train_data,
+            epochs=num_epochs,
+            batch_size=batch_size)
+
+    # different initializer for net and estimator
+    net = get_model()
+    net.initialize(mx.init.Xavier(), ctx=ctx)
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+    # catch reinit warning
+    with warnings.catch_warnings(record=True) as w:
+        est = estimator.Estimator(net=net,
+                                  loss=loss,
+                                  metrics=acc,
+                                  initializer=mx.init.MSRAPrelu(),
+                                  trainers=trainer,
+                                  context=ctx)
+        assert 'Network already initialized' in str(w[-1].message)
+    est.fit(train_data=train_data,
+            epochs=num_epochs,
+            batch_size=batch_size)
+
+
+@unittest.skipIf(sys.version_info.major < 3, 'Test on python 3')
+def test_trainer():
+    ''' test with no trainer and invalid trainer '''
+    net = get_model()
+    num_epochs = 1
+    batch_size = 4
+    ctx = mx.cpu()
+    in_data = mx.nd.random.uniform(shape=(10, 3))
+    out_data = mx.nd.random.uniform(shape=(10, 4))
+    dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
+    train_data = gluon.data.DataLoader(dataset, batch_size=batch_size)
+    loss = gluon.loss.L2Loss()
+    acc = mx.metric.Accuracy()
+    net.initialize(ctx=ctx)
+    # input no trainer
+    with warnings.catch_warnings(record=True) as w:
+        est = estimator.Estimator(net=net,
+                                  loss=loss,
+                                  metrics=acc,
+                                  context=ctx)
+        assert 'No trainer specified' in str(w[-1].message)
+    est.fit(train_data=train_data,
+            epochs=num_epochs,
+            batch_size=batch_size)
+
+    # input invalid trainer
+    trainer = 'sgd'
+    with assert_raises(ValueError):
+        est = estimator.Estimator(net=net,
+                                  loss=loss,
+                                  metrics=acc,
+                                  trainers=trainer,
+                                  context=ctx)
+
+
+def test_metric():
+    ''' test with no metric, list of metrics, invalid metric '''
+    net = get_model()
+    num_epochs = 1
+    batch_size = 4
+    ctx = mx.cpu()
+    in_data = mx.nd.random.uniform(shape=(10, 3))
+    out_data = mx.nd.random.uniform(shape=(10, 4))
+    dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
+    train_data = gluon.data.DataLoader(dataset, batch_size=batch_size)
+    loss = gluon.loss.L2Loss()
+    net.initialize(ctx=ctx)
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+    # input no metric
+    est = estimator.Estimator(net=net,
+                              loss=loss,
+                              trainers=trainer,
+                              context=ctx)
+    est.fit(train_data=train_data,
+            epochs=num_epochs,
+            batch_size=batch_size)
+    # input list of metrics
+    metrics = [mx.metric.Accuracy(), mx.metric.Accuracy()]
+    est = estimator.Estimator(net=net,
+                              loss=loss,
+                              metrics=metrics,
+                              trainers=trainer,
+                              context=ctx)
+    est.fit(train_data=train_data,
+            epochs=num_epochs,
+            batch_size=batch_size)
+    # input invalid metric
+    with assert_raises(ValueError):
+        est = estimator.Estimator(net=net,
+                                  loss=loss,
+                                  metrics='acc',
+                                  trainers=trainer,
+                                  context=ctx)
+    # test default metric
+    loss = gluon.loss.SoftmaxCrossEntropyLoss()
+    est = estimator.Estimator(net=net,
+                              loss=loss,
+                              trainers=trainer,
+                              context=ctx)
+    assert isinstance(est.train_metrics[0], mx.metric.Accuracy)
+
+
+def test_loss():
+    ''' test with no loss, invalid loss '''
+    net = get_model()
+    ctx = mx.cpu()
+    acc = mx.metric.Accuracy()
+    net.initialize(ctx=ctx)
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+    # input no loss
+    with assert_raises(ValueError):
+        est = estimator.Estimator(net=net,
+                                  trainers=trainer,
+                                  metrics=acc,
+                                  context=ctx)
+    # input invalid loss
+    with assert_raises(ValueError):
+        est = estimator.Estimator(net=net,
+                                  loss='mse',
+                                  metrics=acc,
+                                  trainers=trainer,
+                                  context=ctx)
+
+def test_context():
+    ''' test with no context, list of context, invalid context '''
+    net = get_model()
+    loss = gluon.loss.L2Loss()
+    metrics = mx.metric.Accuracy()
+    # input no context
+    est = estimator.Estimator(net=net,
+                              loss=loss,
+                              metrics=metrics)
+    # input list of context
+    ctx = [mx.gpu(0), mx.gpu(1)]
+    est = estimator.Estimator(net=net,
+                              loss=loss,
+                              metrics=metrics,
+                              context=ctx)
+    # input invalid context
+    with assert_raises(ValueError):
+        est = estimator.Estimator(net=net,
+                                  loss=loss,
+                                  metrics=metrics,
+                                  context='cpu')