You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2019/03/25 18:18:03 UTC
[incubator-mxnet] branch fit-api updated: [MXNet-1349][Fit API]Add
validation support and unit tests for fit() API (#14442)
This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch fit-api
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/fit-api by this push:
new 8186772 [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (#14442)
8186772 is described below
commit 8186772f78c1a4e5c9412d1e860f31f05e31fcfd
Author: Abhinav Sharma <ab...@gmail.com>
AuthorDate: Mon Mar 25 11:17:31 2019 -0700
[MXNet-1349][Fit API]Add validation support and unit tests for fit() API (#14442)
* added estimator unittests
* add more tests for estimator
* added validation logic
* added error handlers, unittests
* improve val stats
* fix pylint
* fix pylint
* update unit test
* fix tests
* fix tests
* updated metrics, val logic
* trigger ci
* trigger ci
* update metric, batch_fn error handler
* update context logic, add default metric
---
python/mxnet/gluon/estimator/estimator.py | 116 ++++++++---
python/mxnet/gluon/estimator/event_handler.py | 2 +-
tests/python/unittest/test_gluon_estimator.py | 277 ++++++++++++++++++++++++++
3 files changed, 370 insertions(+), 25 deletions(-)
diff --git a/python/mxnet/gluon/estimator/estimator.py b/python/mxnet/gluon/estimator/estimator.py
index c160115..e759fa7 100644
--- a/python/mxnet/gluon/estimator/estimator.py
+++ b/python/mxnet/gluon/estimator/estimator.py
@@ -19,13 +19,14 @@
# pylint: disable=wildcard-import
"""Gluon Estimator"""
+import copy
import warnings
from .event_handler import LoggingHandler
from ... import gluon, autograd
from ...context import Context, cpu, gpu, num_gpus
from ...io import DataIter
-from ...metric import EvalMetric, Loss
+from ...metric import EvalMetric, Loss, Accuracy
__all__ = ['Estimator']
@@ -62,44 +63,57 @@ class Estimator(object):
if isinstance(loss, gluon.loss.Loss):
self.loss = [loss]
+ elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss) for l in loss]):
+ self.loss = loss
else:
- self.loss = loss or []
- for l in self.loss:
- if not isinstance(loss, gluon.loss.Loss):
- raise ValueError("loss must be a Loss or a list of Loss, refer to gluon.loss.Loss")
+ raise ValueError("loss must be a Loss or a list of Loss, "
+ "refer to gluon.loss.Loss:{}".format(loss))
if isinstance(metrics, EvalMetric):
- self.metrics = [metrics]
+ self.train_metrics = [metrics]
else:
- self.metrics = metrics or []
- for metric in self.metrics:
- if not isinstance(metric, EvalMetric):
- raise ValueError("metrics must be a Metric or a list of Metric, refer to mxnet.metric.EvalMetric")
+ self.train_metrics = metrics or []
+ if not all([isinstance(metric, EvalMetric) for metric in self.train_metrics]):
+ raise ValueError("metrics must be a Metric or a list of Metric, "
+ "refer to mxnet.metric.EvalMetric:{}".format(metrics))
+
+ # Use default mx.metric.Accuracy() for gluon.loss.SoftmaxCrossEntropyLoss()
+ if not self.train_metrics and any([isinstance(l, gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]):
+ self.train_metrics = [Accuracy()]
+
+ # Use same metrics for validation
+ self.val_metrics = copy.deepcopy(self.train_metrics)
- self.initializer = initializer
# store training statistics
self.train_stats = {}
self.train_stats['epochs'] = []
self.train_stats['learning_rate'] = []
# current step of the epoch
self.train_stats['step'] = ''
- for metric in self.metrics:
+ for metric in self.train_metrics:
# record a history of metrics over each epoch
self.train_stats['train_' + metric.name] = []
# only record the latest metric numbers after each batch
self.train_stats['batch_' + metric.name] = 0.
- self.loss_metrics = []
+ for metric in self.val_metrics:
+ self.train_stats['val_' + metric.name] = []
+ self.train_loss_metrics = []
+ self.val_loss_metrics = []
# using the metric wrapper for loss to record loss value
for l in self.loss:
- self.loss_metrics.append(Loss(l.name))
+ self.train_loss_metrics.append(Loss(l.name))
+ self.val_loss_metrics.append(Loss(l.name))
self.train_stats['train_' + l.name] = []
+ self.train_stats['val_' + l.name] = []
# only record the latest loss numbers after each batch
self.train_stats['batch_' + l.name] = 0.
# handle context
if isinstance(context, Context):
self.context = [context]
- if not context:
+ elif isinstance(context, list) and all([isinstance(c, Context) for c in context]):
+ self.context = context
+ elif not context:
if num_gpus() > 0:
# only use 1 GPU by default
if num_gpus() > 1:
@@ -109,8 +123,13 @@ class Estimator(object):
self.context = [gpu(0)]
else:
self.context = [cpu()]
+ else:
+ raise ValueError("context must be a Context or a list of Context, "
+ "refer to mxnet.Context:{}".format(context))
+
# initialize the network
+ self.initializer = initializer
if self.initializer:
if self._is_initialized():
# if already initialized, re-init with user specified initializer
@@ -128,13 +147,13 @@ class Estimator(object):
# handle trainers
if isinstance(trainers, gluon.Trainer):
self.trainers = [trainers]
- else:
- self.trainers = trainers or []
- if not self.trainers:
+ elif not trainers:
warnings.warn("No trainer specified, default SGD optimizer "
"with learning rate 0.001 is used.")
self.trainers = [gluon.Trainer(self.net.collect_params(),
'sgd', {'learning_rate': 0.001})]
+ else:
+ raise ValueError("Invalid trainer specified, please provide a valid gluon.Trainer")
def _is_initialized(self):
param_dict = self.net.collect_params()
@@ -156,7 +175,48 @@ class Estimator(object):
label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0)
return data, label
+ def evaluate(self,
+ val_data,
+ batch_fn=None):
+ """Evaluate model on validation data
+
+ Parameters
+ ----------
+ val_data : DataLoader or DataIter
+ validation data with data and labels
+ batch_fn : function
+ custom batch function to extract data and label
+ from a data batch and load into contexts(devices)
+ """
+
+ for metric in self.val_metrics + self.val_loss_metrics:
+ metric.reset()
+
+ for _, batch in enumerate(val_data):
+ if not batch_fn:
+ if isinstance(val_data, gluon.data.DataLoader):
+ data, label = self._batch_fn(batch, self.context)
+ elif isinstance(val_data, DataIter):
+ data, label = self._batch_fn(batch, self.context, is_iterator=True)
+ else:
+ raise ValueError("You are using a custom iteration, please also provide "
+ "batch_fn to extract data and label. Alternatively, you "
+ "can provide the data as gluon.data.DataLoader or "
+ "mx.io.DataIter")
+ else:
+ data, label = batch_fn(batch, self.context)
+ pred = [self.net(x) for x in data]
+ losses = []
+ for loss in self.loss:
+ losses.append([loss(y_hat, y) for y_hat, y in zip(pred, label)])
+ # update metrics
+ for metric in self.val_metrics:
+ metric.update(label, pred)
+ for loss, loss_metric, in zip(losses, self.val_loss_metrics):
+ loss_metric.update(0, [l for l in loss])
+
def fit(self, train_data,
+ val_data=None,
epochs=1,
batch_size=None,
event_handlers=None,
@@ -204,7 +264,7 @@ class Estimator(object):
for handler in event_handlers:
handler.epoch_begin()
- for metric in self.metrics + self.loss_metrics:
+ for metric in self.train_metrics + self.train_loss_metrics:
metric.reset()
for i, batch in enumerate(train_data):
@@ -215,7 +275,9 @@ class Estimator(object):
data, label = self._batch_fn(batch, self.context, is_iterator=True)
else:
raise ValueError("You are using a custom iteration, please also provide "
- "batch_fn to extract data and label")
+ "batch_fn to extract data and label. Alternatively, you "
+ "can provide the data as gluon.data.DataLoader or "
+ "mx.io.DataIter")
else:
data, label = batch_fn(batch, self.context)
@@ -233,11 +295,11 @@ class Estimator(object):
for l in loss:
l.backward()
- # update metrics
- for metric in self.metrics:
+ # update train metrics
+ for metric in self.train_metrics:
metric.update(label, pred)
self.train_stats['batch_' + metric.name] = metric.get()[1]
- for loss, loss_metric, in zip(losses, self.loss_metrics):
+ for loss, loss_metric, in zip(losses, self.train_loss_metrics):
loss_metric.update(0, [l for l in loss])
self.train_stats['batch_' + loss_metric.name] = loss_metric.get()[1]
@@ -257,8 +319,14 @@ class Estimator(object):
for handler in event_handlers:
handler.batch_end()
- for metric in self.metrics + self.loss_metrics:
+ if val_data:
+ self.evaluate(val_data, batch_fn)
+
+ for metric in self.train_metrics + self.train_loss_metrics:
self.train_stats['train_' + metric.name].append(metric.get()[1])
+ for metric in self.val_metrics + self.val_loss_metrics:
+ self.train_stats['val_' + metric.name].append(metric.get()[1])
+
# epoch end
for handler in event_handlers:
handler.epoch_end()
diff --git a/python/mxnet/gluon/estimator/event_handler.py b/python/mxnet/gluon/estimator/event_handler.py
index 0162c36..c59644e 100644
--- a/python/mxnet/gluon/estimator/event_handler.py
+++ b/python/mxnet/gluon/estimator/event_handler.py
@@ -118,7 +118,7 @@ class LoggingHandler(EventHandler):
epoch = self._estimator.train_stats['epochs'][-1]
msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time)
for key in self._estimator.train_stats.keys():
- if key.startswith('train_') or key.startswith('test_'):
+ if key.startswith('train_') or key.startswith('val_'):
msg += key + ': ' + '%.4f ' % self._estimator.train_stats[key][epoch]
self.logger.info(msg)
diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py
new file mode 100644
index 0000000..85e61ce
--- /dev/null
+++ b/tests/python/unittest/test_gluon_estimator.py
@@ -0,0 +1,277 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+''' Unit tests for Gluon Estimator '''
+
+import unittest
+import sys
+import warnings
+from nose.tools import assert_raises
+import mxnet as mx
+from mxnet import gluon
+from mxnet.gluon import nn
+from mxnet.gluon.estimator import estimator
+
+
+def get_model():
+ net = nn.Sequential()
+ net.add(nn.Dense(4, activation='relu', flatten=False))
+ return net
+
+
+def test_fit():
+ ''' test estimator with different train data types '''
+ net = get_model()
+ num_epochs = 1
+ batch_size = 4
+ ctx = mx.cpu()
+ loss = gluon.loss.L2Loss()
+ acc = mx.metric.Accuracy()
+ net.initialize(ctx=ctx)
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+ est = estimator.Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ trainers=trainer,
+ context=ctx)
+ in_data = mx.nd.random.uniform(shape=(10, 3))
+ out_data = mx.nd.random.uniform(shape=(10, 4))
+ # Input dataloader
+ dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
+ train_dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size)
+ est.fit(train_data=train_dataloader,
+ epochs=num_epochs,
+ batch_size=batch_size)
+
+ # Input dataiter
+ train_dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size)
+ est.fit(train_data=train_dataiter,
+ epochs=num_epochs,
+ batch_size=batch_size)
+
+ # Input NDArray
+ with assert_raises(ValueError):
+ est.fit(train_data=[in_data, out_data],
+ epochs=num_epochs,
+ batch_size=batch_size)
+
+
+def test_validation():
+ ''' test different validation data types'''
+ net = get_model()
+ num_epochs = 1
+ batch_size = 4
+ ctx = mx.cpu()
+ loss = gluon.loss.L2Loss()
+ acc = mx.metric.Accuracy()
+ net.initialize(ctx=ctx)
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+ est = estimator.Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ trainers=trainer,
+ context=ctx)
+ in_data = mx.nd.random.uniform(shape=(10, 3))
+ out_data = mx.nd.random.uniform(shape=(10, 4))
+ # Input dataloader
+ dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
+ train_dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size)
+ val_dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size)
+ est.fit(train_data=train_dataloader,
+ val_data=val_dataloader,
+ epochs=num_epochs,
+ batch_size=batch_size)
+
+ # Input dataiter
+ train_dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size)
+ val_dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size)
+ est.fit(train_data=train_dataiter,
+ val_data=val_dataiter,
+ epochs=num_epochs,
+ batch_size=batch_size)
+ # Input NDArray
+ with assert_raises(ValueError):
+ est.fit(train_data=[in_data, out_data],
+ val_data=[in_data, out_data],
+ epochs=num_epochs,
+ batch_size=batch_size)
+
+
+@unittest.skipIf(sys.version_info.major < 3, 'Test on python 3')
+def test_initializer():
+ ''' test with no initializer, inconsistent initializer '''
+ net = get_model()
+ num_epochs = 1
+ batch_size = 4
+ ctx = mx.cpu()
+ in_data = mx.nd.random.uniform(shape=(10, 3))
+ out_data = mx.nd.random.uniform(shape=(10, 4))
+ dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
+ train_data = gluon.data.DataLoader(dataset, batch_size=batch_size)
+ loss = gluon.loss.L2Loss()
+ acc = mx.metric.Accuracy()
+ # no initializer
+ est = estimator.Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ context=ctx)
+ est.fit(train_data=train_data,
+ epochs=num_epochs,
+ batch_size=batch_size)
+
+ # different initializer for net and estimator
+ net = get_model()
+ net.initialize(mx.init.Xavier(), ctx=ctx)
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+ # catch reinit warning
+ with warnings.catch_warnings(record=True) as w:
+ est = estimator.Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ initializer=mx.init.MSRAPrelu(),
+ trainers=trainer,
+ context=ctx)
+ assert 'Network already initialized' in str(w[-1].message)
+ est.fit(train_data=train_data,
+ epochs=num_epochs,
+ batch_size=batch_size)
+
+
+@unittest.skipIf(sys.version_info.major < 3, 'Test on python 3')
+def test_trainer():
+ ''' test with no trainer and invalid trainer '''
+ net = get_model()
+ num_epochs = 1
+ batch_size = 4
+ ctx = mx.cpu()
+ in_data = mx.nd.random.uniform(shape=(10, 3))
+ out_data = mx.nd.random.uniform(shape=(10, 4))
+ dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
+ train_data = gluon.data.DataLoader(dataset, batch_size=batch_size)
+ loss = gluon.loss.L2Loss()
+ acc = mx.metric.Accuracy()
+ net.initialize(ctx=ctx)
+ # input no trainer
+ with warnings.catch_warnings(record=True) as w:
+ est = estimator.Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ context=ctx)
+ assert 'No trainer specified' in str(w[-1].message)
+ est.fit(train_data=train_data,
+ epochs=num_epochs,
+ batch_size=batch_size)
+
+ # input invalid trainer
+ trainer = 'sgd'
+ with assert_raises(ValueError):
+ est = estimator.Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ trainers=trainer,
+ context=ctx)
+
+
+def test_metric():
+ ''' test with no metric, list of metrics, invalid metric '''
+ net = get_model()
+ num_epochs = 1
+ batch_size = 4
+ ctx = mx.cpu()
+ in_data = mx.nd.random.uniform(shape=(10, 3))
+ out_data = mx.nd.random.uniform(shape=(10, 4))
+ dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
+ train_data = gluon.data.DataLoader(dataset, batch_size=batch_size)
+ loss = gluon.loss.L2Loss()
+ net.initialize(ctx=ctx)
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+ # input no metric
+ est = estimator.Estimator(net=net,
+ loss=loss,
+ trainers=trainer,
+ context=ctx)
+ est.fit(train_data=train_data,
+ epochs=num_epochs,
+ batch_size=batch_size)
+ # input list of metrics
+ metrics = [mx.metric.Accuracy(), mx.metric.Accuracy()]
+ est = estimator.Estimator(net=net,
+ loss=loss,
+ metrics=metrics,
+ trainers=trainer,
+ context=ctx)
+ est.fit(train_data=train_data,
+ epochs=num_epochs,
+ batch_size=batch_size)
+ # input invalid metric
+ with assert_raises(ValueError):
+ est = estimator.Estimator(net=net,
+ loss=loss,
+ metrics='acc',
+ trainers=trainer,
+ context=ctx)
+ # test default metric
+ loss = gluon.loss.SoftmaxCrossEntropyLoss()
+ est = estimator.Estimator(net=net,
+ loss=loss,
+ trainers=trainer,
+ context=ctx)
+ assert isinstance(est.train_metrics[0], mx.metric.Accuracy)
+
+
+def test_loss():
+ ''' test with no loss, invalid loss '''
+ net = get_model()
+ ctx = mx.cpu()
+ acc = mx.metric.Accuracy()
+ net.initialize(ctx=ctx)
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+ # input no loss
+ with assert_raises(ValueError):
+ est = estimator.Estimator(net=net,
+ trainers=trainer,
+ metrics=acc,
+ context=ctx)
+ # input invalid loss
+ with assert_raises(ValueError):
+ est = estimator.Estimator(net=net,
+ loss='mse',
+ metrics=acc,
+ trainers=trainer,
+ context=ctx)
+
+def test_context():
+ ''' test with no context, list of context, invalid context '''
+ net = get_model()
+ loss = gluon.loss.L2Loss()
+ metrics = mx.metric.Accuracy()
+ # input no context
+ est = estimator.Estimator(net=net,
+ loss=loss,
+ metrics=metrics)
+ # input list of context
+ ctx = [mx.gpu(0), mx.gpu(1)]
+ est = estimator.Estimator(net=net,
+ loss=loss,
+ metrics=metrics,
+ context=ctx)
+ # input invalid context
+ with assert_raises(ValueError):
+ est = estimator.Estimator(net=net,
+ loss=loss,
+ metrics=metrics,
+ context='cpu')