You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2019/03/16 05:17:16 UTC

[incubator-mxnet] branch fit-api updated: [MXNet-1334][Fit API]base class for estimator and eventhandler (#14346)

This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch fit-api
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/fit-api by this push:
     new 41392fa  [MXNet-1334][Fit API]base class for estimator and eventhandler (#14346)
41392fa is described below

commit 41392fa1cdc4c4a49451678d9df7fdbad5b42faa
Author: Lai Wei <ro...@gmail.com>
AuthorDate: Fri Mar 15 22:16:42 2019 -0700

    [MXNet-1334][Fit API]base class for estimator and eventhandler (#14346)
    
    * base class for estimator and eventhandler
    
    * add license
    
    * add event handlers
    
    * fix pylint
    
    * improve arg check
    
    * fix pylint
    
    * add unit tests
---
 python/mxnet/gluon/estimator/__init__.py          |  21 ++
 python/mxnet/gluon/estimator/estimator.py         | 267 +++++++++++++++++++
 python/mxnet/gluon/estimator/event_handler.py     | 307 ++++++++++++++++++++++
 tests/python/unittest/test_gluon_event_handler.py |  92 +++++++
 4 files changed, 687 insertions(+)

diff --git a/python/mxnet/gluon/estimator/__init__.py b/python/mxnet/gluon/estimator/__init__.py
new file mode 100644
index 0000000..58600da
--- /dev/null
+++ b/python/mxnet/gluon/estimator/__init__.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=wildcard-import
+"""Gluon Estimator Module"""
+from .estimator import *
+from .event_handler import *
diff --git a/python/mxnet/gluon/estimator/estimator.py b/python/mxnet/gluon/estimator/estimator.py
new file mode 100644
index 0000000..159f7e2
--- /dev/null
+++ b/python/mxnet/gluon/estimator/estimator.py
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import
+"""Gluon Estimator"""
+
+import warnings
+
+from .event_handler import LoggingHandler
+from ... import gluon, autograd
+from ...context import Context, cpu, gpu, num_gpus
+from ...io import DataIter
+from ...metric import EvalMetric, Loss
+
+__all__ = ['Estimator']
+
+
+class Estimator(object):
+    """Estimator Class for easy model training
+
+    :py:class:`Estimator` can be used to facilitate the training & validation process
+
+
+    Parameters
+    ----------
+    loss : Loss or list of Loss
+        Loss(objective functions) to calculate during training
+    metrics : EvalMetric or list of EvalMetric
+        Metrics for evaluating models
+    initializer : Initializer
+        initializer to initialize the network
+    trainers : Trainer or list of Trainer
+        Trainers to apply optimizers on network parameters
+    context : Context or list of Context
+        devices to run the training on
+    """
+
+    def __init__(self, net,
+                 loss=None,
+                 metrics=None,
+                 initializer=None,
+                 trainers=None,
+                 context=None):
+
+        self.net = net
+        self.stop_training = False
+
+        if isinstance(loss, gluon.loss.Loss):
+            self.loss = [loss]
+        else:
+            self.loss = loss or []
+            for l in self.loss:
+                if not isinstance(loss, gluon.loss.Loss):
+                    raise ValueError("loss must be a Loss or a list of Loss, refer to gluon.loss.Loss")
+
+        if isinstance(metrics, EvalMetric):
+            self.metrics = [metrics]
+        else:
+            self.metrics = metrics or []
+            for metric in self.metrics:
+                if not isinstance(metric, EvalMetric):
+                    raise ValueError("metrics must be a Metric or a list of Metric, refer to mxnet.metric.EvalMetric")
+
+        self.initializer = initializer
+        # store training statistics
+        self.train_stats = {}
+        self.train_stats['epochs'] = []
+        self.train_stats['learning_rate'] = []
+        # current step of the epoch
+        self.train_stats['step'] = ''
+        for metric in self.metrics:
+            # record a history of metrics over each epoch
+            self.train_stats['train_' + metric.name] = []
+            # only record the latest metric numbers after each batch
+            self.train_stats['batch_' + metric.name] = 0.
+        self.loss_metrics = []
+        # using the metric wrapper for loss to record loss value
+        for l in self.loss:
+            self.loss_metrics.append(Loss(l.name))
+            self.train_stats['train_' + l.name] = []
+            # only record the latest loss numbers after each batch
+            self.train_stats['batch_' + l.name] = 0.
+
+        # handle context
+        if isinstance(context, Context):
+            self.context = [context]
+        if not context:
+            if num_gpus() > 0:
+                # only use 1 GPU by default
+                if num_gpus() > 1:
+                    warnings.warn("You have multiple GPUs, gpu(0) will be used by default."
+                                  "To utilize all your GPUs, specify context as a list of gpus, "
+                                  "e.g. context=[mx.gpu(0), mx.gpu(1)] ")
+                self.context = [gpu(0)]
+            else:
+                self.context = [cpu()]
+
+        # initialize the network
+        if self.initializer:
+            if self._is_initialized():
+                # if already initialized, re-init with user specified initializer
+                warnings.warn("Network already initialized, re-initializing with %s. "
+                              "You don't need to pass initializer if you already "
+                              "initialized your net."% type(self.initializer).__name__)
+                self.net.initialize(init=self.initializer, ctx=self.context, force_reinit=True)
+            else:
+                # initialize with user specified initializer
+                self.net.initialize(init=self.initializer, ctx=self.context, force_reinit=False)
+        else:
+            if not self._is_initialized():
+                self.net.initialize(ctx=self.context)
+
+        # handle trainers
+        if isinstance(trainers, gluon.Trainer):
+            self.trainers = [trainers]
+        else:
+            self.trainers = trainers or []
+        if not self.trainers:
+            warnings.warn("No trainer specified, default SGD optimizer "
+                          "with learning rate 0.001 is used.")
+            self.trainers = [gluon.Trainer(self.net.collect_params(),
+                                           'sgd', {'learning_rate': 0.001})]
+
+    def _is_initialized(self):
+        param_dict = self.net.collect_params()
+        for param in param_dict:
+            try:
+                param_dict[param].list_ctx()
+            except RuntimeError:
+                return False
+        return True
+
+    def _batch_fn(self, batch, ctx, is_iterator=False):
+        if is_iterator:
+            data = batch.data[0]
+            label = batch.label[0]
+        else:
+            data = batch[0]
+            label = batch[1]
+        data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=0)
+        label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0)
+        return data, label
+
+    def fit(self, train_data,
+            epochs=1,
+            batch_size=None,
+            event_handlers=None,
+            batch_fn=None):
+        """Main training loop
+
+        Parameters
+        ----------
+        train_data : DataLoader or DataIter
+            training data with data and labels
+        val_data : DataLoader or DataIter
+            validation data with data and labels
+        epochs : int, default 1
+            number of epochs to iterate on the training data.
+        batch_size : int
+            number of samples per gradient update.
+            default will be 32 per device
+        event_handlers : EventHandler or list of EventHandler
+            list of EventHandlers to apply during training
+        batch_fn : function
+            custom batch function to extract data and label
+            from a data batch and load into contexts(devices)
+        """
+
+
+        self.epochs = epochs
+        if not batch_size:
+            batch_size = 32 * len(self.context)
+
+        event_handlers = event_handlers or []
+        # provide default logging handler
+        if not event_handlers or \
+                not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
+            event_handlers.append(LoggingHandler(self))
+
+        # training begin
+        for handler in event_handlers:
+            handler.train_begin()
+
+        for epoch in range(epochs):
+            # epoch begin
+            self.train_stats['epochs'].append(epoch)
+            self.train_stats['learning_rate'].append(self.trainers[0].learning_rate)
+
+            for handler in event_handlers:
+                handler.epoch_begin()
+
+            for metric in self.metrics + self.loss_metrics:
+                metric.reset()
+
+            for i, batch in enumerate(train_data):
+                if not batch_fn:
+                    if isinstance(train_data, gluon.data.DataLoader):
+                        data, label = self._batch_fn(batch, self.context)
+                    elif isinstance(train_data, DataIter):
+                        data, label = self._batch_fn(batch, self.context, is_iterator=True)
+                    else:
+                        raise ValueError("You are using a custom iteration, please also provide "
+                                         "batch_fn to extract data and label")
+                else:
+                    data, label = batch_fn(batch, self.context)
+
+                # batch begin
+                for handler in event_handlers:
+                    handler.batch_begin()
+
+                with autograd.record():
+                    pred = [self.net(x) for x in data]
+                    losses = []
+                    for loss in self.loss:
+                        losses.append([loss(y_hat, y) for y_hat, y in zip(pred, label)])
+
+                for loss in losses:
+                    for l in loss:
+                        l.backward()
+
+                # update metrics
+                for metric in self.metrics:
+                    metric.update(label, pred)
+                    self.train_stats['batch_' + metric.name] = metric.get()[1]
+                for loss, loss_metric, in zip(losses, self.loss_metrics):
+                    loss_metric.update(0, [l for l in loss])
+                    self.train_stats['batch_' + loss_metric.name] = loss_metric.get()[1]
+
+                try:
+                    self.train_stats['step'] = "{}/{}".format(batch_size * (i + 1), len(train_data._dataset))
+                except AttributeError:
+                    self.train_stats['step'] = i
+
+                for trainer in self.trainers:
+                    trainer.step(batch_size)
+
+                # batch end
+                for handler in event_handlers:
+                    handler.batch_end()
+
+            for metric in self.metrics + self.loss_metrics:
+                self.train_stats['train_' + metric.name].append(metric.get()[1])
+            # epoch end
+            for handler in event_handlers:
+                handler.epoch_end()
+
+            if self.stop_training:
+                break
+
+        # train end
+        for handler in event_handlers:
+            handler.train_end()
diff --git a/python/mxnet/gluon/estimator/event_handler.py b/python/mxnet/gluon/estimator/event_handler.py
new file mode 100644
index 0000000..0162c36
--- /dev/null
+++ b/python/mxnet/gluon/estimator/event_handler.py
@@ -0,0 +1,307 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import
+"""Gluon EventHandlers for Estimators"""
+
+__all__ = ['EventHandler', 'LoggingHandler']
+import logging
+import os
+import time
+import warnings
+
+import numpy as np
+
+
+class EventHandler(object):
+    """Basic for event handlers
+
+        :py:class:`EventHandler` can perform user defined functions at
+        different stages of training: train begin, epoch begin, batch begin,
+        batch end, epoch end, train end.
+
+        Parameters
+        ----------
+        estimator : Estimator
+            The :py:class:`Estimator` to get training statistics
+        """
+    def __init__(self, estimator):
+        self._estimator = estimator
+
+    def train_begin(self):
+        pass
+
+    def train_end(self):
+        pass
+
+    def batch_begin(self):
+        pass
+
+    def batch_end(self):
+        pass
+
+    def epoch_begin(self):
+        pass
+
+    def epoch_end(self):
+        pass
+
+
+class LoggingHandler(EventHandler):
+    """Basic Logging Handler that applies to every Gluon estimator by default.
+
+    :py:class:`LoggingHandler` logs hyper-parameters, training statistics,
+    and other useful information during training
+
+    Parameters
+    ----------
+    estimator : Estimator
+        The :py:class:`Estimator` to get training statistics
+    file_name : str
+        file name to save the logs
+    file_location: str
+        file location to save the logs
+    """
+
+    def __init__(self, estimator, file_name=None, file_location=None, ):
+        super(LoggingHandler, self).__init__(estimator)
+        self.logger = logging.getLogger(__name__)
+        self.logger.setLevel(logging.INFO)
+        stream_handler = logging.StreamHandler()
+        self.logger.addHandler(stream_handler)
+        # save logger to file only if file name or location is specified
+        if file_name or file_location:
+            file_name = file_name or 'estimator_log'
+            file_location = file_location or './'
+            file_handler = logging.FileHandler(os.path.join(file_location, file_name))
+            self.logger.addHandler(file_handler)
+
+    def train_begin(self):
+        pass
+
+    def train_end(self):
+        pass
+
+    def batch_begin(self):
+        self.batch_start = time.time()
+
+    def batch_end(self):
+        batch_time = time.time() - self.batch_start
+        epoch = self._estimator.train_stats['epochs'][-1]
+        step = self._estimator.train_stats['step']
+        msg = '[Epoch %d] [Step %s] time/step: %.3fs ' % (epoch, step, batch_time)
+        for key in self._estimator.train_stats.keys():
+            if key.startswith('batch_'):
+                msg += key[6:] + ': ' + '%.4f ' % self._estimator.train_stats[key]
+        self.logger.info(msg)
+
+    def epoch_begin(self):
+        self.epoch_start = time.time()
+
+    def epoch_end(self):
+        epoch_time = time.time() - self.epoch_start
+        epoch = self._estimator.train_stats['epochs'][-1]
+        msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time)
+        for key in self._estimator.train_stats.keys():
+            if key.startswith('train_') or key.startswith('test_'):
+                msg += key + ': ' + '%.4f ' % self._estimator.train_stats[key][epoch]
+        self.logger.info(msg)
+
+
+class CheckpointHandler(EventHandler):
+    """Save the model after every epoch.
+
+    :py:class:`CheckpointHandler` save the network parameters every epoch
+
+    Parameters
+    ----------
+    estimator : Estimator
+        The :py:class:`Estimator` to get training statistics
+    filepath : str
+        file name to save the parameters, it can contain directories,
+        for example: ./saved_model/resnet.params
+    monitor: str
+        the metrics to monitor
+    verbose: int, default 0
+        verbosity mode
+    save_best_only: bool
+        if True, only save the parameters if monitored value improved
+    mode: str, default 'auto'
+        one of {auto, min, max}, if `save_best_only=True`, the comparison to make
+        and determine if the monitored value has improved
+    period: int, default 1
+        intervals between saving the network
+    """
+
+    def __init__(self, estimator,
+                 filepath,
+                 monitor='val_loss',
+                 verbose=0,
+                 save_best_only=False,
+                 mode='auto',
+                 period=1):
+        super(CheckpointHandler, self).__init__(estimator)
+        self.monitor = monitor
+        self.verbose = verbose
+        self.filepath = filepath
+        self.save_best_only = save_best_only
+        self.period = period
+        self.epochs_since_last_save = 0
+        self.logger = logging.getLogger(__name__)
+
+        if mode not in ['auto', 'min', 'max']:
+            warnings.warn('ModelCheckpoint mode %s is unknown, '
+                          'fallback to auto mode.' % (mode),
+                          RuntimeWarning)
+            mode = 'auto'
+
+        if mode == 'min':
+            self.monitor_op = np.less
+            self.best = np.Inf
+        elif mode == 'max':
+            self.monitor_op = np.greater
+            self.best = -np.Inf
+        else:
+            # use greater for accuracy and less otherwise
+            if 'acc' in self.monitor:
+                self.monitor_op = np.greater
+                self.best = -np.Inf
+            else:
+                self.monitor_op = np.less
+                self.best = np.Inf
+
+    def epoch_end(self, ):
+        epoch = self._estimator.train_stats['epochs'][-1]
+        # add extension for weights
+        if '.params' not in self.filepath:
+            self.filepath += '.params'
+        self.epochs_since_last_save += 1
+        if self.epochs_since_last_save >= self.period:
+            self.epochs_since_last_save = 0
+            if self.save_best_only:
+                # check if monitor exists in train_stats
+                if self.monitor not in self._estimator.train_stats:
+                    warnings.warn(RuntimeWarning('Unable to find %s in training statistics, make sure'
+                                                 'you are passing one of the metric names as monitor', self.monitor))
+                    self._estimator.net.save_parameters(self.filepath)
+                else:
+                    current = self._estimator.train_stats[self.monitor][-1]
+                    if self.monitor_op(current, self.best):
+                        if self.verbose > 0:
+                            self.logger.info('\n[Epoch %d] %s improved from %0.5f to %0.5f,'
+                                             ' saving model to %s',
+                                             epoch, self.monitor, self.best, current, self.filepath)
+                        self.best = current
+                        self._estimator.net.save_parameters(self.filepath)
+                    else:
+                        if self.verbose > 0:
+                            self.logger.info('\n[Epoch %d] %s did not improve from %0.5f, skipping save model',
+                                             epoch, self.monitor, self.best)
+            else:
+                if self.verbose > 0:
+                    logging.info('\nEpoch %d: saving model to %s', epoch, self.filepath)
+                self._estimator.net.save_parameters(self.filepath)
+
+
+class EarlyStoppingHandler(EventHandler):
+    """Early stop training if monitored value is not improving
+
+    Parameters
+    ----------
+    estimator : Estimator
+        The :py:class:`Estimator` to get training statistics
+    monitor: str
+        the metrics to monitor
+    min_delta: float, default 0
+        minimal change in monitored value to be considered as an improvement
+    patience: int, default 0
+        number of epochs to wait for improvement before terminate training
+    mode: str, default 'auto'
+        one of {auto, min, max}, the comparison to make
+        and determine if the monitored value has improved
+    baseline: float
+        baseline value to compare the monitored value with
+    """
+
+    def __init__(self, estimator,
+                 monitor='val_loss',
+                 min_delta=0,
+                 patience=0,
+                 mode='auto',
+                 baseline=None):
+        super(EarlyStoppingHandler, self).__init__(estimator)
+
+        self._estimator = estimator
+        self.monitor = monitor
+        self.baseline = baseline
+        self.patience = patience
+        self.min_delta = min_delta
+        self.wait = 0
+        self.stopped_epoch = 0
+        self.logger = logging.getLogger(__name__)
+
+        if mode not in ['auto', 'min', 'max']:
+            warnings.warn(RuntimeWarning('EarlyStopping mode %s is unknown, '
+                                         'fallback to auto mode.', mode))
+            mode = 'auto'
+
+        if mode == 'min':
+            self.monitor_op = np.less
+        elif mode == 'max':
+            self.monitor_op = np.greater
+        else:
+            if 'acc' in self.monitor:
+                self.monitor_op = np.greater
+            else:
+                self.monitor_op = np.less
+
+        if self.monitor_op == np.greater:
+            self.min_delta *= 1
+        else:
+            self.min_delta *= -1
+
+    def train_begin(self):
+        self.wait = 0
+        self.stopped_epoch = 0
+        if self.baseline is not None:
+            self.best = self.baseline
+        else:
+            self.best = np.Inf if self.monitor_op == np.less else -np.Inf
+
+    def epoch_end(self):
+        epoch = self._estimator.train_stats['epochs'][-1]
+        if self.monitor not in self._estimator.train_stats:
+            warnings.warn(RuntimeWarning('Unable to find %s in training statistics, make sure'
+                                         'you are passing one of the metric names as monitor', self.monitor))
+        else:
+            current = self._estimator.train_stats[self.monitor][-1]
+            if current is None:
+                return
+
+            if self.monitor_op(current - self.min_delta, self.best):
+                self.best = current
+                self.wait = 0
+            else:
+                self.wait += 1
+                if self.wait >= self.patience:
+                    self.stopped_epoch = epoch
+                    self._estimator.stop_training = True
+
+    def train_end(self):
+        if self.stopped_epoch > 0:
+            self.logger.info('Epoch %d: early stopping due to %s not improving', self.stopped_epoch, self.monitor)
diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py
new file mode 100644
index 0000000..a551594
--- /dev/null
+++ b/tests/python/unittest/test_gluon_event_handler.py
@@ -0,0 +1,92 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import tempfile
+import mxnet as mx
+from mxnet import nd
+from mxnet.gluon import nn, loss
+from mxnet.gluon.estimator import estimator, event_handler
+
+def _get_test_network():
+    net = nn.Sequential()
+    net.add(nn.Dense(128, activation='relu', in_units=100, flatten=False),
+              nn.Dense(64, activation='relu', in_units=128),
+              nn.Dense(10, activation='relu', in_units=64))
+    return net
+
+def _get_test_data():
+    return mx.io.NDArrayIter(data=nd.ones((32, 100)), label=nd.random.randint(0, 10, (32, 1)))
+
+
+def test_checkpoint_handler():
+    tmpdir = tempfile.mkdtemp()
+    file_path = os.path.join(tmpdir, "model.params")
+    test_data  = _get_test_data()
+
+    save_best_only = False
+    mode = 'auto'
+
+    net = _get_test_network()
+    ce_loss = loss.SoftmaxCrossEntropyLoss()
+    acc = mx.metric.Accuracy()
+    est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+    checkpoint_handler = [event_handler.CheckpointHandler(est, file_path,
+                                                          save_best_only=save_best_only,
+                                                          mode=mode)]
+    est.fit(test_data, event_handlers=checkpoint_handler, epochs=1)
+    assert os.path.isfile(file_path)
+    os.remove(file_path)
+
+def test_early_stopping():
+    test_data = _get_test_data()
+
+    mode = 'max'
+    monitor = 'train_accuracy'
+    patience = 0
+
+    net = _get_test_network()
+    ce_loss = loss.SoftmaxCrossEntropyLoss()
+    acc = mx.metric.Accuracy()
+    est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+    early_stopping = [event_handler.EarlyStoppingHandler(est, monitor,
+                                                         patience=patience,
+                                                          mode=mode)]
+    est.fit(test_data, event_handlers=early_stopping, epochs=1)
+
+    mode = 'auto'
+    monitor = 'train_accuracy'
+    patience = 2
+    early_stopping = [event_handler.EarlyStoppingHandler(est, monitor,
+                                                         patience=patience,
+                                                          mode=mode)]
+    est.fit(test_data, event_handlers=early_stopping, epochs=1)
+
+def test_logging():
+    tmpdir = tempfile.mkdtemp()
+    test_data = _get_test_data()
+    file_name = 'test_log'
+    output_dir = os.path.join(tmpdir, file_name)
+
+    net = _get_test_network()
+    ce_loss = loss.SoftmaxCrossEntropyLoss()
+    acc = mx.metric.Accuracy()
+    est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+    logging_handler = [event_handler.LoggingHandler(est, file_name=file_name, file_location=tmpdir)]
+    est.fit(test_data, event_handlers=logging_handler, epochs=1)
+    assert os.path.isfile(output_dir)
+    os.remove(output_dir)
\ No newline at end of file