You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/05/18 18:04:00 UTC
[incubator-mxnet] branch master updated: [MXNET-1333] Estimator and
Fit API (#14629)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 9f451fb [MXNET-1333] Estimator and Fit API (#14629)
9f451fb is described below
commit 9f451fb6f4265f7e122ca08a386e85595a5030a2
Author: Lai Wei <ro...@gmail.com>
AuthorDate: Sat May 18 11:03:18 2019 -0700
[MXNET-1333] Estimator and Fit API (#14629)
* [MXNet-1334][Fit API]base class for estimator and eventhandler (#14346)
* base class for estimator and eventhandler
* add license
* add event handlers
* fix pylint
* improve arg check
* fix pylint
* add unit tests
* Fixed issue where the estimator was printing beyond the dataset size … (#14464)
* Fixed issue where the estimator was printing beyond the dataset size for the last batch
* Added comments
* Nudge to CI
* [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (#14442)
* added estimator unittests
* add more tests for estimator
* added validation logic
* added error handlers, unittests
* improve val stats
* fix pylint
* fix pylint
* update unit test
* fix tests
* fix tests
* updated metrics, val logic
* trigger ci
* trigger ci
* update metric, batch_fn error handler
* update context logic, add default metric
* [MXNet-1340][Fit API]Update train stats (#14494)
* add train history
* update history
* update test
* avoid calling empty methods
* remove train history object
* fix pylint
* add unit test
* fix test
* update categorize handlers
* [MXNet-1375][Fit API]Added RNN integration test for fit() API (#14547)
* Added RNN integration test for fit() API
* Addressed review comments: change in JenkinFile, tmp directory, ctx with condense if/else, renamed imports
* CPU test doesn't require nvidiadocker container
* Modified the structure by removing the redundant code
* [MXNet-1343][Fit API]Add CNN integration test for fit() API (#14405)
* added cnn intg tests for fit api
* updated cnn intg tests
* added functions for nightly test
* updated runtime_function
* updated intg tests
* updated init, datapath, refs
* added validation data
* update cpu test
* refactor code
* updated context
* [MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging verbose support for Gluon fit() API (#14587)
* Retrieve Batch size and Logging verbose support for Gluon fit() API
* NIT changes
* Addressed review comments: shifted the batch size code to a separate method, sentence correction
* Modified unittest
* removed redundant parameter
* Resolve CI test failure
* only support DataLoader for now, future PRs will include DataIter to DataLoader converter
* Get the number of samples from shape attribute instead of length due to low space complexity
* Simplified batch size retrieval code
* removed batch_size parameter from fit() method and fixed the tests
* Verbose exception handling
* Assigning constant to a verbose
* Modified exception message
* Resolved undefined class reference
* Addressed review comments: Modified verbose level names, docs, variable names
* Update estimator.py
* move estimator to contrib (#14633)
* move to gluon contrib (#14635)
* [Fit API] improve event handlers (#14685)
* improve event handlers
* update tests
* passing weakref of estimator
* fix unit test
* fix test
* fix pylint
* fix test
* fix pylint
* move default metric logic
* combine nightly tests
* [MXNET-1396][Fit-API] Update default handler logic (#14765)
* move to nightly for binaries
* update default handler
* fix pylint
* trigger ci
* trigger ci
* [Fit API] update estimator (#14849)
* address comments
* add comment
* check available context
* fix bug
* change cpu check
* [Fit-API] Adress PR comments (#14885)
* address comments
* update checkpoint
* test symbol save
* address comments
* add resume
* update doc and resume checkpoint
* update docs
* trigger ci
* trigger ci
---
ci/docker/runtime_functions.sh | 10 +
python/mxnet/gluon/contrib/estimator/__init__.py | 21 +
python/mxnet/gluon/contrib/estimator/estimator.py | 408 ++++++++++++
.../mxnet/gluon/contrib/estimator/event_handler.py | 705 +++++++++++++++++++++
python/mxnet/gluon/trainer.py | 7 +
tests/nightly/JenkinsfileForBinaries | 8 +
tests/nightly/estimator/test_estimator_cnn.py | 151 +++++
tests/nightly/estimator/test_sentiment_rnn.py | 276 ++++++++
tests/python/unittest/test_gluon_estimator.py | 371 +++++++++++
tests/python/unittest/test_gluon_event_handler.py | 198 ++++++
10 files changed, 2155 insertions(+)
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index e1da222..58e39ef 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -1350,6 +1350,16 @@ nightly_scala_demo_test_cpu() {
bash bin/run_im.sh
}
+nightly_estimator() {
+ set -ex
+ cd /work/mxnet/tests/nightly/estimator
+ export PYTHONPATH=/work/mxnet/python/
+ python test_estimator_cnn.py --type gpu
+ python test_sentiment_rnn.py --type gpu
+ python test_estimator_cnn.py --type cpu
+ python test_sentiment_rnn.py --type cpu
+}
+
# Deploy
deploy_docs() {
diff --git a/python/mxnet/gluon/contrib/estimator/__init__.py b/python/mxnet/gluon/contrib/estimator/__init__.py
new file mode 100644
index 0000000..58600da
--- /dev/null
+++ b/python/mxnet/gluon/contrib/estimator/__init__.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=wildcard-import
+"""Gluon Estimator Module"""
+from .estimator import *
+from .event_handler import *
diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py
new file mode 100644
index 0000000..da1a391
--- /dev/null
+++ b/python/mxnet/gluon/contrib/estimator/estimator.py
@@ -0,0 +1,408 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import, unused-variable
+"""Gluon Estimator"""
+
+import copy
+import warnings
+
+from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler
+from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd
+from .... import gluon, autograd
+from ....context import Context, cpu, gpu, num_gpus
+from ....metric import EvalMetric, Loss, Accuracy
+
+__all__ = ['Estimator']
+
+
+class Estimator(object):
+ """Estimator Class for easy model training
+
+ :py:class:`Estimator` can be used to facilitate the training & validation process
+
+
+ Parameters
+ ----------
+ net : Block
+ The model used for training.
+ loss : gluon.loss.Loss or list of gluon.loss.Loss
+ Loss(objective functions) to calculate during training.
+ metrics : EvalMetric or list of EvalMetric
+ Metrics for evaluating models.
+ initializer : Initializer
+ Initializer to initialize the network.
+ trainer : Trainer
+ Trainer to apply optimizer on network parameters.
+ context : Context or list of Context
+ Device(s) to run the training on.
+ """
+
+ def __init__(self, net,
+ loss,
+ metrics=None,
+ initializer=None,
+ trainer=None,
+ context=None):
+
+ self.net = net
+ self.loss = self._check_loss(loss)
+ self.train_metrics = self._check_metrics(metrics)
+
+ self.context = self._check_context(context)
+ self._initialize(initializer)
+ self.trainer = self._check_trainer(trainer)
+
+ def _check_loss(self, loss):
+ if isinstance(loss, gluon.loss.Loss):
+ loss = [loss]
+ elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss) for l in loss]):
+ loss = loss
+ else:
+ raise ValueError("loss must be a Loss or a list of Loss, "
+ "refer to gluon.loss.Loss:{}".format(loss))
+ return loss
+
+ def _check_metrics(self, metrics):
+ if isinstance(metrics, EvalMetric):
+ metrics = [metrics]
+ else:
+ metrics = metrics or []
+ if not all([isinstance(metric, EvalMetric) for metric in metrics]):
+ raise ValueError("metrics must be a Metric or a list of Metric, "
+ "refer to mxnet.metric.EvalMetric:{}".format(metrics))
+ return metrics
+
+ def _check_context(self, context):
+ # infer available context
+ gpus = num_gpus()
+ available_gpus = [gpu(i) for i in range(gpus)]
+
+ if context:
+ # check context values, only accept Context or a list of Context
+ if isinstance(context, Context):
+ context = [context]
+ elif isinstance(context, list) and all([isinstance(c, Context) for c in context]):
+ context = context
+ else:
+ raise ValueError("context must be a Context or a list of Context, "
+ "for example mx.cpu() or [mx.gpu(0), mx.gpu(1)], "
+ "refer to mxnet.Context:{}".format(context))
+ for ctx in context:
+ assert ctx in available_gpus or str(ctx).startswith('cpu'), \
+ "%s is not available, please make sure " \
+ "your context is in one of: mx.cpu(), %s" % \
+ (ctx, ", ".join([str(ctx) for ctx in available_gpus]))
+ else:
+ # provide default context
+ if gpus > 0:
+ # only use 1 GPU by default
+ if gpus > 1:
+ warnings.warn("You have multiple GPUs, gpu(0) will be used by default."
+ "To utilize all your GPUs, specify context as a list of gpus, "
+ "e.g. context=[mx.gpu(0), mx.gpu(1)] ")
+ context = [gpu(0)]
+ else:
+ context = [cpu()]
+ return context
+
+ def _initialize(self, initializer):
+ # initialize the network
+ if not self._is_initialized():
+ # net is partially or not initialized,
+ # initialize with user specified initializer
+ # if initializer is None, default initializer will be used
+ # do not re-init layers already initialized
+ if initializer:
+ self.net.initialize(init=initializer, ctx=self.context)
+ else:
+ self.net.initialize(ctx=self.context)
+ elif initializer:
+ # net is fully initialized, and user passed not None initializer
+ # do not force reinitialize, give warning
+ warnings.warn("Network already fully initialized, skipping initialization. "
+ "You don't need to pass initializer if you already "
+ "initialized your net. "
+ "You can use net.initialize(init=your_initializer, force_reinit=True)"
+ "to force re-initialize.")
+
+ def _check_trainer(self, trainer):
+ # handle trainer
+ if not trainer:
+ warnings.warn("No trainer specified, default SGD optimizer "
+ "with learning rate 0.001 is used.")
+ trainer = gluon.Trainer(self.net.collect_params(),
+ 'sgd', {'learning_rate': 0.001})
+ elif not isinstance(trainer, gluon.Trainer):
+ raise ValueError("Trainer must be a Gluon Trainer instance, refer to "
+ "gluon.Trainer:{}".format(trainer))
+ return trainer
+
+ def _is_initialized(self):
+ param_dict = self.net.collect_params()
+ for param in param_dict:
+ try:
+ param_dict[param].list_ctx()
+ except RuntimeError:
+ return False
+ return True
+
+ def _get_data_and_label(self, batch, ctx, batch_axis=0):
+ data = batch[0]
+ label = batch[1]
+ data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=batch_axis)
+ label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=batch_axis)
+ return data, label
+
+ def prepare_loss_and_metrics(self):
+ """
+ Based on loss functions and training metrics in estimator
+ Create metric wrappers to record loss values,
+ Create copies of train loss/metric objects to record validation values
+ Returns train_metrics and val_metrics
+
+ """
+ if any(not hasattr(self, attribute) for attribute in
+ ['train_metrics', 'val_metrics']):
+ # Use default mx.metric.Accuracy() for gluon.loss.SoftmaxCrossEntropyLoss()
+ if not self.train_metrics and any([isinstance(l, gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]):
+ self.train_metrics = [Accuracy()]
+ self.val_metrics = []
+ for loss in self.loss:
+ # remove trailing numbers from loss name to avoid confusion
+ self.train_metrics.append(Loss(loss.name.rstrip('1234567890')))
+ for metric in self.train_metrics:
+ val_metric = copy.deepcopy(metric)
+ metric.name = "train " + metric.name
+ val_metric.name = "validation " + val_metric.name
+ self.val_metrics.append(val_metric)
+ return self.train_metrics, self.val_metrics
+
+ def evaluate(self,
+ val_data,
+ val_metrics,
+ batch_axis=0):
+ """Evaluate model on validation data
+
+ Parameters
+ ----------
+ val_data : DataLoader
+ Validation data loader with data and labels.
+ val_metrics : EvalMetric or list of EvalMetrics
+ Metrics to update validation result.
+ batch_axis : int, default 0
+ Batch axis to split the validation data into devices.
+ """
+ if not isinstance(val_data, gluon.data.DataLoader):
+ raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
+ "can transform your DataIter or any NDArray into Gluon DataLoader. "
+ "Refer to gluon.data.dataloader")
+
+ for metric in val_metrics:
+ metric.reset()
+
+ for _, batch in enumerate(val_data):
+ data, label = self._get_data_and_label(batch, self.context, batch_axis)
+ pred = [self.net(x) for x in data]
+ loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
+ # update metrics
+ for metric in val_metrics:
+ if isinstance(metric, Loss):
+ metric.update(0, loss)
+ else:
+ metric.update(label, pred)
+
+ def fit(self, train_data,
+ val_data=None,
+ epochs=None,
+ event_handlers=None,
+ batches=None,
+ batch_axis=0):
+ """Trains the model with a given :py:class:`DataLoader` for a specified
+ number of epochs or batches. The batch size is inferred from the
+ data loader's batch_size.
+
+ Parameters
+ ----------
+ train_data : DataLoader
+ Training data loader with data and labels.
+ val_data : DataLoader, default None
+ Validation data loader with data and labels.
+ epochs : int, default None
+ Number of epochs to iterate on the training data.
+ You can only specify one and only one type of iteration(epochs or batches).
+ event_handlers : EventHandler or list of EventHandler
+ List of :py:class:`EventHandlers` to apply during training.
+ batches : int, default None
+ Number of batches to iterate on the training data.
+ You can only specify one and only one type of iteration(epochs or batches).
+ batch_axis : int, default 0
+ Batch axis to split the training data into devices.
+ """
+ if not isinstance(train_data, gluon.data.DataLoader):
+ raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
+ "can transform your DataIter or any NDArray into Gluon DataLoader. "
+ "Refer to gluon.data.dataloader")
+
+ # must specify one and only one of epochs or batches
+ if (not epochs) == (not batches):
+ raise ValueError(
+ "Fit only support exactly one type of iteration, "
+ "train by number of epochs or number of batches."
+ "Please specify one and only one of: epochs or batches.")
+
+ self.max_epoch = epochs
+ self.max_batch = batches
+
+ # provide default handlers
+ event_handlers = self._prepare_default_handlers(val_data, event_handlers)
+
+ train_begin, epoch_begin, batch_begin, \
+ batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers)
+
+ # pass a reference to all event handlers
+ estimator_ref = self
+ # training begin
+ for handler in train_begin:
+ handler.train_begin(estimator_ref)
+
+ while True:
+ # epoch begin
+ for handler in epoch_begin:
+ handler.epoch_begin(estimator_ref)
+
+ for i, batch in enumerate(train_data):
+ data, label = self._get_data_and_label(batch, self.context, batch_axis)
+
+ batch_size = batch[0].shape[0]
+
+ # batch begin
+ for handler in batch_begin:
+ handler.batch_begin(estimator_ref, batch=batch)
+
+ with autograd.record():
+ pred = [self.net(x) for x in data]
+ loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
+
+ for l in loss:
+ l.backward()
+
+ self.trainer.step(batch_size)
+ # batch end
+
+ batch_end_result = []
+ for handler in batch_end:
+ batch_end_result.append(handler.batch_end(estimator_ref, batch=batch,
+ pred=pred, label=label, loss=loss))
+ # if any handler signaled to stop
+ if any(batch_end_result):
+ break
+
+ # epoch end
+ epoch_end_result = []
+ for handler in epoch_end:
+ epoch_end_result.append(handler.epoch_end(estimator_ref))
+ # if any handler signaled to stop
+ if any(epoch_end_result):
+ break
+
+ # train end
+ for handler in train_end:
+ handler.train_end(estimator_ref)
+
+ def _prepare_default_handlers(self, val_data, event_handlers):
+ event_handlers = event_handlers or []
+ default_handlers = []
+ train_metrics, val_metrics = self.prepare_loss_and_metrics()
+
+ # no need to add to default handler check as StoppingHandler does not use metrics
+ event_handlers.append(StoppingHandler(self.max_epoch, self.max_batch))
+
+ if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
+ event_handlers.append(MetricHandler(train_metrics=train_metrics))
+ default_handlers.append("MetricHandler")
+
+ if val_data and not any(isinstance(handler, ValidationHandler) for handler in event_handlers):
+ event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate,
+ val_metrics=val_metrics))
+ default_handlers.append("ValidationHandler")
+
+ if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
+ event_handlers.append(LoggingHandler(train_metrics=train_metrics,
+ val_metrics=val_metrics))
+ default_handlers.append("LoggingHandler")
+
+ # if there is a mix of user defined event handlers and default event handlers
+ # they should have the same set of loss and metrics
+ if default_handlers:
+ msg = "You are training with the following default event handlers: %s. " \
+ "They use loss and metrics from estimator.prepare_loss_and_metrics(). " \
+ "Please use the same set of metrics for all your other handlers." % \
+ ", ".join(default_handlers)
+ warnings.warn(msg)
+ # check if all handlers has the same set of references to loss and metrics
+ references = []
+ for handler in event_handlers:
+ for attribute in dir(handler):
+ if any(keyword in attribute for keyword in ['metric' or 'monitor']):
+ reference = getattr(handler, attribute)
+ if isinstance(reference, list):
+ references += reference
+ else:
+ references.append(reference)
+ # remove None metric references
+ references = set([ref for ref in references if ref])
+ for metric in references:
+ if metric not in train_metrics + val_metrics:
+ msg = "We have added following default handlers for you: %s and used " \
+ "estimator.prepare_loss_and_metrics() to pass metrics to " \
+ "those handlers. Please use the same set of metrics " \
+ "for all your handlers." % \
+ ", ".join(default_handlers)
+ raise ValueError(msg)
+
+ event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0))
+ return event_handlers
+
+ def _categorize_handlers(self, event_handlers):
+ """
+ categorize handlers into 6 event lists to avoid calling empty methods
+ for example, only event handlers with train_begin method
+ implemented will be called at train begin
+ """
+
+ train_begin = []
+ epoch_begin = []
+ batch_begin = []
+ batch_end = []
+ epoch_end = []
+ train_end = []
+ for handler in event_handlers:
+ if isinstance(handler, TrainBegin):
+ train_begin.append(handler)
+ if isinstance(handler, EpochBegin):
+ epoch_begin.append(handler)
+ if isinstance(handler, BatchBegin):
+ batch_begin.append(handler)
+ if isinstance(handler, BatchEnd):
+ batch_end.append(handler)
+ if isinstance(handler, EpochEnd):
+ epoch_end.append(handler)
+ if isinstance(handler, TrainEnd):
+ train_end.append(handler)
+ return train_begin, epoch_begin, batch_begin, batch_end, epoch_end, train_end
diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py
new file mode 100644
index 0000000..ce5890e
--- /dev/null
+++ b/python/mxnet/gluon/contrib/estimator/event_handler.py
@@ -0,0 +1,705 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import, unused-argument
+"""Gluon EventHandlers for Estimators"""
+
+import logging
+import os
+import time
+import warnings
+
+import numpy as np
+
+from ....metric import EvalMetric, Loss
+
+
+class TrainBegin(object):
+ def train_begin(self, estimator, *args, **kwargs):
+ pass
+
+
+class TrainEnd(object):
+ def train_end(self, estimator, *args, **kwargs):
+ pass
+
+
+class EpochBegin(object):
+ def epoch_begin(self, estimator, *args, **kwargs):
+ pass
+
+
+class EpochEnd(object):
+ def epoch_end(self, estimator, *args, **kwargs):
+ return False
+
+
+class BatchBegin(object):
+ def batch_begin(self, estimator, *args, **kwargs):
+ pass
+
+
+class BatchEnd(object):
+ def batch_end(self, estimator, *args, **kwargs):
+ return False
+
+
+class StoppingHandler(TrainBegin, BatchEnd, EpochEnd):
+ """Stop conditions to stop training
+ Stop training if maximum number of batches or epochs
+ reached.
+
+ Parameters
+ ----------
+ max_epoch : int, default None
+ Number of maximum epochs to train.
+ max_batch : int, default None
+ Number of maximum batches to train.
+
+ """
+
+ def __init__(self, max_epoch=None, max_batch=None):
+ self.max_epoch = max_epoch
+ self.max_batch = max_batch
+ self.current_batch = 0
+ self.current_epoch = 0
+ self.stop_training = False
+
+ def train_begin(self, estimator, *args, **kwargs):
+ self.max_epoch = estimator.max_epoch
+ self.max_batch = estimator.max_batch
+ self.current_batch = 0
+ self.current_epoch = 0
+
+ def batch_end(self, estimator, *args, **kwargs):
+ self.current_batch += 1
+ if self.current_batch == self.max_batch:
+ self.stop_training = True
+ return self.stop_training
+
+ def epoch_end(self, estimator, *args, **kwargs):
+ self.current_epoch += 1
+ if self.current_epoch == self.max_epoch:
+ self.stop_training = True
+ return self.stop_training
+
+
+class MetricHandler(EpochBegin, BatchEnd):
+ """Metric Handler that update metric values at batch end
+
+ :py:class:`MetricHandler` takes model predictions and true labels
+ and update the metrics, it also update metric wrapper for loss with loss values.
+ Validation loss and metrics will be handled by :py:class:`ValidationHandler`
+
+ Parameters
+ ----------
+ train_metrics : List of EvalMetrics
+ Training metrics to be updated at batch end.
+ """
+
+ def __init__(self, train_metrics):
+ self.train_metrics = train_metrics or []
+ # order to be called among all callbacks
+ # metrics need to be calculated before other callbacks can access them
+ self.priority = -np.Inf
+
+ def epoch_begin(self, estimator, *args, **kwargs):
+ for metric in self.train_metrics:
+ metric.reset()
+
+ def batch_end(self, estimator, *args, **kwargs):
+ pred = kwargs['pred']
+ label = kwargs['label']
+ loss = kwargs['loss']
+ for metric in self.train_metrics:
+ if isinstance(metric, Loss):
+ # metric wrapper for loss values
+ metric.update(0, loss)
+ else:
+ metric.update(label, pred)
+
+
+class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
+ """"Validation Handler that evaluate model on validation dataset
+
+ :py:class:`ValidationHandler` takes validation dataset, an evaluation function,
+ metrics to be evaluated, and how often to run the validation. You can provide custom
+ evaluation function or use the one provided my :py:class:`Estimator`
+
+ Parameters
+ ----------
+ val_data : DataLoader
+ Validation data set to run evaluation.
+ eval_fn : function
+ A function defines how to run evaluation and
+ calculate loss and metrics.
+ val_metrics : List of EvalMetrics
+ Validation metrics to be updated.
+ epoch_period : int, default 1
+ How often to run validation at epoch end, by default
+ :py:class:`ValidationHandler` validate every epoch.
+ batch_period : int, default None
+ How often to run validation at batch end, by default
+ :py:class:`ValidationHandler` does not validate at batch end.
+ """
+
+ def __init__(self,
+ val_data,
+ eval_fn,
+ val_metrics=None,
+ epoch_period=1,
+ batch_period=None):
+ self.val_data = val_data
+ self.eval_fn = eval_fn
+ self.epoch_period = epoch_period
+ self.batch_period = batch_period
+ self.val_metrics = val_metrics
+ self.current_batch = 0
+ self.current_epoch = 0
+ # order to be called among all callbacks
+ # validation metrics need to be calculated before other callbacks can access them
+ self.priority = -np.Inf
+ self.logger = logging.getLogger(__name__)
+
+ def train_begin(self, estimator, *args, **kwargs):
+ # reset epoch and batch counter
+ self.current_batch = 0
+ self.current_epoch = 0
+
+ def batch_end(self, estimator, *args, **kwargs):
+ self.current_batch += 1
+ if self.batch_period and self.current_batch % self.batch_period == 0:
+ self.eval_fn(val_data=self.val_data,
+ val_metrics=self.val_metrics)
+ msg = '[Epoch %d] ValidationHandler: %d batches reached, ' \
+ % (self.current_epoch, self.current_batch)
+ for monitor in self.val_metrics:
+ name, value = monitor.get()
+ msg += '%s: %.4f, ' % (name, value)
+ self.logger.info(msg.rstrip(','))
+
+ def epoch_end(self, estimator, *args, **kwargs):
+ self.current_epoch += 1
+ if self.epoch_period and self.current_epoch % self.epoch_period == 0:
+ self.eval_fn(val_data=self.val_data,
+ val_metrics=self.val_metrics)
+
+
+class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd):
+ """Basic Logging Handler that applies to every Gluon estimator by default.
+
+ :py:class:`LoggingHandler` logs hyper-parameters, training statistics,
+ and other useful information during training
+
+ Parameters
+ ----------
+ file_name : str
+ File name to save the logs.
+ file_location : str
+ File location to save the logs.
+ filemode : str, default 'a'
+ Logging file mode, default using append mode.
+ verbose : int, default LOG_PER_EPOCH
+ Limit the granularity of metrics displayed during training process.
+ verbose=LOG_PER_EPOCH: display metrics every epoch
+ verbose=LOG_PER_BATCH: display metrics every batch
+ train_metrics : list of EvalMetrics
+ Training metrics to be logged, logged at batch end, epoch end, train end.
+ val_metrics : list of EvalMetrics
+ Validation metrics to be logged, logged at epoch end, train end.
+ """
+
+ LOG_PER_EPOCH = 1
+ LOG_PER_BATCH = 2
+
+ def __init__(self, file_name=None,
+ file_location=None,
+ filemode='a',
+ verbose=LOG_PER_EPOCH,
+ train_metrics=None,
+ val_metrics=None):
+ super(LoggingHandler, self).__init__()
+ self.logger = logging.getLogger(__name__)
+ self.logger.setLevel(logging.INFO)
+ stream_handler = logging.StreamHandler()
+ self.logger.addHandler(stream_handler)
+ # save logger to file only if file name or location is specified
+ if file_name or file_location:
+ file_name = file_name or 'estimator_log'
+ file_location = file_location or './'
+ file_handler = logging.FileHandler(os.path.join(file_location, file_name), mode=filemode)
+ self.logger.addHandler(file_handler)
+ if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]:
+ raise ValueError("verbose level must be either LOG_PER_EPOCH or "
+ "LOG_PER_BATCH, received %s. "
+ "E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)"
+ % verbose)
+ self.verbose = verbose
+ self.train_metrics = train_metrics or []
+ self.val_metrics = val_metrics or []
+ self.batch_index = 0
+ self.current_epoch = 0
+ self.processed_samples = 0
+ # logging handler need to be called at last to make sure all states are updated
+ # it will also shut down logging at train end
+ self.priority = np.Inf
+
+ def train_begin(self, estimator, *args, **kwargs):
+ self.train_start = time.time()
+ trainer = estimator.trainer
+ optimizer = trainer.optimizer.__class__.__name__
+ lr = trainer.learning_rate
+ self.logger.info("Training begin: using optimizer %s "
+ "with current learning rate %.4f ",
+ optimizer, lr)
+ if estimator.max_epoch:
+ self.logger.info("Train for %d epochs.", estimator.max_epoch)
+ else:
+ self.logger.info("Train for %d batches.", estimator.max_batch)
+ # reset all counters
+ self.current_epoch = 0
+ self.batch_index = 0
+ self.processed_samples = 0
+
+ def train_end(self, estimator, *args, **kwargs):
+ train_time = time.time() - self.train_start
+ msg = 'Train finished using total %ds with %d epochs. ' % (train_time, self.current_epoch)
+ # log every result in train stats including train/validation loss & metrics
+ for metric in self.train_metrics + self.val_metrics:
+ name, value = metric.get()
+ msg += '%s: %.4f, ' % (name, value)
+ self.logger.info(msg.rstrip(', '))
+ # make a copy of handler list and remove one by one
+ # as removing handler will edit the handler list
+ for handler in self.logger.handlers[:]:
+ handler.close()
+ self.logger.removeHandler(handler)
+ logging.shutdown()
+
+ def batch_begin(self, estimator, *args, **kwargs):
+ if self.verbose == self.LOG_PER_BATCH:
+ self.batch_start = time.time()
+
+ def batch_end(self, estimator, *args, **kwargs):
+ if self.verbose == self.LOG_PER_BATCH:
+ batch_time = time.time() - self.batch_start
+ msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index)
+ self.processed_samples += kwargs['batch'][0].shape[0]
+ msg += '[Samples %s] ' % (self.processed_samples)
+ msg += 'time/batch: %.3fs ' % batch_time
+ for metric in self.train_metrics:
+ # only log current training loss & metric after each batch
+ name, value = metric.get()
+ msg += '%s: %.4f, ' % (name, value)
+ self.logger.info(msg.rstrip(', '))
+ self.batch_index += 1
+
+ def epoch_begin(self, estimator, *args, **kwargs):
+ if self.verbose >= self.LOG_PER_EPOCH:
+ self.epoch_start = time.time()
+ self.logger.info("[Epoch %d] Begin, current learning rate: %.4f",
+ self.current_epoch, estimator.trainer.learning_rate)
+
+ def epoch_end(self, estimator, *args, **kwargs):
+ if self.verbose >= self.LOG_PER_EPOCH:
+ epoch_time = time.time() - self.epoch_start
+ msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time)
+ for monitor in self.train_metrics + self.val_metrics:
+ name, value = monitor.get()
+ msg += '%s: %.4f, ' % (name, value)
+ self.logger.info(msg.rstrip(', '))
+ self.current_epoch += 1
+ self.batch_index = 0
+
+
+class CheckpointHandler(TrainBegin, BatchEnd, EpochEnd):
+ """Save the model after user define period
+
+ :py:class:`CheckpointHandler` saves the network architecture after first batch if the model
+ can be fully hybridized, saves model parameters and trainer states after user defined period,
+ default saves every epoch.
+
+ Parameters
+ ----------
+ model_dir : str
+ File directory to save all the model related files including model architecture,
+ model parameters, and trainer states.
+ model_prefix : str default 'model'
+ Prefix to add for all checkpoint file names.
+ monitor: EvalMetric, default None
+ The metrics to monitor and determine if model has improved
+ verbose: int, default 0
+ Verbosity mode, 1 means inform user every time a checkpoint is saved
+ save_best: bool, default False
+ If True, monitor must not be None, :py:class:`CheckpointHandler` will save the
+ model parameters and trainer states with the best monitored value.
+ mode: str, default 'auto'
+ One of {auto, min, max}, if `save_best=True`, the comparison to make
+ and determine if the monitored value has improved. if 'auto' mode,
+ :py:class:`CheckpointHandler` will try to use min or max based on
+ the monitored metric name.
+ epoch_period: int, default 1
+ Epoch intervals between saving the network. By default, checkpoints are
+ saved every epoch.
+ batch_period: int, default None
+ Batch intervals between saving the network.
+ By default, checkpoints are not saved based on the number of batches.
+ max_checkpoints : int, default 5
+ Maximum number of checkpoint files to keep in the model_dir, older checkpoints
+ will be removed. Best checkpoint file is not counted.
+ resume_from_checkpoint : bool, default False
+ Whether to resume training from checkpoint in model_dir. If True and checkpoints
+ found, :py:class:`CheckpointHandler` will load net parameters and trainer states,
+ and train the remaining of epochs and batches.
+ """
+
+ def __init__(self,
+ model_dir,
+ model_prefix='model',
+ monitor=None,
+ verbose=0,
+ save_best=False,
+ mode='auto',
+ epoch_period=1,
+ batch_period=None,
+ max_checkpoints=5,
+ resume_from_checkpoint=False):
+ self.monitor = monitor
+ self.verbose = verbose
+ if not os.path.exists(model_dir):
+ os.makedirs(model_dir)
+ self.model_dir = model_dir
+ self.model_prefix = model_prefix
+ self.save_best = save_best
+ if self.save_best and not isinstance(self.monitor, EvalMetric):
+ raise ValueError("To save best model only, please provide one of the metric objects as monitor, "
+ "You can get these objects using estimator.prepare_loss_and_metric()")
+ self.epoch_period = epoch_period
+ self.batch_period = batch_period
+ self.current_batch = 0
+ self.current_epoch = 0
+ self.max_checkpoints = max_checkpoints
+ self.resume_from_checkpoint = resume_from_checkpoint
+ self.saved_checkpoints = []
+ self.logger = logging.getLogger(__name__)
+ if self.save_best:
+ if mode not in ['auto', 'min', 'max']:
+ warnings.warn('ModelCheckpoint mode %s is unknown, '
+ 'fallback to auto mode. CheckpointHandler will use'
+ 'max mode for f1 and accuracy metric comparison and '
+ 'use min mode other wise' % (mode),
+ RuntimeWarning)
+ mode = 'auto'
+
+ if mode == 'min':
+ self.monitor_op = np.less
+ self.best = np.Inf
+ elif mode == 'max':
+ self.monitor_op = np.greater
+ self.best = -np.Inf
+ else:
+ # use greater for accuracy and f1 and less otherwise
+ if 'acc' or 'f1' in self.monitor.get()[0].lower():
+ self.logger.info("`greater` operator will be used to determine "
+ "if %s has improved, please use `min` for mode "
+ "if you want otherwise", self.monitor.get()[0])
+ self.monitor_op = np.greater
+ else:
+ self.logger.info("`less` operator will be used to determine "
+ "if %s has improved, please use `max` for mode "
+ "if you want otherwise", self.monitor.get()[0])
+ self.monitor_op = np.less
+
+ def train_begin(self, estimator, *args, **kwargs):
+ # reset all counters
+ self.current_epoch = 0
+ self.current_batch = 0
+ if self.save_best:
+ self.best = np.Inf if self.monitor_op == np.less else -np.Inf
+ if self.resume_from_checkpoint:
+ error_msg = "To use resume from checkpoint, you must only specify " \
+ "the same type of period you used for training." \
+ "For example, if you are training based on number of epochs," \
+ "you must save only based on epochs, and set batch_period to None."
+ if estimator.max_batch:
+ assert self.batch_period, error_msg
+ assert not self.epoch_period, error_msg
+ if estimator.max_epoch:
+ assert self.epoch_period, error_msg
+ assert not self.batch_period, error_msg
+
+ self._resume_from_checkpoint(estimator)
+
+ def batch_end(self, estimator, *args, **kwargs):
+ # only save symbol once after first batch
+ if self.current_batch == 0:
+ self._save_symbol(estimator)
+ if self.batch_period and (self.current_batch + 1) % self.batch_period == 0:
+ self._save_checkpoint(estimator)
+ self.current_batch += 1
+
+ def epoch_end(self, estimator, *args, **kwargs):
+ if self.epoch_period and (self.current_epoch + 1) % self.epoch_period == 0:
+ self._save_checkpoint(estimator)
+ self.current_epoch += 1
+
+ def _save_checkpoint(self, estimator):
+ # if resumed from checkpoint, increment checkpoint number
+ if self.resume_from_checkpoint:
+ save_epoch_number = self.current_epoch + self.trained_epoch + 1
+ if estimator.max_epoch:
+ # checkpoint saved at epoch end, batch number already incremented
+ save_batch_number = self.current_batch + self.trained_batch
+ else:
+ save_batch_number = self.current_batch + self.trained_batch + 1
+ else:
+ save_epoch_number = self.current_epoch
+ save_batch_number = self.current_batch
+ prefix = "%s-epoch%dbatch%d" % (self.model_prefix, save_epoch_number, save_batch_number)
+ self._save_params_and_trainer(estimator, prefix)
+ if self.verbose > 0:
+ self.logger.info('[Epoch %d] CheckpointHandler: trained total %d batches, '
+ 'saving model at %s with prefix: %s',
+ self.current_epoch, self.current_batch + 1, self.model_dir, prefix)
+
+ if self.save_best:
+ monitor_name, monitor_value = self.monitor.get()
+ # check if monitor exists in train stats
+ if np.isnan(monitor_value):
+ warnings.warn(RuntimeWarning('Skipping save best because %s is not updated, make sure you '
+ 'pass one of the metric objects as monitor, '
+ 'you can use estimator.prepare_loss_and_metrics to'
+ 'create all metric objects', monitor_name))
+ else:
+ if self.monitor_op(monitor_value, self.best):
+ prefix = self.model_prefix + '-best'
+ self._save_params_and_trainer(estimator, prefix)
+ self.best = monitor_value
+ if self.verbose > 0:
+ self.logger.info('[Epoch %d] CheckpointHandler: '
+ '%s improved from %0.5f to %0.5f, '
+ 'updating best model at %s with prefix: %s',
+ self.current_epoch, monitor_name,
+ self.best, monitor_value, self.model_dir, prefix)
+ else:
+ if self.verbose > 0:
+ self.logger.info('[Epoch %d] CheckpointHandler: '
+ '%s did not improve from %0.5f, '
+ 'skipping updating best model',
+ self.current_batch, monitor_name,
+ self.best)
+
+ def _save_symbol(self, estimator):
+ symbol_file = os.path.join(self.model_dir, self.model_prefix + '-symbol.json')
+ if hasattr(estimator.net, '_cached_graph'):
+ sym = estimator.net._cached_graph[1]
+ sym.save(symbol_file)
+ else:
+ self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock"
+ "to construct your model, can call net.hybridize() before passing to"
+ "Estimator in order to save model architecture as %s.", symbol_file)
+
+ def _save_params_and_trainer(self, estimator, file_prefix):
+ param_file = os.path.join(self.model_dir, file_prefix + '.params')
+ trainer_file = os.path.join(self.model_dir, file_prefix + '.states')
+ estimator.net.save_parameters(param_file)
+ estimator.trainer.save_states(trainer_file)
+
+ # only count checkpoints with epoch or batch number in file name
+ if 'best' not in file_prefix:
+ self.saved_checkpoints.append(file_prefix)
+ # remove old checkpoint when max number of checkpoints reached
+ if len(self.saved_checkpoints) > self.max_checkpoints:
+ prefix = self.saved_checkpoints.pop(0)
+ for fname in os.listdir(self.model_dir):
+ if fname.startswith(prefix):
+ os.remove(os.path.join(self.model_dir, fname))
+
+ def _resume_from_checkpoint(self, estimator):
+ prefix = self.model_prefix + '-epoch'
+ self.trained_epoch = self._find_max_iteration(
+ dir=self.model_dir,
+ prefix=prefix,
+ start='epoch',
+ end='batch',
+ saved_checkpoints=self.saved_checkpoints)
+ prefix += str(self.trained_epoch)
+ self.trained_batch = self._find_max_iteration(
+ dir=self.model_dir,
+ prefix=prefix,
+ start='batch',
+ end='.params')
+
+ if self.trained_epoch == -1:
+ msg = "CheckpointHandler: No checkpoint found, training from scratch for "
+ if estimator.max_batch:
+ msg += "%d batches" % estimator.max_batch
+ else:
+ msg += "%d epochs" % estimator.max_epoch
+ self.logger.info(msg)
+ else:
+ msg = "CheckpointHandler: Checkpoint resumed from epoch %d batch %d, " \
+ "continue to train for " % (self.trained_epoch, self.trained_batch)
+ # change maximum number of epoch or batch to train if resumed from epoch checkpoint
+ if estimator.max_epoch:
+ if self.trained_epoch >= estimator.max_epoch - 1:
+ raise ValueError("Found checkpoint with maximum number of epoch %d reached, please specify "
+ "resume_from_checkpoint=False (default value) if you wan to train from scratch."
+ % estimator.max_epoch)
+ estimator.max_epoch = estimator.max_epoch - self.trained_epoch - 1
+ msg += "%d epochs " % estimator.max_epoch
+ if estimator.max_batch:
+ if self.trained_batch >= estimator.max_batch - 1:
+ raise ValueError("Found checkpoint with maximum number of batch %d reached, please specify"
+ "resume_from_checkpoint=False (default value) if you wan to train from scratch."
+ % self.trained_batch)
+ estimator.max_batch = estimator.max_batch - self.trained_batch - 1
+ msg += "%d batches " % estimator.max_batch
+ # load checkpoint
+ param_file = "%s-epoch%dbatch%d.params" % (self.model_prefix, self.trained_epoch, self.trained_batch)
+ param_file = os.path.join(self.model_dir, param_file)
+ trainer_file = "%s-epoch%dbatch%d.states" % (self.model_prefix, self.trained_epoch, self.trained_batch)
+ trainer_file = os.path.join(self.model_dir, trainer_file)
+ assert os.path.exists(param_file), "Failed to load checkpoint, %s does not exist" % param_file
+ assert os.path.exists(trainer_file), "Failed to load checkpoint, %s does not exist" % trainer_file
+ estimator.net.load_parameters(param_file, ctx=estimator.context)
+ estimator.trainer.load_states(trainer_file)
+ self.logger.warning(msg)
+
+ def _find_max_iteration(self, dir, prefix, start, end, saved_checkpoints=None):
+ error_msg = "Error parsing checkpoint file, please check your " \
+ "checkpoints have the format: " \
+ "{model_name}-epoch{epoch_number}batch{batch_number}.params, " \
+ "there should also be a .states file for each .params file "
+ max_iter = -1
+ for fname in os.listdir(dir):
+ if fname.startswith(prefix) and '.params' in fname:
+ if saved_checkpoints:
+ # save prefix of existing checkpoints
+ saved_checkpoints.append(fname[:fname.find('.params')])
+ try:
+ # find trained number of epoch
+ iter = int(fname[fname.find(start) + len(start): fname.find(end)])
+ if iter > max_iter:
+ max_iter = iter
+ except ValueError:
+ raise ValueError(error_msg)
+ return max_iter
+
+
+class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd):
+ """Early stop training if monitored value is not improving
+
+ Parameters
+ ----------
+ monitor: EvalMetric
+ The metric to monitor, and stop training if this metric does not improve.
+ min_delta: float, default 0
+ Minimal change in monitored value to be considered as an improvement.
+ patience: int, default 0
+ Number of epochs to wait for improvement before terminate training.
+ mode: str, default 'auto'
+ One of {auto, min, max}, if `save_best_only=True`, the comparison to make
+ and determine if the monitored value has improved. if 'auto' mode, checkpoint
+ handler will try to use min or max based on the monitored metric name.
+ baseline: float
+ Baseline value to compare the monitored value with.
+ """
+
+ def __init__(self,
+ monitor,
+ min_delta=0,
+ patience=0,
+ mode='auto',
+ baseline=None):
+ super(EarlyStoppingHandler, self).__init__()
+
+ if not isinstance(monitor, EvalMetric):
+ raise ValueError("Please provide one of the metric objects as monitor, "
+ "You can create these objects using estimator.prepare_loss_and_metric()")
+ self.monitor = monitor
+ self.baseline = baseline
+ self.patience = patience
+ self.min_delta = min_delta
+ self.wait = 0
+ self.stopped_epoch = 0
+ self.current_epoch = 0
+ self.stop_training = False
+ self.logger = logging.getLogger(__name__)
+
+ if mode not in ['auto', 'min', 'max']:
+ warnings.warn('EarlyStopping mode %s is unknown, '
+ 'fallback to auto mode. CheckpointHandler will use'
+ 'max mode for f1 and accuracy metric comparison and '
+ 'use min mode other wise' % (mode),
+ RuntimeWarning)
+ mode = 'auto'
+
+ if mode == 'min':
+ self.monitor_op = np.less
+ elif mode == 'max':
+ self.monitor_op = np.greater
+ else:
+ if 'acc' or 'f1' in self.monitor.get()[0].lower():
+ self.logger.info("`greater` operator is used to determine "
+ "if %s has improved, please use `min` for mode "
+ "if you want otherwise", self.monitor.get()[0])
+ self.monitor_op = np.greater
+ else:
+ self.logger.info("`less` operator is used to determine "
+ "if %s has improved, please use `max` for mode "
+ "if you want otherwise", self.monitor.get()[0])
+ self.monitor_op = np.less
+
+ if self.monitor_op == np.greater:
+ self.min_delta *= 1
+ else:
+ self.min_delta *= -1
+
+ def train_begin(self, estimator, *args, **kwargs):
+ self.wait = 0
+ self.stopped_epoch = 0
+ self.current_epoch = 0
+ self.stop_training = False
+ if self.baseline is not None:
+ self.best = self.baseline
+ else:
+ self.best = np.Inf if self.monitor_op == np.less else -np.Inf
+
+ def epoch_end(self, estimator, *args, **kwargs):
+ monitor_name, monitor_value = self.monitor.get()
+ if np.isnan(monitor_value):
+ warnings.warn(RuntimeWarning('%s is not updated, make sure you pass one of the metric objects'
+ 'as monitor, you can use estimator.prepare_loss_and_metrics to'
+ 'create all metric objects', monitor_name))
+ else:
+ if self.monitor_op(monitor_value - self.min_delta, self.best):
+ self.best = monitor_value
+ self.wait = 0
+ else:
+ self.wait += 1
+ if self.wait >= self.patience:
+ self.stopped_epoch = self.current_epoch
+ self.stop_training = True
+ self.current_epoch += 1
+ return self.stop_training
+
+ def train_end(self, estimator, *args, **kwargs):
+ if self.stopped_epoch > 0:
+ self.logger.info('[Epoch %d] EarlyStoppingHanlder: early stopping due to %s not improving',
+ self.stopped_epoch, self.monitor.get()[0])
diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index 6935c27..0939490 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -255,6 +255,13 @@ class Trainer(object):
else:
return self._optimizer.learning_rate
+ @property
+ def optimizer(self):
+ if isinstance(self._optimizer, opt.Optimizer):
+ return self._optimizer
+ else:
+ raise UserWarning("Optimizer has not been initialized yet")
+
def set_learning_rate(self, lr):
"""Sets a new learning rate of the optimizer.
diff --git a/tests/nightly/JenkinsfileForBinaries b/tests/nightly/JenkinsfileForBinaries
index ea6db1a..e4b9ff1 100755
--- a/tests/nightly/JenkinsfileForBinaries
+++ b/tests/nightly/JenkinsfileForBinaries
@@ -141,6 +141,14 @@ core_logic: {
utils.docker_run('ubuntu_nightly_gpu', 'nightly_tutorial_test_ubuntu_python3_gpu', true, '1500m')
}
}
+ },
+ 'Gluon estimator: GPU': {
+ node(NODE_LINUX_GPU) {
+ ws('workspace/estimator-test-gpu') {
+ utils.unpack_and_init('gpu', mx_lib)
+ utils.docker_run('ubuntu_nightly_gpu', 'nightly_estimator', true)
+ }
+ }
}
}
}
diff --git a/tests/nightly/estimator/test_estimator_cnn.py b/tests/nightly/estimator/test_estimator_cnn.py
new file mode 100644
index 0000000..c60dc54
--- /dev/null
+++ b/tests/nightly/estimator/test_estimator_cnn.py
@@ -0,0 +1,151 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Test gluon estimator on CNN models
+
+import argparse
+import numpy as np
+import mxnet as mx
+from mxnet import gluon, init, nd
+from mxnet.gluon import data
+from mxnet.gluon.contrib.estimator import estimator
+from mxnet.gluon.model_zoo import vision
+
+def load_data_mnist(batch_size, resize=None, num_workers=4):
+ '''
+ Load MNIST dataset
+ '''
+ transformer = []
+ if resize:
+ transformer += [data.vision.transforms.Resize(resize)]
+ transformer += [data.vision.transforms.ToTensor()]
+ transformer = data.vision.transforms.Compose(transformer)
+ mnist_train = data.vision.MNIST(train=True)
+ mnist_test = data.vision.MNIST(train=False)
+ train_iter = data.DataLoader(
+ mnist_train.transform_first(transformer), batch_size, shuffle=True,
+ num_workers=num_workers)
+ test_iter = data.DataLoader(
+ mnist_test.transform_first(transformer), batch_size, shuffle=False,
+ num_workers=num_workers)
+ return train_iter, test_iter
+
+def bilinear_kernel(in_channels, out_channels, kernel_size):
+ '''
+ Bilinear interpolation using transposed convolution
+ https://github.com/d2l-ai/d2l-en/blob/master/chapter_computer-vision/fcn.md
+ '''
+ factor = (kernel_size + 1) // 2
+ if kernel_size % 2 == 1:
+ center = factor - 1
+ else:
+ center = factor - 0.5
+ og = np.ogrid[:kernel_size, :kernel_size]
+ filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
+ weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype='float32')
+ weight[range(in_channels), range(out_channels), :, :] = filt
+ return nd.array(weight)
+
+def get_net(model_name, context):
+ if model_name == 'FCN':
+ num_classes = 21
+ pretrained_net = vision.resnet18_v2(pretrained=True, ctx=context)
+ net = gluon.nn.HybridSequential()
+ for layer in pretrained_net.features[:-2]:
+ net.add(layer)
+ net.add(gluon.nn.Conv2D(num_classes, kernel_size=1),
+ gluon.nn.Conv2DTranspose(num_classes, kernel_size=64, padding=16, strides=32))
+ net[-1].initialize(init.Constant(bilinear_kernel(num_classes, num_classes, 64)), ctx=context)
+ net[-2].initialize(init=init.Xavier(), ctx=context)
+ input_shape = (1, 3, 320, 480)
+ label_shape = (1, 320, 480)
+ loss_axis = 1
+ else:
+ net = vision.get_model(model_name, classes=10)
+ net.initialize(mx.init.Xavier(), ctx=context)
+ input_shape = (1, 1, 224, 224)
+ label_shape = 1
+ loss_axis = -1
+ return net, input_shape, label_shape, loss_axis
+
+def test_estimator_cpu():
+ '''
+ Test estimator by doing one pass over each model with synthetic data
+ '''
+ models = ['resnet18_v1',
+ 'FCN'
+ ]
+ context = mx.cpu()
+ for model_name in models:
+ net, input_shape, label_shape, loss_axis = get_net(model_name, context)
+ train_dataset = gluon.data.dataset.ArrayDataset(mx.nd.random.uniform(shape=input_shape),
+ mx.nd.zeros(shape=label_shape))
+ val_dataset = gluon.data.dataset.ArrayDataset(mx.nd.random.uniform(shape=input_shape),
+ mx.nd.zeros(shape=label_shape))
+ loss = gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis)
+ train_data = gluon.data.DataLoader(train_dataset, batch_size=1)
+ val_data = gluon.data.DataLoader(val_dataset, batch_size=1)
+ net.hybridize()
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+ # Define estimator
+ est = estimator.Estimator(net=net,
+ loss=loss,
+ metrics=mx.metric.Accuracy(),
+ trainer=trainer,
+ context=context)
+ # Call fit()
+ est.fit(train_data=train_data,
+ val_data=val_data,
+ epochs=1)
+
+def test_estimator_gpu():
+ '''
+ Test estimator by training resnet18_v1 for 5 epochs on MNIST and verify accuracy
+ '''
+ model_name = 'resnet18_v1'
+ batch_size = 128
+ num_epochs = 5
+ context = mx.gpu(0)
+ net, _, _, _ = get_net(model_name, context)
+ train_data, test_data = load_data_mnist(batch_size, resize=224)
+ loss = gluon.loss.SoftmaxCrossEntropyLoss()
+ net.hybridize()
+ acc = mx.metric.Accuracy()
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+ # Define estimator
+ est = estimator.Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ trainer=trainer,
+ context=context)
+ # Call fit()
+ est.fit(train_data=train_data,
+ val_data=test_data,
+ epochs=num_epochs)
+
+ assert acc.get()[1] > 0.80
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='test gluon estimator')
+ parser.add_argument('--type', type=str, default='cpu')
+ opt = parser.parse_args()
+ if opt.type == 'cpu':
+ test_estimator_cpu()
+ elif opt.type == 'gpu':
+ test_estimator_gpu()
+ else:
+ raise RuntimeError("Unknown test type")
diff --git a/tests/nightly/estimator/test_sentiment_rnn.py b/tests/nightly/estimator/test_sentiment_rnn.py
new file mode 100644
index 0000000..404bf83
--- /dev/null
+++ b/tests/nightly/estimator/test_sentiment_rnn.py
@@ -0,0 +1,276 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Gluon Text Sentiment Classification Example using RNN/CNN
+Example modified from below link:
+https://github.com/d2l-ai/d2l-en/blob/master/chapter_natural-language-processing/sentiment-analysis-rnn.md
+https://github.com/d2l-ai/d2l-en/blob/master/chapter_natural-language-processing/sentiment-analysis-cnn.md"""
+
+import argparse
+import os
+import tarfile
+import random
+import collections
+import mxnet as mx
+from mxnet import nd, gluon
+from mxnet.contrib import text
+from mxnet.gluon import nn, rnn
+from mxnet.gluon.contrib.estimator import estimator
+
+
+class TextCNN(nn.Block):
+ def __init__(self, vocab, embed_size, kernel_sizes, num_channels,
+ **kwargs):
+ super(TextCNN, self).__init__(**kwargs)
+ self.embedding = nn.Embedding(len(vocab), embed_size)
+ # The embedding layer does not participate in training
+ self.constant_embedding = nn.Embedding(len(vocab), embed_size)
+ self.dropout = nn.Dropout(0.5)
+ self.decoder = nn.Dense(2)
+ # The max-over-time pooling layer has no weight, so it can share an
+ # instance
+ self.pool = nn.GlobalMaxPool1D()
+ # Create multiple one-dimensional convolutional layers
+ self.convs = nn.Sequential()
+ for c, k in zip(num_channels, kernel_sizes):
+ self.convs.add(nn.Conv1D(c, k, activation='relu'))
+
+ def forward(self, inputs):
+ # Concatenate the output of two embedding layers with shape of
+ # (batch size, number of words, word vector dimension) by word vector
+ embeddings = nd.concat(
+ self.embedding(inputs), self.constant_embedding(inputs), dim=2)
+ # According to the input format required by Conv1D, the word vector
+ # dimension, that is, the channel dimension of the one-dimensional
+ # convolutional layer, is transformed into the previous dimension
+ embeddings = embeddings.transpose((0, 2, 1))
+ # For each one-dimensional convolutional layer, after max-over-time
+ # pooling, an NDArray with the shape of (batch size, channel size, 1)
+ # can be obtained. Use the flatten function to remove the last
+ # dimension and then concatenate on the channel dimension
+ encoding = nd.concat(*[nd.flatten(
+ self.pool(conv(embeddings))) for conv in self.convs], dim=1)
+ # After applying the dropout method, use a fully connected layer to
+ # obtain the output
+ outputs = self.decoder(self.dropout(encoding))
+ return outputs
+
+
+class BiRNN(nn.Block):
+ def __init__(self, vocab, embed_size, num_hiddens, num_layers, **kwargs):
+ super(BiRNN, self).__init__(**kwargs)
+ self.embedding = nn.Embedding(len(vocab), embed_size)
+ # Set Bidirectional to True to get a bidirectional recurrent neural
+ # network
+ self.encoder = rnn.LSTM(num_hiddens, num_layers=num_layers,
+ bidirectional=True, input_size=embed_size)
+ self.decoder = nn.Dense(2)
+
+ def forward(self, inputs):
+ # The shape of inputs is (batch size, number of words). Because LSTM
+ # needs to use sequence as the first dimension, the input is
+ # transformed and the word feature is then extracted. The output shape
+ # is (number of words, batch size, word vector dimension).
+ embeddings = self.embedding(inputs.T)
+ # The shape of states is (number of words, batch size, 2 * number of
+ # hidden units).
+ states = self.encoder(embeddings)
+ # Concatenate the hidden states of the initial time step and final
+ # time step to use as the input of the fully connected layer. Its
+ # shape is (batch size, 4 * number of hidden units)
+ encoding = nd.concat(states[0], states[-1])
+ outputs = self.decoder(encoding)
+ return outputs
+
+
+def download_imdb(data_dir='/tmp/data'):
+ '''
+ Download and extract the IMDB dataset
+ '''
+ url = ('http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz')
+ sha1 = '01ada507287d82875905620988597833ad4e0903'
+ if not os.path.exists(data_dir):
+ os.makedirs(data_dir)
+ file_path = os.path.join(data_dir, 'aclImdb_v1.tar.gz')
+ if not os.path.isfile(file_path):
+ file_path = gluon.utils.download(url, data_dir, sha1_hash=sha1)
+ with tarfile.open(file_path, 'r') as f:
+ f.extractall(data_dir)
+
+
+def read_imdb(folder='train'):
+ '''
+ Read the IMDB dataset
+ '''
+ data = []
+ for label in ['pos', 'neg']:
+ folder_name = os.path.join('/tmp/data/aclImdb/', folder, label)
+ for file in os.listdir(folder_name):
+ with open(os.path.join(folder_name, file), 'rb') as f:
+ review = f.read().decode('utf-8').replace('\n', '').lower()
+ data.append([review, 1 if label == 'pos' else 0])
+ random.shuffle(data)
+ return data
+
+
+def get_tokenized_imdb(data):
+ '''
+ Tokenized the words
+ '''
+
+ def tokenizer(text):
+ return [tok.lower() for tok in text.split(' ')]
+
+ return [tokenizer(review) for review, _ in data]
+
+
+def get_vocab_imdb(data):
+ '''
+ Get the indexed tokens
+ '''
+ tokenized_data = get_tokenized_imdb(data)
+ counter = collections.Counter([tk for st in tokenized_data for tk in st])
+ return text.vocab.Vocabulary(counter, min_freq=5)
+
+
+def preprocess_imdb(data, vocab):
+ '''
+ Make the length of each comment 500 by truncating or adding 0s
+ '''
+ max_l = 500
+
+ def pad(x):
+ return x[:max_l] if len(x) > max_l else x + [0] * (max_l - len(x))
+
+ tokenized_data = get_tokenized_imdb(data)
+ features = nd.array([pad(vocab.to_indices(x)) for x in tokenized_data])
+ labels = nd.array([score for _, score in data])
+ return features, labels
+
+
+def run(net, train_dataloader, test_dataloader, **kwargs):
+ '''
+ Train a test sentiment model
+ '''
+ num_epochs = kwargs['epochs']
+ ctx = kwargs['ctx']
+ batch_size = kwargs['batch_size']
+ lr = kwargs['lr']
+
+ # Define trainer
+ trainer = mx.gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})
+ # Define loss and evaluation metrics
+ loss = gluon.loss.SoftmaxCrossEntropyLoss()
+ acc = mx.metric.Accuracy()
+
+ # Define estimator
+ est = estimator.Estimator(net=net, loss=loss, metrics=acc,
+ trainer=trainer, context=ctx)
+ # Begin training
+ est.fit(train_data=train_dataloader, val_data=test_dataloader,
+ epochs=num_epochs)
+ return acc
+
+
+def test_estimator_cpu(**kwargs):
+ '''
+ Test estimator by doing one pass over each model with synthetic data
+ '''
+ models = ['TextCNN', 'BiRNN']
+ ctx = kwargs['ctx']
+ batch_size = kwargs['batch_size']
+ embed_size = kwargs['embed_size']
+
+ train_data = mx.nd.random.randint(low=0, high=100, shape=(2 * batch_size, 500))
+ train_label = mx.nd.random.randint(low=0, high=2, shape=(2 * batch_size,))
+ val_data = mx.nd.random.randint(low=0, high=100, shape=(batch_size, 500))
+ val_label = mx.nd.random.randint(low=0, high=2, shape=(batch_size,))
+
+ train_dataloader = gluon.data.DataLoader(dataset=gluon.data.ArrayDataset(train_data, train_label),
+ batch_size=batch_size, shuffle=True)
+ val_dataloader = gluon.data.DataLoader(dataset=gluon.data.ArrayDataset(val_data, val_label),
+ batch_size=batch_size)
+ vocab_list = mx.nd.zeros(shape=(100,))
+
+ # Get the model
+ for model in models:
+ if model == 'TextCNN':
+ kernel_sizes, nums_channels = [3, 4, 5], [100, 100, 100]
+ net = TextCNN(vocab_list, embed_size, kernel_sizes, nums_channels)
+ else:
+ num_hiddens, num_layers = 100, 2
+ net = BiRNN(vocab_list, embed_size, num_hiddens, num_layers)
+ net.initialize(mx.init.Xavier(), ctx=ctx)
+
+ run(net, train_dataloader, val_dataloader, **kwargs)
+
+
+def test_estimator_gpu(**kwargs):
+ '''
+ Test estimator by training Bidirectional RNN for 5 epochs on the IMDB dataset
+ and verify accuracy
+ '''
+ ctx = kwargs['ctx']
+ batch_size = kwargs['batch_size']
+ num_epochs = kwargs['epochs']
+ embed_size = kwargs['embed_size']
+
+ # data
+ download_imdb()
+ train_data, test_data = read_imdb('train'), read_imdb('test')
+ vocab = get_vocab_imdb(train_data)
+
+ train_set = gluon.data.ArrayDataset(*preprocess_imdb(train_data, vocab))
+ test_set = gluon.data.ArrayDataset(*preprocess_imdb(test_data, vocab))
+ train_dataloader = gluon.data.DataLoader(train_set, batch_size, shuffle=True)
+ test_dataloader = gluon.data.DataLoader(test_set, batch_size)
+
+ # Model
+ num_hiddens, num_layers = 100, 2
+ net = BiRNN(vocab, embed_size, num_hiddens, num_layers)
+ net.initialize(mx.init.Xavier(), ctx=ctx)
+
+ glove_embedding = text.embedding.create(
+ 'glove', pretrained_file_name='glove.6B.100d.txt', vocabulary=vocab)
+
+ net.embedding.weight.set_data(glove_embedding.idx_to_vec)
+ net.embedding.collect_params().setattr('grad_req', 'null')
+
+ acc = run(net, train_dataloader, test_dataloader, **kwargs)
+
+ assert acc.get()[1] > 0.70
+
+
+parser = argparse.ArgumentParser(description='test gluon estimator')
+parser.add_argument('--type', type=str, default='cpu')
+opt = parser.parse_args()
+kwargs = {
+ 'batch_size': 64,
+ 'lr': 0.01,
+ 'embed_size': 100
+}
+
+if opt.type == 'cpu':
+ kwargs['ctx'] = mx.cpu()
+ kwargs['epochs'] = 1
+ test_estimator_cpu(**kwargs)
+elif opt.type == 'gpu':
+ kwargs['ctx'] = mx.gpu()
+ kwargs['epochs'] = 5
+ test_estimator_gpu(**kwargs)
+else:
+ raise RuntimeError("Unknown test type")
diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py
new file mode 100644
index 0000000..d2e8c08
--- /dev/null
+++ b/tests/python/unittest/test_gluon_estimator.py
@@ -0,0 +1,371 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+''' Unit tests for Gluon Estimator '''
+
+import sys
+import unittest
+
+import mxnet as mx
+from mxnet import gluon
+from mxnet.gluon import nn
+from mxnet.gluon.contrib.estimator import *
+from nose.tools import assert_raises
+
+
+def _get_test_network():
+ net = nn.Sequential()
+ net.add(nn.Dense(4, activation='relu', flatten=False))
+ return net
+
+
+def _get_test_data():
+ batch_size = 4
+ in_data = mx.nd.random.uniform(shape=(10, 3))
+ out_data = mx.nd.random.uniform(shape=(10, 4))
+ # Input dataloader
+ dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
+ dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size)
+ dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size)
+ return dataloader, dataiter
+
+
+def test_fit():
+ ''' test estimator with different train data types '''
+ net = _get_test_network()
+ dataloader, dataiter = _get_test_data()
+ num_epochs = 1
+ ctx = mx.cpu()
+ loss = gluon.loss.L2Loss()
+ acc = mx.metric.Accuracy()
+ net.initialize(ctx=ctx)
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ trainer=trainer,
+ context=ctx)
+
+ est.fit(train_data=dataloader,
+ epochs=num_epochs)
+
+ with assert_raises(ValueError):
+ est.fit(train_data=dataiter,
+ epochs=num_epochs)
+
+ # Input NDArray
+ with assert_raises(ValueError):
+ est.fit(train_data=[mx.nd.ones(shape=(10, 3))],
+ epochs=num_epochs)
+
+
+def test_validation():
+ ''' test different validation data types'''
+ net = _get_test_network()
+ dataloader, dataiter = _get_test_data()
+ num_epochs = 1
+ ctx = mx.cpu()
+ loss = gluon.loss.L2Loss()
+ acc = mx.metric.Accuracy()
+ net.initialize(ctx=ctx)
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ trainer=trainer,
+ context=ctx)
+ # Input dataloader
+ est.fit(train_data=dataloader,
+ val_data=dataloader,
+ epochs=num_epochs)
+
+ # using validation handler
+ train_metrics, val_metrics = est.prepare_loss_and_metrics()
+ validation_handler = ValidationHandler(val_data=dataloader, eval_fn=est.evaluate,
+ val_metrics=val_metrics)
+
+ with assert_raises(ValueError):
+ est.fit(train_data=dataiter,
+ val_data=dataiter,
+ epochs=num_epochs)
+ # Input NDArray
+ with assert_raises(ValueError):
+ est.fit(train_data=[mx.nd.ones(shape=(10, 3))],
+ val_data=[mx.nd.ones(shape=(10, 3))],
+ epochs=num_epochs)
+
+
+@unittest.skipIf(sys.version_info.major < 3, 'Test on python 3')
+def test_initializer():
+ ''' test with no initializer, inconsistent initializer '''
+ net = _get_test_network()
+ train_data, _ = _get_test_data()
+ num_epochs = 1
+ ctx = mx.cpu()
+
+ loss = gluon.loss.L2Loss()
+ acc = mx.metric.Accuracy()
+ # no initializer
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ context=ctx)
+ est.fit(train_data=train_data,
+ epochs=num_epochs)
+
+ # different initializer for net and estimator
+ net = _get_test_network()
+ net.initialize(mx.init.Xavier(), ctx=ctx)
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+ # catch reinit warning
+ with warnings.catch_warnings(record=True) as w:
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ initializer=mx.init.MSRAPrelu(),
+ trainer=trainer,
+ context=ctx)
+ assert 'Network already fully initialized' in str(w[-1].message)
+ # net partially initialized, fine tuning use case
+ net = gluon.model_zoo.vision.resnet18_v1(pretrained=True, ctx=ctx)
+ net.output = gluon.nn.Dense(10) #last layer not initialized
+ est = Estimator(net, loss=loss, metrics=acc, context=ctx)
+ dataset = gluon.data.ArrayDataset(mx.nd.zeros((10, 3, 224, 224)), mx.nd.zeros((10, 10)))
+ train_data = gluon.data.DataLoader(dataset=dataset, batch_size=5)
+ est.fit(train_data=train_data,
+ epochs=num_epochs)
+
+
+@unittest.skipIf(sys.version_info.major < 3, 'Test on python 3')
+def test_trainer():
+ ''' test with no trainer and invalid trainer '''
+ net = _get_test_network()
+ train_data, _ = _get_test_data()
+ num_epochs = 1
+ ctx = mx.cpu()
+
+ loss = gluon.loss.L2Loss()
+ acc = mx.metric.Accuracy()
+ net.initialize(ctx=ctx)
+ # input no trainer
+ with warnings.catch_warnings(record=True) as w:
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ context=ctx)
+ assert 'No trainer specified' in str(w[-1].message)
+ est.fit(train_data=train_data,
+ epochs=num_epochs)
+
+ # input invalid trainer
+ trainer = 'sgd'
+ with assert_raises(ValueError):
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=acc,
+ trainer=trainer,
+ context=ctx)
+
+
+def test_metric():
+ ''' test with no metric, list of metrics, invalid metric '''
+ net = _get_test_network()
+ train_data, _ = _get_test_data()
+ num_epochs = 1
+ ctx = mx.cpu()
+
+ loss = gluon.loss.L2Loss()
+ net.initialize(ctx=ctx)
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+ # input no metric
+ est = Estimator(net=net,
+ loss=loss,
+ trainer=trainer,
+ context=ctx)
+ est.fit(train_data=train_data,
+ epochs=num_epochs)
+ # input list of metrics
+ metrics = [mx.metric.Accuracy(), mx.metric.Accuracy()]
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=metrics,
+ trainer=trainer,
+ context=ctx)
+ est.fit(train_data=train_data,
+ epochs=num_epochs)
+ # input invalid metric
+ with assert_raises(ValueError):
+ est = Estimator(net=net,
+ loss=loss,
+ metrics='acc',
+ trainer=trainer,
+ context=ctx)
+ # test default metric
+ loss = gluon.loss.SoftmaxCrossEntropyLoss()
+ est = Estimator(net=net,
+ loss=loss,
+ trainer=trainer,
+ context=ctx)
+ est.prepare_loss_and_metrics()
+ assert isinstance(est.train_metrics[0], mx.metric.Accuracy)
+
+
+def test_loss():
+ ''' test with invalid loss '''
+ net = _get_test_network()
+ ctx = mx.cpu()
+ acc = mx.metric.Accuracy()
+ net.initialize(ctx=ctx)
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+ # input invalid loss
+ with assert_raises(ValueError):
+ est = Estimator(net=net,
+ loss='mse',
+ metrics=acc,
+ trainer=trainer,
+ context=ctx)
+
+
+def test_context():
+ ''' test with no context, list of context, invalid context '''
+ net = _get_test_network()
+ loss = gluon.loss.L2Loss()
+ metrics = mx.metric.Accuracy()
+ # input no context
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=metrics)
+ # input list of context
+ gpus = mx.context.num_gpus()
+ ctx = [mx.gpu(i) for i in range(gpus)] if gpus > 0 else [mx.cpu()]
+ net = _get_test_network()
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=metrics,
+ context=ctx)
+ # input invalid context
+ with assert_raises(ValueError):
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=metrics,
+ context='cpu')
+
+ with assert_raises(AssertionError):
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=metrics,
+ context=[mx.gpu(0), mx.gpu(100)])
+
+
+def test_categorize_handlers():
+ class CustomHandler1(TrainBegin):
+
+ def train_begin(self):
+ print("custom train begin")
+
+ class CustomHandler2(EpochBegin, BatchBegin, TrainEnd):
+
+ def epoch_begin(self):
+ print("custom epoch begin")
+
+ def batch_begin(self):
+ print("custom batch begin")
+
+ def train_end(self):
+ print("custom train end")
+
+ class CustomHandler3(EpochBegin, BatchBegin, BatchEnd, TrainEnd):
+
+ def epoch_begin(self):
+ print("custom epoch begin")
+
+ def batch_begin(self):
+ print("custom batch begin")
+
+ def batch_end(self):
+ print("custom batch end")
+
+ def train_end(self):
+ print("custom train end")
+
+ net = nn.Sequential()
+ net.add(nn.Dense(10))
+ loss = gluon.loss.SoftmaxCrossEntropyLoss()
+ est = Estimator(net, loss=loss)
+ event_handlers = [CustomHandler1(), CustomHandler2(), CustomHandler3()]
+ train_begin, epoch_begin, batch_begin, \
+ batch_end, epoch_end, train_end = est._categorize_handlers(event_handlers)
+ assert len(train_begin) == 1
+ assert len(epoch_begin) == 2
+ assert len(batch_begin) == 2
+ assert len(batch_end) == 1
+ assert len(train_end) == 2
+
+
+@unittest.skipIf(sys.version_info.major < 3, 'Test on python 3')
+def test_default_handlers():
+ net = _get_test_network()
+ train_data, _ = _get_test_data()
+
+ num_epochs = 1
+ ctx = mx.cpu()
+
+ net.initialize(ctx=ctx)
+ trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+
+ train_acc = mx.metric.RMSE()
+ loss = gluon.loss.L2Loss()
+
+ est = Estimator(net=net,
+ loss=loss,
+ metrics=train_acc,
+ trainer=trainer,
+ context=ctx)
+ # no handler
+ with warnings.catch_warnings(record=True) as w:
+ est.fit(train_data=train_data, epochs=num_epochs)
+ assert 'You are training with the' in str(w[-1].message)
+
+ # handler with prepared loss and metrics
+ # use mix of default and user defined handlers
+ train_metrics, val_metrics = est.prepare_loss_and_metrics()
+ logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics)
+ with warnings.catch_warnings(record=True) as w:
+ est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging])
+ assert 'You are training with the' in str(w[-1].message)
+ # provide metric handler by default
+ assert 'MetricHandler' in str(w[-1].message)
+
+ # handler with all user defined metrics
+ # use mix of default and user defined handlers
+ metric = MetricHandler(train_metrics=[train_acc])
+ logging = LoggingHandler(train_metrics=[train_acc], val_metrics=[mx.metric.RMSE("val acc")])
+ est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[metric, logging])
+
+ # handler with mixed metrics, some handler use metrics prepared by estimator
+ # some handler use metrics user prepared
+ logging = LoggingHandler(train_metrics=train_metrics, val_metrics=[mx.metric.RMSE("val acc")])
+ with assert_raises(ValueError):
+ est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging])
+
+ # test handler order
+ train_metrics, val_metrics = est.prepare_loss_and_metrics()
+ early_stopping = EarlyStoppingHandler(monitor=val_metrics[0])
+ handlers = est._prepare_default_handlers(val_data=None, event_handlers=[early_stopping])
+ assert len(handlers) == 4
+ assert isinstance(handlers[0], MetricHandler)
+ assert isinstance(handlers[3], LoggingHandler)
diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py
new file mode 100644
index 0000000..7ea5ff3
--- /dev/null
+++ b/tests/python/unittest/test_gluon_event_handler.py
@@ -0,0 +1,198 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+
+import mxnet as mx
+from common import TemporaryDirectory
+from mxnet import nd
+from mxnet.gluon import nn, loss
+from mxnet.gluon.contrib.estimator import estimator, event_handler
+
+
+def _get_test_network(net=nn.Sequential()):
+ net.add(nn.Dense(128, activation='relu', flatten=False),
+ nn.Dense(64, activation='relu'),
+ nn.Dense(10, activation='relu'))
+ return net
+
+
+def _get_test_data():
+ data = nd.ones((32, 100))
+ label = nd.zeros((32, 1))
+ data_arr = mx.gluon.data.dataset.ArrayDataset(data, label)
+ return mx.gluon.data.DataLoader(data_arr, batch_size=8)
+
+
+def test_checkpoint_handler():
+ with TemporaryDirectory() as tmpdir:
+ model_prefix = 'test_epoch'
+ file_path = os.path.join(tmpdir, model_prefix)
+ test_data = _get_test_data()
+
+ net = _get_test_network()
+ ce_loss = loss.SoftmaxCrossEntropyLoss()
+ acc = mx.metric.Accuracy()
+ est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+ checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
+ model_prefix=model_prefix,
+ monitor=acc,
+ save_best=True,
+ epoch_period=1)
+ est.fit(test_data, event_handlers=[checkpoint_handler], epochs=1)
+ assert checkpoint_handler.current_epoch == 1
+ assert checkpoint_handler.current_batch == 4
+ assert os.path.isfile(file_path + '-best.params')
+ assert os.path.isfile(file_path + '-best.states')
+ assert os.path.isfile(file_path + '-epoch0batch4.params')
+ assert os.path.isfile(file_path + '-epoch0batch4.states')
+
+ model_prefix = 'test_batch'
+ file_path = os.path.join(tmpdir, model_prefix)
+ net = _get_test_network(nn.HybridSequential())
+ net.hybridize()
+ est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+ checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
+ model_prefix=model_prefix,
+ epoch_period=None,
+ batch_period=2,
+ max_checkpoints=2)
+ est.fit(test_data, event_handlers=[checkpoint_handler], batches=10)
+ assert checkpoint_handler.current_batch == 10
+ assert checkpoint_handler.current_epoch == 3
+ assert not os.path.isfile(file_path + 'best.params')
+ assert not os.path.isfile(file_path + 'best.states')
+ assert not os.path.isfile(file_path + '-epoch0batch0.params')
+ assert not os.path.isfile(file_path + '-epoch0batch0.states')
+ assert os.path.isfile(file_path + '-symbol.json')
+ assert os.path.isfile(file_path + '-epoch1batch7.params')
+ assert os.path.isfile(file_path + '-epoch1batch7.states')
+ assert os.path.isfile(file_path + '-epoch2batch9.params')
+ assert os.path.isfile(file_path + '-epoch2batch9.states')
+
+def test_resume_checkpoint():
+ with TemporaryDirectory() as tmpdir:
+ model_prefix = 'test_net'
+ file_path = os.path.join(tmpdir, model_prefix)
+ test_data = _get_test_data()
+
+ net = _get_test_network()
+ ce_loss = loss.SoftmaxCrossEntropyLoss()
+ acc = mx.metric.Accuracy()
+ est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+ checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
+ model_prefix=model_prefix,
+ monitor=acc,
+ max_checkpoints=1)
+ est.fit(test_data, event_handlers=[checkpoint_handler], epochs=2)
+ assert os.path.isfile(file_path + '-epoch1batch8.params')
+ assert os.path.isfile(file_path + '-epoch1batch8.states')
+ checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
+ model_prefix=model_prefix,
+ monitor=acc,
+ max_checkpoints=1,
+ resume_from_checkpoint=True)
+ est.fit(test_data, event_handlers=[checkpoint_handler], epochs=5)
+ # should only continue to train 3 epochs and last checkpoint file is epoch4
+ assert est.max_epoch == 3
+ assert os.path.isfile(file_path + '-epoch4batch20.states')
+
+
+def test_early_stopping():
+ test_data = _get_test_data()
+
+ net = _get_test_network()
+ ce_loss = loss.SoftmaxCrossEntropyLoss()
+ acc = mx.metric.Accuracy()
+ est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+ early_stopping = event_handler.EarlyStoppingHandler(monitor=acc,
+ patience=0,
+ mode='min')
+ est.fit(test_data, event_handlers=[early_stopping], epochs=5)
+ assert early_stopping.current_epoch == 2
+ assert early_stopping.stopped_epoch == 1
+
+ early_stopping = event_handler.EarlyStoppingHandler(monitor=acc,
+ patience=2,
+ mode='auto')
+ est.fit(test_data, event_handlers=[early_stopping], epochs=1)
+ assert early_stopping.current_epoch == 1
+
+
+def test_logging():
+ with TemporaryDirectory() as tmpdir:
+ test_data = _get_test_data()
+ file_name = 'test_log'
+ output_dir = os.path.join(tmpdir, file_name)
+
+ net = _get_test_network()
+ ce_loss = loss.SoftmaxCrossEntropyLoss()
+ acc = mx.metric.Accuracy()
+ est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+ train_metrics, val_metrics = est.prepare_loss_and_metrics()
+ logging_handler = event_handler.LoggingHandler(file_name=file_name,
+ file_location=tmpdir,
+ train_metrics=train_metrics,
+ val_metrics=val_metrics)
+ est.fit(test_data, event_handlers=[logging_handler], epochs=3)
+ assert logging_handler.batch_index == 0
+ assert logging_handler.current_epoch == 3
+ assert os.path.isfile(output_dir)
+
+
+def test_custom_handler():
+ class CustomStopHandler(event_handler.TrainBegin,
+ event_handler.BatchEnd,
+ event_handler.EpochEnd):
+ def __init__(self, batch_stop=None, epoch_stop=None):
+ self.batch_stop = batch_stop
+ self.epoch_stop = epoch_stop
+ self.num_batch = 0
+ self.num_epoch = 0
+ self.stop_training = False
+
+ def train_begin(self, estimator, *args, **kwargs):
+ self.num_batch = 0
+ self.num_epoch = 0
+
+ def batch_end(self, estimator, *args, **kwargs):
+ self.num_batch += 1
+ if self.num_batch == self.batch_stop:
+ self.stop_training = True
+ return self.stop_training
+
+ def epoch_end(self, estimator, *args, **kwargs):
+ self.num_epoch += 1
+ if self.num_epoch == self.epoch_stop:
+ self.stop_training = True
+ return self.stop_training
+
+ # total data size is 32, batch size is 8
+ # 4 batch per epoch
+ test_data = _get_test_data()
+ net = _get_test_network()
+ ce_loss = loss.SoftmaxCrossEntropyLoss()
+ acc = mx.metric.Accuracy()
+ est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+ custom_handler = CustomStopHandler(3, 2)
+ est.fit(test_data, event_handlers=[custom_handler], epochs=3)
+ assert custom_handler.num_batch == 3
+ assert custom_handler.num_epoch == 1
+ custom_handler = CustomStopHandler(100, 5)
+ est.fit(test_data, event_handlers=[custom_handler], epochs=10)
+ assert custom_handler.num_batch == 5 * 4
+ assert custom_handler.num_epoch == 5