You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2019/12/11 21:58:13 UTC
[incubator-mxnet] branch v1.6.x updated: Backport Gluon estimator
changes to 1.6 (#17048)
This is an automated email from the ASF dual-hosted git repository.
ptrendx pushed a commit to branch v1.6.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.6.x by this push:
new c7dab17 Backport Gluon estimator changes to 1.6 (#17048)
c7dab17 is described below
commit c7dab1704d2604f26a9b058dec712987c9ac3801
Author: Przemyslaw Tredak <pt...@nvidia.com>
AuthorDate: Wed Dec 11 13:57:33 2019 -0800
Backport Gluon estimator changes to 1.6 (#17048)
* Include eval_net the validation model in the gluon estimator api (#16957)
* Include eval_net the validation model in the estimator api
* fix small issue
* Extend estimator.evaluate() to support event handlers (#16971)
* fix unittest failures for the new api interface
* Add comments in the code for readability
* Remove unused argument val_metrics
* merge changes with the master branch
* fix some regression errors
* fix bugs introduced in the merging phase
* Add support of plug and play fit_batch and evaluate_batch (#16982)
* Add support of plug and play fit_batch and evaluate_batch
* Add check for the validity of the estimator model
* Rename estimator model as batch processor
* Remove unused import
* Add documentation of the batch processor class
* refine the documentation of the batch processor
* Fix merge bugs
* fix bugs introduced during merge
* fix sanity check failures
* fix CI bugs
* Fix Gluon Estimator nightly test (#17042)
---
python/mxnet/gluon/contrib/estimator/__init__.py | 2 +
.../gluon/contrib/estimator/batch_processor.py | 105 +++++++++++
python/mxnet/gluon/contrib/estimator/estimator.py | 199 ++++++++++++---------
.../mxnet/gluon/contrib/estimator/event_handler.py | 59 +++---
tests/nightly/estimator/test_estimator_cnn.py | 4 +-
tests/nightly/estimator/test_sentiment_rnn.py | 2 +-
.../python/unittest/test_gluon_batch_processor.py | 117 ++++++++++++
tests/python/unittest/test_gluon_estimator.py | 142 ++++++++++++---
tests/python/unittest/test_gluon_event_handler.py | 23 ++-
9 files changed, 498 insertions(+), 155 deletions(-)
diff --git a/python/mxnet/gluon/contrib/estimator/__init__.py b/python/mxnet/gluon/contrib/estimator/__init__.py
index bb0a091..5ffd603 100644
--- a/python/mxnet/gluon/contrib/estimator/__init__.py
+++ b/python/mxnet/gluon/contrib/estimator/__init__.py
@@ -19,5 +19,7 @@
"""Gluon Estimator Module"""
from . import estimator
from . import event_handler
+from . import batch_processor
from .estimator import *
from .event_handler import *
+from .batch_processor import *
diff --git a/python/mxnet/gluon/contrib/estimator/batch_processor.py b/python/mxnet/gluon/contrib/estimator/batch_processor.py
new file mode 100644
index 0000000..4985f8c
--- /dev/null
+++ b/python/mxnet/gluon/contrib/estimator/batch_processor.py
@@ -0,0 +1,105 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import, unused-argument, too-many-ancestors
+"""Gluon Batch Processor for Estimators"""
+
+from ...utils import split_and_load
+from .... import autograd
+
+__all__ = ['BatchProcessor']
+
+class BatchProcessor(object):
+ """BatchProcessor Class for plug and play fit_batch & evaluate_batch
+
+ During training or validation, data are divided into minibatches for processing. This
+ class aims at providing hooks of training or validating on a minibatch of data. Users
+ may provide customized fit_batch() and evaluate_batch() methods by inheriting from
+ this class and overriding class methods.
+
+ :py:class:`BatchProcessor` can be used to replace fit_batch() and evaluate_batch()
+ in the base estimator class
+ """
+
+ def __init__(self):
+ pass
+
+ def _get_data_and_label(self, batch, ctx, batch_axis=0):
+ data = batch[0]
+ label = batch[1]
+ data = split_and_load(data, ctx_list=ctx, batch_axis=batch_axis)
+ label = split_and_load(label, ctx_list=ctx, batch_axis=batch_axis)
+ return data, label
+
+ def evaluate_batch(self, estimator,
+ val_batch,
+ batch_axis=0):
+ """Evaluate the estimator model on a batch of validation data.
+
+ Parameters
+ ----------
+ estimator : Estimator
+ Reference to the estimator
+ val_batch : tuple
+ Data and label of a batch from the validation data loader.
+ batch_axis : int, default 0
+ Batch axis to split the validation data into devices.
+ """
+ data, label = self._get_data_and_label(val_batch, estimator.context, batch_axis)
+ pred = [estimator.eval_net(x) for x in data]
+ loss = [estimator.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)]
+
+ return data, label, pred, loss
+
+ def fit_batch(self, estimator,
+ train_batch,
+ batch_axis=0):
+ """Trains the estimator model on a batch of training data.
+
+ Parameters
+ ----------
+ estimator : Estimator
+ Reference to the estimator
+ train_batch : tuple
+ Data and label of a batch from the training data loader.
+ batch_axis : int, default 0
+ Batch axis to split the training data into devices.
+
+ Returns
+ -------
+ data: List of NDArray
+ Sharded data from the batch. Data is sharded with
+ `gluon.split_and_load`.
+ label: List of NDArray
+ Sharded label from the batch. Labels are sharded with
+ `gluon.split_and_load`.
+ pred: List of NDArray
+ Prediction on each of the sharded inputs.
+ loss: List of NDArray
+ Loss on each of the sharded inputs.
+ """
+ data, label = self._get_data_and_label(train_batch, estimator.context, batch_axis)
+
+ with autograd.record():
+ pred = [estimator.net(x) for x in data]
+ loss = [estimator.loss(y_hat, y) for y_hat, y in zip(pred, label)]
+
+ for l in loss:
+ l.backward()
+
+ return data, label, pred, loss
diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py
index ab7018f..09f4315 100644
--- a/python/mxnet/gluon/contrib/estimator/estimator.py
+++ b/python/mxnet/gluon/contrib/estimator/estimator.py
@@ -32,9 +32,9 @@ from ...data import DataLoader
from ...loss import Loss as gluon_loss
from ...trainer import Trainer
from ...utils import split_and_load
-from .... import autograd
from ....context import Context, cpu, gpu, num_gpus
from ....metric import Loss as metric_loss
+from .batch_processor import BatchProcessor
__all__ = ['Estimator']
@@ -51,18 +51,41 @@ class Estimator(object):
The model used for training.
loss : gluon.loss.Loss
Loss (objective) function to calculate during training.
- metrics : EvalMetric or list of EvalMetric
- Metrics for evaluating models.
+ train_metrics : EvalMetric or list of EvalMetric
+ Training metrics for evaluating models on training dataset.
+ val_metrics : EvalMetric or list of EvalMetric
+ Validation metrics for evaluating models on validation dataset.
initializer : Initializer
Initializer to initialize the network.
trainer : Trainer
Trainer to apply optimizer on network parameters.
context : Context or list of Context
Device(s) to run the training on.
- evaluation_loss: gluon.loss.loss
- Loss (objective) function to calculate during evaluation. If set evaluation_loss
+ evaluation_loss : gluon.loss.loss
+ Loss (objective) function to calculate during validation. If set evaluation_loss
None, it will use the same loss function as self.loss
-
+ eval_net : gluon.Block
+ The model used for validation. The validation model does not necessarily belong to
+ the same model class as the training model. But the two models typically share the
+ same architecture. Therefore the validation model can reuse parameters of the
+ training model.
+
+ The code example of consruction of eval_net sharing the same network parameters as
+ the training net is given below:
+
+ >>> net = _get_train_network()
+ >>> eval_net = _get_test_network(params=net.collect_params())
+ >>> net.initialize(ctx=ctx)
+ >>> est = Estimator(net, loss, eval_net=eval_net)
+
+ Proper namespace match is required for weight sharing between two networks. Most networks
+ inheriting :py:class:`Block` can share their parameters correctly. An exception is
+ Sequential networks that Block scope must be specified for correct weight sharing. For
+ the naming in mxnet Gluon API, please refer to the site
+ (https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/naming.html)
+ for future information.
+ batch_processor: BatchProcessor
+ BatchProcessor provides customized fit_batch() and evaluate_batch() methods
"""
logger = None
@@ -85,19 +108,26 @@ class Estimator(object):
def __init__(self, net,
loss,
- metrics=None,
+ train_metrics=None,
+ val_metrics=None,
initializer=None,
trainer=None,
context=None,
- evaluation_loss=None):
+ evaluation_loss=None,
+ eval_net=None,
+ batch_processor=None):
self.net = net
self.loss = self._check_loss(loss)
- self._train_metrics = _check_metrics(metrics)
+ self._train_metrics = _check_metrics(train_metrics)
+ self._val_metrics = _check_metrics(val_metrics)
self._add_default_training_metrics()
self._add_validation_metrics()
self.evaluation_loss = self.loss
if evaluation_loss is not None:
self.evaluation_loss = self._check_loss(evaluation_loss)
+ self.eval_net = self.net
+ if eval_net is not None:
+ self.eval_net = eval_net
self.logger = logging.Logger(name='Estimator', level=logging.INFO)
self.logger.addHandler(logging.StreamHandler(sys.stdout))
@@ -105,6 +135,7 @@ class Estimator(object):
self.context = self._check_context(context)
self._initialize(initializer)
self.trainer = self._check_trainer(trainer)
+ self.batch_processor = self._check_batch_processor(batch_processor)
def _check_loss(self, loss):
if not isinstance(loss, gluon_loss):
@@ -145,6 +176,18 @@ class Estimator(object):
context = [cpu()]
return context
+ def _check_batch_processor(self, batch_processor):
+ # check whether the batch processor contains fit_batch() and evaluate_batch() methods
+ if batch_processor is not None:
+ model_fit = getattr(batch_processor, 'fit_batch', None)
+ model_evaluate = getattr(batch_processor, 'evaluate_batch', None)
+ if not callable(model_fit) or not callable(model_evaluate):
+ raise ValueError('Customized Batch Processor must contain fit_batch()'
+ ' and evaluate_batch() methods')
+ else:
+ batch_processor = BatchProcessor()
+ return batch_processor
+
def _initialize(self, initializer):
# initialize the network
if not self._is_initialized():
@@ -202,13 +245,21 @@ class Estimator(object):
self._train_metrics.append(metric_loss(loss_name))
for metric in self._train_metrics:
- metric.name = "training " + metric.name
+ # add training prefix to the metric name
+ # it is useful for event handlers to distinguish them from validation metrics
+ metric.name = 'training ' + metric.name
def _add_validation_metrics(self):
- self._val_metrics = [copy.deepcopy(metric) for metric in self._train_metrics]
+ if not self._val_metrics:
+ self._val_metrics = [copy.deepcopy(metric) for metric in self._train_metrics]
for metric in self._val_metrics:
- metric.name = "validation " + metric.name
+ # add validation prefix to the metric name
+ # it is useful for event handlers to distinguish them from training metrics
+ if 'training' in metric.name:
+ metric.name = metric.name.replace('training', 'validation')
+ else:
+ metric.name = 'validation ' + metric.name
@property
def train_metrics(self):
@@ -218,35 +269,10 @@ class Estimator(object):
def val_metrics(self):
return self._val_metrics
- def evaluate_batch(self,
- val_batch,
- val_metrics,
- batch_axis=0):
- """Evaluate model on a batch of validation data.
-
- Parameters
- ----------
- val_batch : tuple
- Data and label of a batch from the validation data loader.
- val_metrics : EvalMetric or list of EvalMetrics
- Metrics to update validation result.
- batch_axis : int, default 0
- Batch axis to split the validation data into devices.
- """
- data, label = self._get_data_and_label(val_batch, self.context, batch_axis)
- pred = [self.net(x) for x in data]
- loss = [self.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)]
- # update metrics
- for metric in val_metrics:
- if isinstance(metric, metric_loss):
- metric.update(0, loss)
- else:
- metric.update(label, pred)
-
def evaluate(self,
val_data,
- val_metrics,
- batch_axis=0):
+ batch_axis=0,
+ event_handlers=None):
"""Evaluate model on validation data.
This function calls :py:func:`evaluate_batch` on each of the batches from the
@@ -257,57 +283,45 @@ class Estimator(object):
----------
val_data : DataLoader
Validation data loader with data and labels.
- val_metrics : EvalMetric or list of EvalMetrics
- Metrics to update validation result.
batch_axis : int, default 0
Batch axis to split the validation data into devices.
+ event_handlers : EventHandler or list of EventHandler
+ List of :py:class:`EventHandlers` to apply during validation. Besides
+ event handlers specified here, a default MetricHandler and a LoggingHandler
+ will be added if not specified explicitly.
"""
if not isinstance(val_data, DataLoader):
raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
"can transform your DataIter or any NDArray into Gluon DataLoader. "
"Refer to gluon.data.DataLoader")
- for metric in val_metrics:
+ for metric in self.val_metrics:
metric.reset()
+ estimator_ref = self
- for _, batch in enumerate(val_data):
- self.evaluate_batch(batch, val_metrics, batch_axis)
+ event_handlers = self._prepare_default_validation_handlers(event_handlers)
- def fit_batch(self, train_batch, batch_axis=0):
- """Trains the model on a batch of training data.
+ _, epoch_begin, batch_begin, batch_end, \
+ epoch_end, _ = self._categorize_handlers(event_handlers)
- Parameters
- ----------
- train_batch : tuple
- Data and label of a batch from the training data loader.
- batch_axis : int, default 0
- Batch axis to split the training data into devices.
+ estimator_ref = self
- Returns
- -------
- data: List of NDArray
- Sharded data from the batch. Data is sharded with
- `gluon.split_and_load`.
- label: List of NDArray
- Sharded label from the batch. Labels are sharded with
- `gluon.split_and_load`.
- pred: List of NDArray
- Prediction on each of the sharded inputs.
- loss: List of NDArray
- Loss on each of the sharded inputs.
- """
- data, label = self._get_data_and_label(train_batch, self.context, batch_axis)
+ for handler in epoch_begin:
+ handler.epoch_begin(estimator_ref)
- batch_size = train_batch[0].shape[batch_axis]
+ for _, batch in enumerate(val_data):
+ for handler in batch_begin:
+ handler.batch_begin(estimator_ref, batch=batch)
- with autograd.record():
- pred = [self.net(x) for x in data]
- loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)]
+ _, label, pred, loss = \
+ self.batch_processor.evaluate_batch(estimator_ref, batch,
+ batch_axis)
- for l in loss:
- l.backward()
+ for handler in batch_end:
+ handler.batch_end(estimator_ref, batch=batch, pred=pred, label=label, loss=loss)
- return data, label, pred, loss
+ for handler in epoch_end:
+ handler.epoch_end(estimator_ref)
def fit(self, train_data,
val_data=None,
@@ -382,8 +396,8 @@ class Estimator(object):
for handler in batch_begin:
handler.batch_begin(estimator_ref, batch=batch)
- _, label, pred, loss = self.fit_batch(batch, batch_axis)
-
+ _, label, pred, loss = self.batch_processor.fit_batch(estimator_ref,
+ batch, batch_axis)
# batch end
batch_end_result = []
@@ -417,23 +431,17 @@ class Estimator(object):
added_default_handlers.append(GradientUpdateHandler())
if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
- added_default_handlers.append(MetricHandler(train_metrics=self.train_metrics))
+ added_default_handlers.append(MetricHandler(metrics=self.train_metrics))
if not any(isinstance(handler, ValidationHandler) for handler in event_handlers):
# no validation handler
if val_data:
- val_metrics = self.val_metrics
# add default validation handler if validation data found
added_default_handlers.append(ValidationHandler(val_data=val_data,
- eval_fn=self.evaluate,
- val_metrics=val_metrics))
- else:
- # set validation metrics to None if no validation data and no validation handler
- val_metrics = []
+ eval_fn=self.evaluate))
if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
- added_default_handlers.append(LoggingHandler(train_metrics=self.train_metrics,
- val_metrics=val_metrics))
+ added_default_handlers.append(LoggingHandler(metrics=self.train_metrics))
# if there is a mix of user defined event handlers and default event handlers
# they should have the same set of metrics
@@ -450,6 +458,29 @@ class Estimator(object):
event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0))
return event_handlers
+ def _prepare_default_validation_handlers(self, event_handlers):
+ event_handlers = _check_event_handlers(event_handlers)
+ added_default_handlers = []
+
+ # add default logging handler and metric handler for validation
+ if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
+ added_default_handlers.append(MetricHandler(metrics=self.val_metrics))
+
+ if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
+ added_default_handlers.append(LoggingHandler(metrics=self.val_metrics))
+
+ mixing_handlers = event_handlers and added_default_handlers
+ event_handlers.extend(added_default_handlers)
+
+ # check if all handlers refer to well-defined validation metrics
+ if mixing_handlers:
+ known_metrics = set(self.val_metrics)
+ for handler in event_handlers:
+ _check_handler_metric_ref(handler, known_metrics)
+
+ event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0))
+ return event_handlers
+
def _categorize_handlers(self, event_handlers):
"""
categorize handlers into 6 event lists to avoid calling empty methods
diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py
index 6477760..c755136 100644
--- a/python/mxnet/gluon/contrib/estimator/event_handler.py
+++ b/python/mxnet/gluon/contrib/estimator/event_handler.py
@@ -128,28 +128,28 @@ class MetricHandler(EpochBegin, BatchEnd):
Parameters
----------
- train_metrics : List of EvalMetrics
- Training metrics to be updated at batch end.
+ metrics : List of EvalMetrics
+ Metrics to be updated at batch end.
priority : scalar
Priority level of the MetricHandler. Priority level is sorted in ascending
order. The lower the number is, the higher priority level the handler is.
"""
- def __init__(self, train_metrics, priority=-1000):
- self.train_metrics = _check_metrics(train_metrics)
+ def __init__(self, metrics, priority=-1000):
+ self.metrics = _check_metrics(metrics)
# order to be called among all callbacks
# metrics need to be calculated before other callbacks can access them
self.priority = priority
def epoch_begin(self, estimator, *args, **kwargs):
- for metric in self.train_metrics:
+ for metric in self.metrics:
metric.reset()
def batch_end(self, estimator, *args, **kwargs):
pred = kwargs['pred']
label = kwargs['label']
loss = kwargs['loss']
- for metric in self.train_metrics:
+ for metric in self.metrics:
if isinstance(metric, metric_loss):
# metric wrapper for loss values
metric.update(0, loss)
@@ -171,8 +171,6 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
eval_fn : function
A function defines how to run evaluation and
calculate loss and metrics.
- val_metrics : List of EvalMetrics
- Validation metrics to be updated.
epoch_period : int, default 1
How often to run validation at epoch end, by default
:py:class:`ValidationHandler` validate every epoch.
@@ -188,7 +186,6 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
def __init__(self,
val_data,
eval_fn,
- val_metrics=None,
epoch_period=1,
batch_period=None,
priority=-1000):
@@ -196,7 +193,6 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
self.eval_fn = eval_fn
self.epoch_period = epoch_period
self.batch_period = batch_period
- self.val_metrics = _check_metrics(val_metrics)
self.current_batch = 0
self.current_epoch = 0
# order to be called among all callbacks
@@ -211,20 +207,12 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
def batch_end(self, estimator, *args, **kwargs):
self.current_batch += 1
if self.batch_period and self.current_batch % self.batch_period == 0:
- self.eval_fn(val_data=self.val_data,
- val_metrics=self.val_metrics)
- msg = '[Epoch %d] ValidationHandler: %d batches reached, ' \
- % (self.current_epoch, self.current_batch)
- for monitor in self.val_metrics:
- name, value = monitor.get()
- msg += '%s: %.4f, ' % (name, value)
- estimator.logger.info(msg.rstrip(','))
+ self.eval_fn(val_data=self.val_data)
def epoch_end(self, estimator, *args, **kwargs):
self.current_epoch += 1
if self.epoch_period and self.current_epoch % self.epoch_period == 0:
- self.eval_fn(val_data=self.val_data,
- val_metrics=self.val_metrics)
+ self.eval_fn(val_data=self.val_data)
class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd):
@@ -239,10 +227,8 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
Logging interval during training.
log_interval='epoch': display metrics every epoch
log_interval=integer k: display metrics every interval of k batches
- train_metrics : list of EvalMetrics
- Training metrics to be logged, logged at batch end, epoch end, train end.
- val_metrics : list of EvalMetrics
- Validation metrics to be logged, logged at epoch end, train end.
+ metrics : list of EvalMetrics
+ Metrics to be logged, logged at batch end, epoch end, train end.
priority : scalar, default np.Inf
Priority level of the LoggingHandler. Priority level is sorted in
ascending order. The lower the number is, the higher priority level the
@@ -250,14 +236,12 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
"""
def __init__(self, log_interval='epoch',
- train_metrics=None,
- val_metrics=None,
+ metrics=None,
priority=np.Inf):
super(LoggingHandler, self).__init__()
if not isinstance(log_interval, int) and log_interval != 'epoch':
raise ValueError("log_interval must be either an integer or string 'epoch'")
- self.train_metrics = _check_metrics(train_metrics)
- self.val_metrics = _check_metrics(val_metrics)
+ self.metrics = _check_metrics(metrics)
self.batch_index = 0
self.current_epoch = 0
self.processed_samples = 0
@@ -265,6 +249,7 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
# it will also shut down logging at train end
self.priority = priority
self.log_interval = log_interval
+ self.log_interval_time = 0
def train_begin(self, estimator, *args, **kwargs):
self.train_start = time.time()
@@ -288,7 +273,7 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
train_time = time.time() - self.train_start
msg = 'Train finished using total %ds with %d epochs. ' % (train_time, self.current_epoch)
# log every result in train stats including train/validation loss & metrics
- for metric in self.train_metrics + self.val_metrics:
+ for metric in self.metrics:
name, value = metric.get()
msg += '%s: %.4f, ' % (name, value)
estimator.logger.info(msg.rstrip(', '))
@@ -307,7 +292,7 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
if self.batch_index % self.log_interval == 0:
msg += 'time/interval: %.3fs ' % self.log_interval_time
self.log_interval_time = 0
- for metric in self.train_metrics:
+ for metric in self.metrics:
# only log current training loss & metric after each interval
name, value = metric.get()
msg += '%s: %.4f, ' % (name, value)
@@ -316,15 +301,23 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
def epoch_begin(self, estimator, *args, **kwargs):
if isinstance(self.log_interval, int) or self.log_interval == 'epoch':
+ is_training = False
+ # use the name hack defined in __init__() of estimator class
+ for metric in self.metrics:
+ if 'training' in metric.name:
+ is_training = True
self.epoch_start = time.time()
- estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f",
- self.current_epoch, estimator.trainer.learning_rate)
+ if is_training:
+ estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f",
+ self.current_epoch, estimator.trainer.learning_rate)
+ else:
+ estimator.logger.info("Validation Begin")
def epoch_end(self, estimator, *args, **kwargs):
if isinstance(self.log_interval, int) or self.log_interval == 'epoch':
epoch_time = time.time() - self.epoch_start
msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time)
- for monitor in self.train_metrics + self.val_metrics:
+ for monitor in self.metrics:
name, value = monitor.get()
msg += '%s: %.4f, ' % (name, value)
estimator.logger.info(msg.rstrip(', '))
diff --git a/tests/nightly/estimator/test_estimator_cnn.py b/tests/nightly/estimator/test_estimator_cnn.py
index 4a3bb20..af51953 100644
--- a/tests/nightly/estimator/test_estimator_cnn.py
+++ b/tests/nightly/estimator/test_estimator_cnn.py
@@ -116,7 +116,7 @@ def test_estimator_cpu():
# Define estimator
est = estimator.Estimator(net=net,
loss=loss,
- metrics=mx.metric.Accuracy(),
+ train_metrics=mx.metric.Accuracy(),
trainer=trainer,
context=context)
# Call fit()
@@ -145,7 +145,7 @@ def test_estimator_gpu():
# Define estimator
est = estimator.Estimator(net=net,
loss=loss,
- metrics=acc,
+ train_metrics=acc,
trainer=trainer,
context=context)
# Call fit()
diff --git a/tests/nightly/estimator/test_sentiment_rnn.py b/tests/nightly/estimator/test_sentiment_rnn.py
index 233355b..ab124ba 100644
--- a/tests/nightly/estimator/test_sentiment_rnn.py
+++ b/tests/nightly/estimator/test_sentiment_rnn.py
@@ -197,7 +197,7 @@ def run(net, train_dataloader, test_dataloader, num_epochs, ctx, lr):
nested_metrics.add([metrics, mx.metric.Accuracy()])
# Define estimator
- est = estimator.Estimator(net=net, loss=loss, metrics=nested_metrics,
+ est = estimator.Estimator(net=net, loss=loss, train_metrics=nested_metrics,
trainer=trainer, context=ctx)
# Begin training
est.fit(train_data=train_dataloader, val_data=test_dataloader,
diff --git a/tests/python/unittest/test_gluon_batch_processor.py b/tests/python/unittest/test_gluon_batch_processor.py
new file mode 100644
index 0000000..4bd6f76
--- /dev/null
+++ b/tests/python/unittest/test_gluon_batch_processor.py
@@ -0,0 +1,117 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+''' Unit tests for Gluon Batch Processor '''
+
+import sys
+import unittest
+import warnings
+
+import mxnet as mx
+from mxnet import gluon
+from mxnet.gluon import nn
+from mxnet.gluon.contrib.estimator import *
+from mxnet.gluon.contrib.estimator.event_handler import *
+from mxnet.gluon.contrib.estimator.batch_processor import BatchProcessor
+from nose.tools import assert_raises
+
+def _get_test_network():
+ net = nn.Sequential()
+ net.add(nn.Dense(4, activation='relu', flatten=False))
+ return net
+
+
+def _get_test_data():
+ batch_size = 4
+ in_data = mx.nd.random.uniform(shape=(10, 3))
+ out_data = mx.nd.random.uniform(shape=(10, 4))
+ # Input dataloader
+ dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
+ dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size)
+ dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size)
+ return dataloader, dataiter
+
+def test_batch_processor_fit():
+ ''' test estimator with different train data types '''
+ net = _get_test_network()
+ dataloader, dataiter = _get_test_data()
+ num_epochs = 1
+ ctx = mx.cpu()
+ loss = gluon.loss.L2Loss()
+ acc = mx.metric.Accuracy()
+ net.initialize(ctx=ctx)
+ processor = BatchProcessor()
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+ est = Estimator(net=net,
+ loss=loss,
+ train_metrics=acc,
+ trainer=trainer,
+ context=ctx,
+ batch_processor=processor)
+
+ est.fit(train_data=dataloader,
+ epochs=num_epochs)
+
+ with assert_raises(ValueError):
+ est.fit(train_data=dataiter,
+ epochs=num_epochs)
+
+ # Input NDArray
+ with assert_raises(ValueError):
+ est.fit(train_data=[mx.nd.ones(shape=(10, 3))],
+ epochs=num_epochs)
+
+
+def test_batch_processor_validation():
+ ''' test different validation data types'''
+ net = _get_test_network()
+ dataloader, dataiter = _get_test_data()
+ num_epochs = 1
+ ctx = mx.cpu()
+ loss = gluon.loss.L2Loss()
+ acc = mx.metric.Accuracy()
+ evaluation_loss = gluon.loss.L1Loss()
+ net.initialize(ctx=ctx)
+ processor = BatchProcessor()
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+ est = Estimator(net=net,
+ loss=loss,
+ train_metrics=acc,
+ trainer=trainer,
+ context=ctx,
+ evaluation_loss=evaluation_loss,
+ batch_processor=processor)
+ # Input dataloader
+ est.fit(train_data=dataloader,
+ val_data=dataloader,
+ epochs=num_epochs)
+
+ # using validation handler
+ train_metrics = est.train_metrics
+ val_metrics = est.val_metrics
+ validation_handler = ValidationHandler(val_data=dataloader, eval_fn=est.evaluate)
+
+ with assert_raises(ValueError):
+ est.fit(train_data=dataiter,
+ val_data=dataiter,
+ epochs=num_epochs)
+ # Input NDArray
+ with assert_raises(ValueError):
+ est.fit(train_data=[mx.nd.ones(shape=(10, 3))],
+ val_data=[mx.nd.ones(shape=(10, 3))],
+ epochs=num_epochs)
+
diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py
index 21f949a..924dd08 100644
--- a/tests/python/unittest/test_gluon_estimator.py
+++ b/tests/python/unittest/test_gluon_estimator.py
@@ -29,11 +29,16 @@ from mxnet.gluon.contrib.estimator.event_handler import *
from nose.tools import assert_raises
-def _get_test_network():
- net = nn.Sequential()
+def _get_test_network(params=None):
+ net = nn.Sequential(params=params)
net.add(nn.Dense(4, activation='relu', flatten=False))
return net
+def _get_test_network_with_namescope(params=None):
+ net = nn.Sequential(params=params)
+ with net.name_scope():
+ net.add(nn.Dense(4, activation='relu', flatten=False))
+ return net
def _get_test_data():
batch_size = 4
@@ -58,7 +63,7 @@ def test_fit():
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
est = Estimator(net=net,
loss=loss,
- metrics=acc,
+ train_metrics=acc,
trainer=trainer,
context=ctx)
@@ -88,7 +93,7 @@ def test_validation():
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
est = Estimator(net=net,
loss=loss,
- metrics=acc,
+ train_metrics=acc,
trainer=trainer,
context=ctx,
evaluation_loss=evaluation_loss)
@@ -100,8 +105,7 @@ def test_validation():
# using validation handler
train_metrics = est.train_metrics
val_metrics = est.val_metrics
- validation_handler = ValidationHandler(val_data=dataloader, eval_fn=est.evaluate,
- val_metrics=val_metrics)
+ validation_handler = ValidationHandler(val_data=dataloader, eval_fn=est.evaluate)
with assert_raises(ValueError):
est.fit(train_data=dataiter,
@@ -127,7 +131,7 @@ def test_initializer():
# no initializer
est = Estimator(net=net,
loss=loss,
- metrics=acc,
+ train_metrics=acc,
context=ctx)
est.fit(train_data=train_data,
epochs=num_epochs)
@@ -140,7 +144,7 @@ def test_initializer():
with warnings.catch_warnings(record=True) as w:
est = Estimator(net=net,
loss=loss,
- metrics=acc,
+ train_metrics=acc,
initializer=mx.init.MSRAPrelu(),
trainer=trainer,
context=ctx)
@@ -148,7 +152,7 @@ def test_initializer():
# net partially initialized, fine tuning use case
net = gluon.model_zoo.vision.resnet18_v1(pretrained=True, ctx=ctx)
net.output = gluon.nn.Dense(10) #last layer not initialized
- est = Estimator(net, loss=loss, metrics=acc, context=ctx)
+ est = Estimator(net, loss=loss, train_metrics=acc, context=ctx)
dataset = gluon.data.ArrayDataset(mx.nd.zeros((10, 3, 224, 224)), mx.nd.zeros((10, 10)))
train_data = gluon.data.DataLoader(dataset=dataset, batch_size=5)
est.fit(train_data=train_data,
@@ -170,7 +174,7 @@ def test_trainer():
with warnings.catch_warnings(record=True) as w:
est = Estimator(net=net,
loss=loss,
- metrics=acc,
+ train_metrics=acc,
context=ctx)
assert 'No trainer specified' in str(w[-1].message)
est.fit(train_data=train_data,
@@ -181,7 +185,7 @@ def test_trainer():
with assert_raises(ValueError):
est = Estimator(net=net,
loss=loss,
- metrics=acc,
+ train_metrics=acc,
trainer=trainer,
context=ctx)
@@ -207,7 +211,7 @@ def test_metric():
metrics = [mx.metric.Accuracy(), mx.metric.Accuracy()]
est = Estimator(net=net,
loss=loss,
- metrics=metrics,
+ train_metrics=metrics,
trainer=trainer,
context=ctx)
est.fit(train_data=train_data,
@@ -216,7 +220,7 @@ def test_metric():
with assert_raises(ValueError):
est = Estimator(net=net,
loss=loss,
- metrics='acc',
+ train_metrics='acc',
trainer=trainer,
context=ctx)
# test default metric
@@ -239,7 +243,7 @@ def test_loss():
with assert_raises(ValueError):
est = Estimator(net=net,
loss='mse',
- metrics=acc,
+ train_metrics=acc,
trainer=trainer,
context=ctx)
@@ -252,26 +256,26 @@ def test_context():
# input no context
est = Estimator(net=net,
loss=loss,
- metrics=metrics)
+ train_metrics=metrics)
# input list of context
gpus = mx.context.num_gpus()
ctx = [mx.gpu(i) for i in range(gpus)] if gpus > 0 else [mx.cpu()]
net = _get_test_network()
est = Estimator(net=net,
loss=loss,
- metrics=metrics,
+ train_metrics=metrics,
context=ctx)
# input invalid context
with assert_raises(ValueError):
est = Estimator(net=net,
loss=loss,
- metrics=metrics,
+ train_metrics=metrics,
context='cpu')
with assert_raises(AssertionError):
est = Estimator(net=net,
loss=loss,
- metrics=metrics,
+ train_metrics=metrics,
context=[mx.gpu(0), mx.gpu(100)])
@@ -336,7 +340,7 @@ def test_default_handlers():
est = Estimator(net=net,
loss=loss,
- metrics=train_acc,
+ train_metrics=train_acc,
trainer=trainer,
context=ctx)
# no handler(all default handlers), no warning
@@ -347,18 +351,18 @@ def test_default_handlers():
# use mix of default and user defined handlers
train_metrics = est.train_metrics
val_metrics = est.val_metrics
- logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics)
+ logging = LoggingHandler(metrics=train_metrics)
est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging])
# handler with all user defined metrics
# use mix of default and user defined handlers
- metric = MetricHandler(train_metrics=[train_acc])
- logging = LoggingHandler(train_metrics=[train_acc])
+ metric = MetricHandler(metrics=[train_acc])
+ logging = LoggingHandler(metrics=[train_acc])
est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[metric, logging])
# handler with mixed metrics, some handler use metrics prepared by estimator
# some handler use metrics user prepared
- logging = LoggingHandler(train_metrics=train_metrics, val_metrics=[mx.metric.RMSE("val acc")])
+ logging = LoggingHandler(metrics=[mx.metric.RMSE("val acc")])
with assert_raises(ValueError):
est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging])
@@ -371,3 +375,95 @@ def test_default_handlers():
assert isinstance(handlers[0], GradientUpdateHandler)
assert isinstance(handlers[1], MetricHandler)
assert isinstance(handlers[4], LoggingHandler)
+
+def test_eval_net():
+ ''' test estimator with a different evaluation net '''
+ ''' test weight sharing of sequential networks without namescope '''
+ net = _get_test_network()
+ eval_net = _get_test_network(params=net.collect_params())
+ dataloader, dataiter = _get_test_data()
+ num_epochs = 1
+ ctx = mx.cpu()
+ loss = gluon.loss.L2Loss()
+ evaluation_loss = gluon.loss.L2Loss()
+ acc = mx.metric.Accuracy()
+ net.initialize(ctx=ctx)
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+ est = Estimator(net=net,
+ loss=loss,
+ train_metrics=acc,
+ trainer=trainer,
+ context=ctx,
+ evaluation_loss=evaluation_loss,
+ eval_net=eval_net)
+
+ with assert_raises(RuntimeError):
+ est.fit(train_data=dataloader,
+ val_data=dataloader,
+ epochs=num_epochs)
+
+ ''' test weight sharing of sequential networks with namescope '''
+ net = _get_test_network_with_namescope()
+ eval_net = _get_test_network_with_namescope(params=net.collect_params())
+ net.initialize(ctx=ctx)
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+ est = Estimator(net=net,
+ loss=loss,
+ train_metrics=acc,
+ trainer=trainer,
+ context=ctx,
+ evaluation_loss=evaluation_loss,
+ eval_net=eval_net)
+
+ est.fit(train_data=dataloader,
+ val_data=dataloader,
+ epochs=num_epochs)
+
+ ''' test weight sharing of two resnets '''
+ net = gluon.model_zoo.vision.resnet18_v1(pretrained=False, ctx=ctx)
+ net.output = gluon.nn.Dense(10)
+ eval_net = gluon.model_zoo.vision.resnet18_v1(pretrained=False, ctx=ctx)
+ eval_net.output = gluon.nn.Dense(10, params=net.collect_params())
+ dataset = gluon.data.ArrayDataset(mx.nd.zeros((10, 3, 224, 224)), mx.nd.zeros((10, 10)))
+ dataloader = gluon.data.DataLoader(dataset=dataset, batch_size=5)
+ net.initialize(ctx=ctx)
+ eval_net.initialize(ctx=ctx)
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+ est = Estimator(net=net,
+ loss=loss,
+ train_metrics=acc,
+ trainer=trainer,
+ context=ctx,
+ evaluation_loss=evaluation_loss,
+ eval_net=eval_net)
+
+ est.fit(train_data=dataloader,
+ val_data=dataloader,
+ epochs=num_epochs)
+
+def test_val_handlers():
+ net = _get_test_network()
+ train_data, _ = _get_test_data()
+ val_data, _ = _get_test_data()
+
+ num_epochs = 1
+ ctx = mx.cpu()
+ net.initialize(ctx=ctx)
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+
+ train_acc = mx.metric.RMSE()
+ loss = gluon.loss.L2Loss()
+
+ est = Estimator(net=net,
+ loss=loss,
+ train_metrics=train_acc,
+ trainer=trainer,
+ context=ctx)
+
+ with warnings.catch_warnings(record=True) as w:
+ est.fit(train_data=train_data, epochs=num_epochs)
+ est.evaluate(val_data=val_data)
+
+ logging = LoggingHandler(log_interval=1, metrics=est.val_metrics)
+ est.evaluate(val_data=val_data, event_handlers=[logging])
+
diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py
index 658fb88..41b7901 100644
--- a/tests/python/unittest/test_gluon_event_handler.py
+++ b/tests/python/unittest/test_gluon_event_handler.py
@@ -54,7 +54,7 @@ def test_checkpoint_handler():
net = _get_test_network()
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
- est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+ est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
model_prefix=model_prefix,
monitor=acc,
@@ -72,7 +72,7 @@ def test_checkpoint_handler():
file_path = os.path.join(tmpdir, model_prefix)
net = _get_test_network(nn.HybridSequential())
net.hybridize()
- est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+ est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
model_prefix=model_prefix,
epoch_period=None,
@@ -100,7 +100,7 @@ def test_resume_checkpoint():
net = _get_test_network()
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
- est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+ est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
model_prefix=model_prefix,
monitor=acc,
@@ -125,7 +125,7 @@ def test_early_stopping():
net = _get_test_network()
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
- est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+ est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
early_stopping = event_handler.EarlyStoppingHandler(monitor=acc,
patience=0,
mode='min')
@@ -149,14 +149,13 @@ def test_logging():
net = _get_test_network()
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
- est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+ est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
est.logger.addHandler(logging.FileHandler(output_dir))
train_metrics = est.train_metrics
val_metrics = est.val_metrics
- logging_handler = event_handler.LoggingHandler(train_metrics=train_metrics,
- val_metrics=val_metrics)
+ logging_handler = event_handler.LoggingHandler(metrics=train_metrics)
est.fit(test_data, event_handlers=[logging_handler], epochs=3)
assert logging_handler.batch_index == 0
assert logging_handler.current_epoch == 3
@@ -197,7 +196,7 @@ def test_custom_handler():
net = _get_test_network()
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
- est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+ est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
custom_handler = CustomStopHandler(3, 2)
est.fit(test_data, event_handlers=[custom_handler], epochs=3)
assert custom_handler.num_batch == 3
@@ -220,10 +219,10 @@ def test_logging_interval():
num_epochs = 1
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
- logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval)
+ logging = LoggingHandler(metrics=[acc], log_interval=log_interval)
est = estimator.Estimator(net=net,
loss=ce_loss,
- metrics=acc)
+ train_metrics=acc)
est.fit(train_data=dataloader,
epochs=num_epochs,
@@ -245,10 +244,10 @@ def test_logging_interval():
sys.stdout = mystdout = StringIO()
acc = mx.metric.Accuracy()
log_interval = 5
- logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval)
+ logging = LoggingHandler(metrics=[acc], log_interval=log_interval)
est = estimator.Estimator(net=net,
loss=ce_loss,
- metrics=acc)
+ train_metrics=acc)
est.fit(train_data=dataloader,
epochs=num_epochs,
event_handlers=[logging])