You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2019/03/16 05:17:16 UTC
[incubator-mxnet] branch fit-api updated: [MXNet-1334][Fit API]base
class for estimator and eventhandler (#14346)
This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch fit-api
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/fit-api by this push:
new 41392fa [MXNet-1334][Fit API]base class for estimator and eventhandler (#14346)
41392fa is described below
commit 41392fa1cdc4c4a49451678d9df7fdbad5b42faa
Author: Lai Wei <ro...@gmail.com>
AuthorDate: Fri Mar 15 22:16:42 2019 -0700
[MXNet-1334][Fit API]base class for estimator and eventhandler (#14346)
* base class for estimator and eventhandler
* add license
* add event handlers
* fix pylint
* improve arg check
* fix pylint
* add unit tests
---
python/mxnet/gluon/estimator/__init__.py | 21 ++
python/mxnet/gluon/estimator/estimator.py | 267 +++++++++++++++++++
python/mxnet/gluon/estimator/event_handler.py | 307 ++++++++++++++++++++++
tests/python/unittest/test_gluon_event_handler.py | 92 +++++++
4 files changed, 687 insertions(+)
diff --git a/python/mxnet/gluon/estimator/__init__.py b/python/mxnet/gluon/estimator/__init__.py
new file mode 100644
index 0000000..58600da
--- /dev/null
+++ b/python/mxnet/gluon/estimator/__init__.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=wildcard-import
+"""Gluon Estimator Module"""
+from .estimator import *
+from .event_handler import *
diff --git a/python/mxnet/gluon/estimator/estimator.py b/python/mxnet/gluon/estimator/estimator.py
new file mode 100644
index 0000000..159f7e2
--- /dev/null
+++ b/python/mxnet/gluon/estimator/estimator.py
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import
+"""Gluon Estimator"""
+
+import warnings
+
+from .event_handler import LoggingHandler
+from ... import gluon, autograd
+from ...context import Context, cpu, gpu, num_gpus
+from ...io import DataIter
+from ...metric import EvalMetric, Loss
+
+__all__ = ['Estimator']
+
+
+class Estimator(object):
+ """Estimator Class for easy model training
+
+ :py:class:`Estimator` can be used to facilitate the training & validation process
+
+
+ Parameters
+ ----------
+ loss : Loss or list of Loss
+ Loss(objective functions) to calculate during training
+ metrics : EvalMetric or list of EvalMetric
+ Metrics for evaluating models
+ initializer : Initializer
+ initializer to initialize the network
+ trainers : Trainer or list of Trainer
+ Trainers to apply optimizers on network parameters
+ context : Context or list of Context
+ devices to run the training on
+ """
+
+ def __init__(self, net,
+ loss=None,
+ metrics=None,
+ initializer=None,
+ trainers=None,
+ context=None):
+
+ self.net = net
+ self.stop_training = False
+
+ if isinstance(loss, gluon.loss.Loss):
+ self.loss = [loss]
+ else:
+ self.loss = loss or []
+ for l in self.loss:
+ if not isinstance(loss, gluon.loss.Loss):
+ raise ValueError("loss must be a Loss or a list of Loss, refer to gluon.loss.Loss")
+
+ if isinstance(metrics, EvalMetric):
+ self.metrics = [metrics]
+ else:
+ self.metrics = metrics or []
+ for metric in self.metrics:
+ if not isinstance(metric, EvalMetric):
+ raise ValueError("metrics must be a Metric or a list of Metric, refer to mxnet.metric.EvalMetric")
+
+ self.initializer = initializer
+ # store training statistics
+ self.train_stats = {}
+ self.train_stats['epochs'] = []
+ self.train_stats['learning_rate'] = []
+ # current step of the epoch
+ self.train_stats['step'] = ''
+ for metric in self.metrics:
+ # record a history of metrics over each epoch
+ self.train_stats['train_' + metric.name] = []
+ # only record the latest metric numbers after each batch
+ self.train_stats['batch_' + metric.name] = 0.
+ self.loss_metrics = []
+ # using the metric wrapper for loss to record loss value
+ for l in self.loss:
+ self.loss_metrics.append(Loss(l.name))
+ self.train_stats['train_' + l.name] = []
+ # only record the latest loss numbers after each batch
+ self.train_stats['batch_' + l.name] = 0.
+
+ # handle context
+ if isinstance(context, Context):
+ self.context = [context]
+ if not context:
+ if num_gpus() > 0:
+ # only use 1 GPU by default
+ if num_gpus() > 1:
+ warnings.warn("You have multiple GPUs, gpu(0) will be used by default."
+ "To utilize all your GPUs, specify context as a list of gpus, "
+ "e.g. context=[mx.gpu(0), mx.gpu(1)] ")
+ self.context = [gpu(0)]
+ else:
+ self.context = [cpu()]
+
+ # initialize the network
+ if self.initializer:
+ if self._is_initialized():
+ # if already initialized, re-init with user specified initializer
+ warnings.warn("Network already initialized, re-initializing with %s. "
+ "You don't need to pass initializer if you already "
+ "initialized your net."% type(self.initializer).__name__)
+ self.net.initialize(init=self.initializer, ctx=self.context, force_reinit=True)
+ else:
+ # initialize with user specified initializer
+ self.net.initialize(init=self.initializer, ctx=self.context, force_reinit=False)
+ else:
+ if not self._is_initialized():
+ self.net.initialize(ctx=self.context)
+
+ # handle trainers
+ if isinstance(trainers, gluon.Trainer):
+ self.trainers = [trainers]
+ else:
+ self.trainers = trainers or []
+ if not self.trainers:
+ warnings.warn("No trainer specified, default SGD optimizer "
+ "with learning rate 0.001 is used.")
+ self.trainers = [gluon.Trainer(self.net.collect_params(),
+ 'sgd', {'learning_rate': 0.001})]
+
+ def _is_initialized(self):
+ param_dict = self.net.collect_params()
+ for param in param_dict:
+ try:
+ param_dict[param].list_ctx()
+ except RuntimeError:
+ return False
+ return True
+
+ def _batch_fn(self, batch, ctx, is_iterator=False):
+ if is_iterator:
+ data = batch.data[0]
+ label = batch.label[0]
+ else:
+ data = batch[0]
+ label = batch[1]
+ data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=0)
+ label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0)
+ return data, label
+
+ def fit(self, train_data,
+ epochs=1,
+ batch_size=None,
+ event_handlers=None,
+ batch_fn=None):
+ """Main training loop
+
+ Parameters
+ ----------
+ train_data : DataLoader or DataIter
+ training data with data and labels
+ val_data : DataLoader or DataIter
+ validation data with data and labels
+ epochs : int, default 1
+ number of epochs to iterate on the training data.
+ batch_size : int
+ number of samples per gradient update.
+ default will be 32 per device
+ event_handlers : EventHandler or list of EventHandler
+ list of EventHandlers to apply during training
+ batch_fn : function
+ custom batch function to extract data and label
+ from a data batch and load into contexts(devices)
+ """
+
+
+ self.epochs = epochs
+ if not batch_size:
+ batch_size = 32 * len(self.context)
+
+ event_handlers = event_handlers or []
+ # provide default logging handler
+ if not event_handlers or \
+ not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
+ event_handlers.append(LoggingHandler(self))
+
+ # training begin
+ for handler in event_handlers:
+ handler.train_begin()
+
+ for epoch in range(epochs):
+ # epoch begin
+ self.train_stats['epochs'].append(epoch)
+ self.train_stats['learning_rate'].append(self.trainers[0].learning_rate)
+
+ for handler in event_handlers:
+ handler.epoch_begin()
+
+ for metric in self.metrics + self.loss_metrics:
+ metric.reset()
+
+ for i, batch in enumerate(train_data):
+ if not batch_fn:
+ if isinstance(train_data, gluon.data.DataLoader):
+ data, label = self._batch_fn(batch, self.context)
+ elif isinstance(train_data, DataIter):
+ data, label = self._batch_fn(batch, self.context, is_iterator=True)
+ else:
+ raise ValueError("You are using a custom iteration, please also provide "
+ "batch_fn to extract data and label")
+ else:
+ data, label = batch_fn(batch, self.context)
+
+ # batch begin
+ for handler in event_handlers:
+ handler.batch_begin()
+
+ with autograd.record():
+ pred = [self.net(x) for x in data]
+ losses = []
+ for loss in self.loss:
+ losses.append([loss(y_hat, y) for y_hat, y in zip(pred, label)])
+
+ for loss in losses:
+ for l in loss:
+ l.backward()
+
+ # update metrics
+ for metric in self.metrics:
+ metric.update(label, pred)
+ self.train_stats['batch_' + metric.name] = metric.get()[1]
+ for loss, loss_metric, in zip(losses, self.loss_metrics):
+ loss_metric.update(0, [l for l in loss])
+ self.train_stats['batch_' + loss_metric.name] = loss_metric.get()[1]
+
+ try:
+ self.train_stats['step'] = "{}/{}".format(batch_size * (i + 1), len(train_data._dataset))
+ except AttributeError:
+ self.train_stats['step'] = i
+
+ for trainer in self.trainers:
+ trainer.step(batch_size)
+
+ # batch end
+ for handler in event_handlers:
+ handler.batch_end()
+
+ for metric in self.metrics + self.loss_metrics:
+ self.train_stats['train_' + metric.name].append(metric.get()[1])
+ # epoch end
+ for handler in event_handlers:
+ handler.epoch_end()
+
+ if self.stop_training:
+ break
+
+ # train end
+ for handler in event_handlers:
+ handler.train_end()
diff --git a/python/mxnet/gluon/estimator/event_handler.py b/python/mxnet/gluon/estimator/event_handler.py
new file mode 100644
index 0000000..0162c36
--- /dev/null
+++ b/python/mxnet/gluon/estimator/event_handler.py
@@ -0,0 +1,307 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import
+"""Gluon EventHandlers for Estimators"""
+
+__all__ = ['EventHandler', 'LoggingHandler']
+import logging
+import os
+import time
+import warnings
+
+import numpy as np
+
+
+class EventHandler(object):
+ """Basic for event handlers
+
+ :py:class:`EventHandler` can perform user defined functions at
+ different stages of training: train begin, epoch begin, batch begin,
+ batch end, epoch end, train end.
+
+ Parameters
+ ----------
+ estimator : Estimator
+ The :py:class:`Estimator` to get training statistics
+ """
+ def __init__(self, estimator):
+ self._estimator = estimator
+
+ def train_begin(self):
+ pass
+
+ def train_end(self):
+ pass
+
+ def batch_begin(self):
+ pass
+
+ def batch_end(self):
+ pass
+
+ def epoch_begin(self):
+ pass
+
+ def epoch_end(self):
+ pass
+
+
+class LoggingHandler(EventHandler):
+ """Basic Logging Handler that applies to every Gluon estimator by default.
+
+ :py:class:`LoggingHandler` logs hyper-parameters, training statistics,
+ and other useful information during training
+
+ Parameters
+ ----------
+ estimator : Estimator
+ The :py:class:`Estimator` to get training statistics
+ file_name : str
+ file name to save the logs
+ file_location: str
+ file location to save the logs
+ """
+
+ def __init__(self, estimator, file_name=None, file_location=None, ):
+ super(LoggingHandler, self).__init__(estimator)
+ self.logger = logging.getLogger(__name__)
+ self.logger.setLevel(logging.INFO)
+ stream_handler = logging.StreamHandler()
+ self.logger.addHandler(stream_handler)
+ # save logger to file only if file name or location is specified
+ if file_name or file_location:
+ file_name = file_name or 'estimator_log'
+ file_location = file_location or './'
+ file_handler = logging.FileHandler(os.path.join(file_location, file_name))
+ self.logger.addHandler(file_handler)
+
+ def train_begin(self):
+ pass
+
+ def train_end(self):
+ pass
+
+ def batch_begin(self):
+ self.batch_start = time.time()
+
+ def batch_end(self):
+ batch_time = time.time() - self.batch_start
+ epoch = self._estimator.train_stats['epochs'][-1]
+ step = self._estimator.train_stats['step']
+ msg = '[Epoch %d] [Step %s] time/step: %.3fs ' % (epoch, step, batch_time)
+ for key in self._estimator.train_stats.keys():
+ if key.startswith('batch_'):
+ msg += key[6:] + ': ' + '%.4f ' % self._estimator.train_stats[key]
+ self.logger.info(msg)
+
+ def epoch_begin(self):
+ self.epoch_start = time.time()
+
+ def epoch_end(self):
+ epoch_time = time.time() - self.epoch_start
+ epoch = self._estimator.train_stats['epochs'][-1]
+ msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time)
+ for key in self._estimator.train_stats.keys():
+ if key.startswith('train_') or key.startswith('test_'):
+ msg += key + ': ' + '%.4f ' % self._estimator.train_stats[key][epoch]
+ self.logger.info(msg)
+
+
+class CheckpointHandler(EventHandler):
+ """Save the model after every epoch.
+
+ :py:class:`CheckpointHandler` save the network parameters every epoch
+
+ Parameters
+ ----------
+ estimator : Estimator
+ The :py:class:`Estimator` to get training statistics
+ filepath : str
+ file name to save the parameters, it can contain directories,
+ for example: ./saved_model/resnet.params
+ monitor: str
+ the metrics to monitor
+ verbose: int, default 0
+ verbosity mode
+ save_best_only: bool
+ if True, only save the parameters if monitored value improved
+ mode: str, default 'auto'
+ one of {auto, min, max}, if `save_best_only=True`, the comparison to make
+ and determine if the monitored value has improved
+ period: int, default 1
+ intervals between saving the network
+ """
+
+ def __init__(self, estimator,
+ filepath,
+ monitor='val_loss',
+ verbose=0,
+ save_best_only=False,
+ mode='auto',
+ period=1):
+ super(CheckpointHandler, self).__init__(estimator)
+ self.monitor = monitor
+ self.verbose = verbose
+ self.filepath = filepath
+ self.save_best_only = save_best_only
+ self.period = period
+ self.epochs_since_last_save = 0
+ self.logger = logging.getLogger(__name__)
+
+ if mode not in ['auto', 'min', 'max']:
+ warnings.warn('ModelCheckpoint mode %s is unknown, '
+ 'fallback to auto mode.' % (mode),
+ RuntimeWarning)
+ mode = 'auto'
+
+ if mode == 'min':
+ self.monitor_op = np.less
+ self.best = np.Inf
+ elif mode == 'max':
+ self.monitor_op = np.greater
+ self.best = -np.Inf
+ else:
+ # use greater for accuracy and less otherwise
+ if 'acc' in self.monitor:
+ self.monitor_op = np.greater
+ self.best = -np.Inf
+ else:
+ self.monitor_op = np.less
+ self.best = np.Inf
+
+ def epoch_end(self, ):
+ epoch = self._estimator.train_stats['epochs'][-1]
+ # add extension for weights
+ if '.params' not in self.filepath:
+ self.filepath += '.params'
+ self.epochs_since_last_save += 1
+ if self.epochs_since_last_save >= self.period:
+ self.epochs_since_last_save = 0
+ if self.save_best_only:
+ # check if monitor exists in train_stats
+ if self.monitor not in self._estimator.train_stats:
+ warnings.warn(RuntimeWarning('Unable to find %s in training statistics, make sure'
+ 'you are passing one of the metric names as monitor', self.monitor))
+ self._estimator.net.save_parameters(self.filepath)
+ else:
+ current = self._estimator.train_stats[self.monitor][-1]
+ if self.monitor_op(current, self.best):
+ if self.verbose > 0:
+ self.logger.info('\n[Epoch %d] %s improved from %0.5f to %0.5f,'
+ ' saving model to %s',
+ epoch, self.monitor, self.best, current, self.filepath)
+ self.best = current
+ self._estimator.net.save_parameters(self.filepath)
+ else:
+ if self.verbose > 0:
+ self.logger.info('\n[Epoch %d] %s did not improve from %0.5f, skipping save model',
+ epoch, self.monitor, self.best)
+ else:
+ if self.verbose > 0:
+ logging.info('\nEpoch %d: saving model to %s', epoch, self.filepath)
+ self._estimator.net.save_parameters(self.filepath)
+
+
+class EarlyStoppingHandler(EventHandler):
+ """Early stop training if monitored value is not improving
+
+ Parameters
+ ----------
+ estimator : Estimator
+ The :py:class:`Estimator` to get training statistics
+ monitor: str
+ the metrics to monitor
+ min_delta: float, default 0
+ minimal change in monitored value to be considered as an improvement
+ patience: int, default 0
+ number of epochs to wait for improvement before terminate training
+ mode: str, default 'auto'
+ one of {auto, min, max}, the comparison to make
+ and determine if the monitored value has improved
+ baseline: float
+ baseline value to compare the monitored value with
+ """
+
+ def __init__(self, estimator,
+ monitor='val_loss',
+ min_delta=0,
+ patience=0,
+ mode='auto',
+ baseline=None):
+ super(EarlyStoppingHandler, self).__init__(estimator)
+
+ self._estimator = estimator
+ self.monitor = monitor
+ self.baseline = baseline
+ self.patience = patience
+ self.min_delta = min_delta
+ self.wait = 0
+ self.stopped_epoch = 0
+ self.logger = logging.getLogger(__name__)
+
+ if mode not in ['auto', 'min', 'max']:
+ warnings.warn(RuntimeWarning('EarlyStopping mode %s is unknown, '
+ 'fallback to auto mode.', mode))
+ mode = 'auto'
+
+ if mode == 'min':
+ self.monitor_op = np.less
+ elif mode == 'max':
+ self.monitor_op = np.greater
+ else:
+ if 'acc' in self.monitor:
+ self.monitor_op = np.greater
+ else:
+ self.monitor_op = np.less
+
+ if self.monitor_op == np.greater:
+ self.min_delta *= 1
+ else:
+ self.min_delta *= -1
+
+ def train_begin(self):
+ self.wait = 0
+ self.stopped_epoch = 0
+ if self.baseline is not None:
+ self.best = self.baseline
+ else:
+ self.best = np.Inf if self.monitor_op == np.less else -np.Inf
+
+ def epoch_end(self):
+ epoch = self._estimator.train_stats['epochs'][-1]
+ if self.monitor not in self._estimator.train_stats:
+ warnings.warn(RuntimeWarning('Unable to find %s in training statistics, make sure'
+ 'you are passing one of the metric names as monitor', self.monitor))
+ else:
+ current = self._estimator.train_stats[self.monitor][-1]
+ if current is None:
+ return
+
+ if self.monitor_op(current - self.min_delta, self.best):
+ self.best = current
+ self.wait = 0
+ else:
+ self.wait += 1
+ if self.wait >= self.patience:
+ self.stopped_epoch = epoch
+ self._estimator.stop_training = True
+
+ def train_end(self):
+ if self.stopped_epoch > 0:
+ self.logger.info('Epoch %d: early stopping due to %s not improving', self.stopped_epoch, self.monitor)
diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py
new file mode 100644
index 0000000..a551594
--- /dev/null
+++ b/tests/python/unittest/test_gluon_event_handler.py
@@ -0,0 +1,92 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import tempfile
+import mxnet as mx
+from mxnet import nd
+from mxnet.gluon import nn, loss
+from mxnet.gluon.estimator import estimator, event_handler
+
+def _get_test_network():
+ net = nn.Sequential()
+ net.add(nn.Dense(128, activation='relu', in_units=100, flatten=False),
+ nn.Dense(64, activation='relu', in_units=128),
+ nn.Dense(10, activation='relu', in_units=64))
+ return net
+
+def _get_test_data():
+ return mx.io.NDArrayIter(data=nd.ones((32, 100)), label=nd.random.randint(0, 10, (32, 1)))
+
+
+def test_checkpoint_handler():
+ tmpdir = tempfile.mkdtemp()
+ file_path = os.path.join(tmpdir, "model.params")
+ test_data = _get_test_data()
+
+ save_best_only = False
+ mode = 'auto'
+
+ net = _get_test_network()
+ ce_loss = loss.SoftmaxCrossEntropyLoss()
+ acc = mx.metric.Accuracy()
+ est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+ checkpoint_handler = [event_handler.CheckpointHandler(est, file_path,
+ save_best_only=save_best_only,
+ mode=mode)]
+ est.fit(test_data, event_handlers=checkpoint_handler, epochs=1)
+ assert os.path.isfile(file_path)
+ os.remove(file_path)
+
+def test_early_stopping():
+ test_data = _get_test_data()
+
+ mode = 'max'
+ monitor = 'train_accuracy'
+ patience = 0
+
+ net = _get_test_network()
+ ce_loss = loss.SoftmaxCrossEntropyLoss()
+ acc = mx.metric.Accuracy()
+ est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+ early_stopping = [event_handler.EarlyStoppingHandler(est, monitor,
+ patience=patience,
+ mode=mode)]
+ est.fit(test_data, event_handlers=early_stopping, epochs=1)
+
+ mode = 'auto'
+ monitor = 'train_accuracy'
+ patience = 2
+ early_stopping = [event_handler.EarlyStoppingHandler(est, monitor,
+ patience=patience,
+ mode=mode)]
+ est.fit(test_data, event_handlers=early_stopping, epochs=1)
+
+def test_logging():
+ tmpdir = tempfile.mkdtemp()
+ test_data = _get_test_data()
+ file_name = 'test_log'
+ output_dir = os.path.join(tmpdir, file_name)
+
+ net = _get_test_network()
+ ce_loss = loss.SoftmaxCrossEntropyLoss()
+ acc = mx.metric.Accuracy()
+ est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+ logging_handler = [event_handler.LoggingHandler(est, file_name=file_name, file_location=tmpdir)]
+ est.fit(test_data, event_handlers=logging_handler, epochs=1)
+ assert os.path.isfile(output_dir)
+ os.remove(output_dir)
\ No newline at end of file