You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/05/18 18:04:00 UTC

[incubator-mxnet] branch master updated: [MXNET-1333] Estimator and Fit API (#14629)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9f451fb  [MXNET-1333] Estimator and Fit API (#14629)
9f451fb is described below

commit 9f451fb6f4265f7e122ca08a386e85595a5030a2
Author: Lai Wei <ro...@gmail.com>
AuthorDate: Sat May 18 11:03:18 2019 -0700

    [MXNET-1333] Estimator and Fit API (#14629)
    
    * [MXNet-1334][Fit API]base class for estimator and eventhandler (#14346)
    
    * base class for estimator and eventhandler
    
    * add license
    
    * add event handlers
    
    * fix pylint
    
    * improve arg check
    
    * fix pylint
    
    * add unit tests
    
    * Fixed issue where the estimator was printing beyond the dataset size … (#14464)
    
    * Fixed issue where the estimator was printing beyond the dataset size for the last batch
    
    * Added comments
    
    * Nudge to CI
    
    * [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (#14442)
    
    * added estimator unittests
    
    * add more tests for estimator
    
    * added validation logic
    
    * added error handlers, unittests
    
    * improve val stats
    
    * fix pylint
    
    * fix pylint
    
    * update unit test
    
    * fix tests
    
    * fix tests
    
    * updated metrics, val logic
    
    * trigger ci
    
    * trigger ci
    
    * update metric, batch_fn error handler
    
    * update context logic, add default metric
    
    * [MXNet-1340][Fit API]Update train stats (#14494)
    
    * add train history
    
    * update history
    
    * update test
    
    * avoid calling empty methods
    
    * remove train history object
    
    * fix pylint
    
    * add unit test
    
    * fix test
    
    * update categorize handlers
    
    * [MXNet-1375][Fit API]Added RNN integration test for fit() API (#14547)
    
    * Added RNN integration test for fit() API
    
    * Addressed review comments: change in JenkinFile, tmp directory, ctx with condense if/else, renamed imports
    
    * CPU test doesn't require nvidiadocker container
    
    * Modified the structure by removing the redundant code
    
    * [MXNet-1343][Fit API]Add CNN integration test for fit() API (#14405)
    
    * added cnn intg tests for fit api
    
    * updated cnn intg tests
    
    * added functions for nightly test
    
    * updated runtime_function
    
    * updated intg tests
    
    * updated init, datapath, refs
    
    * added validation data
    
    * update cpu test
    
    * refactor code
    
    * updated context
    
    * [MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging verbose support for Gluon fit() API (#14587)
    
    * Retrieve Batch size and Logging verbose support for Gluon fit() API
    
    * NIT changes
    
    * Addressed review comments: shifted the batch size code to a separate method, sentence correction
    
    * Modified unittest
    
    * removed redundant parameter
    
    * Resolve CI test failure
    
    * only support DataLoader for now, future PRs will include DataIter to DataLoader converter
    
    * Get the number of samples from shape attribute instead of length due to low space complexity
    
    * Simplified batch size retrieval code
    
    * removed batch_size parameter from fit() method and fixed the tests
    
    * Verbose exception handling
    
    * Assigning constant to a verbose
    
    * Modified exception message
    
    * Resolved undefined class reference
    
    * Addressed review comments: Modified verbose level names, docs, variable names
    
    * Update estimator.py
    
    * move estimator to contrib (#14633)
    
    * move to gluon contrib (#14635)
    
    * [Fit API] improve event handlers (#14685)
    
    * improve event handlers
    
    * update tests
    
    * passing weakref of estimator
    
    * fix unit test
    
    * fix test
    
    * fix pylint
    
    * fix test
    
    * fix pylint
    
    * move default metric logic
    
    * combine nightly tests
    
    * [MXNET-1396][Fit-API] Update default handler logic (#14765)
    
    * move to nightly for binaries
    
    * update default handler
    
    * fix pylint
    
    * trigger ci
    
    * trigger ci
    
    * [Fit API] update estimator (#14849)
    
    * address comments
    
    * add comment
    
    * check available context
    
    * fix bug
    
    * change cpu check
    
    * [Fit-API] Adress PR comments (#14885)
    
    * address comments
    
    * update checkpoint
    
    * test symbol save
    
    * address comments
    
    * add resume
    
    * update doc and resume checkpoint
    
    * update docs
    
    * trigger ci
    
    * trigger ci
---
 ci/docker/runtime_functions.sh                     |  10 +
 python/mxnet/gluon/contrib/estimator/__init__.py   |  21 +
 python/mxnet/gluon/contrib/estimator/estimator.py  | 408 ++++++++++++
 .../mxnet/gluon/contrib/estimator/event_handler.py | 705 +++++++++++++++++++++
 python/mxnet/gluon/trainer.py                      |   7 +
 tests/nightly/JenkinsfileForBinaries               |   8 +
 tests/nightly/estimator/test_estimator_cnn.py      | 151 +++++
 tests/nightly/estimator/test_sentiment_rnn.py      | 276 ++++++++
 tests/python/unittest/test_gluon_estimator.py      | 371 +++++++++++
 tests/python/unittest/test_gluon_event_handler.py  | 198 ++++++
 10 files changed, 2155 insertions(+)

diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index e1da222..58e39ef 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -1350,6 +1350,16 @@ nightly_scala_demo_test_cpu() {
     bash bin/run_im.sh
 }
 
+nightly_estimator() {
+    set -ex
+    cd /work/mxnet/tests/nightly/estimator
+    export PYTHONPATH=/work/mxnet/python/
+    python test_estimator_cnn.py --type gpu
+    python test_sentiment_rnn.py --type gpu
+    python test_estimator_cnn.py --type cpu
+    python test_sentiment_rnn.py --type cpu
+}
+
 # Deploy
 
 deploy_docs() {
diff --git a/python/mxnet/gluon/contrib/estimator/__init__.py b/python/mxnet/gluon/contrib/estimator/__init__.py
new file mode 100644
index 0000000..58600da
--- /dev/null
+++ b/python/mxnet/gluon/contrib/estimator/__init__.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=wildcard-import
+"""Gluon Estimator Module"""
+from .estimator import *
+from .event_handler import *
diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py
new file mode 100644
index 0000000..da1a391
--- /dev/null
+++ b/python/mxnet/gluon/contrib/estimator/estimator.py
@@ -0,0 +1,408 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import, unused-variable
+"""Gluon Estimator"""
+
+import copy
+import warnings
+
+from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler
+from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd
+from .... import gluon, autograd
+from ....context import Context, cpu, gpu, num_gpus
+from ....metric import EvalMetric, Loss, Accuracy
+
+__all__ = ['Estimator']
+
+
+class Estimator(object):
+    """Estimator Class for easy model training
+
+    :py:class:`Estimator` can be used to facilitate the training & validation process
+
+
+    Parameters
+    ----------
+    net : Block
+        The model used for training.
+    loss : gluon.loss.Loss or list of gluon.loss.Loss
+        Loss(objective functions) to calculate during training.
+    metrics : EvalMetric or list of EvalMetric
+        Metrics for evaluating models.
+    initializer : Initializer
+        Initializer to initialize the network.
+    trainer : Trainer
+        Trainer to apply optimizer on network parameters.
+    context : Context or list of Context
+        Device(s) to run the training on.
+    """
+
+    def __init__(self, net,
+                 loss,
+                 metrics=None,
+                 initializer=None,
+                 trainer=None,
+                 context=None):
+
+        self.net = net
+        self.loss = self._check_loss(loss)
+        self.train_metrics = self._check_metrics(metrics)
+
+        self.context = self._check_context(context)
+        self._initialize(initializer)
+        self.trainer = self._check_trainer(trainer)
+
+    def _check_loss(self, loss):
+        if isinstance(loss, gluon.loss.Loss):
+            loss = [loss]
+        elif isinstance(loss, list) and all([isinstance(l, gluon.loss.Loss) for l in loss]):
+            loss = loss
+        else:
+            raise ValueError("loss must be a Loss or a list of Loss, "
+                             "refer to gluon.loss.Loss:{}".format(loss))
+        return loss
+
+    def _check_metrics(self, metrics):
+        if isinstance(metrics, EvalMetric):
+            metrics = [metrics]
+        else:
+            metrics = metrics or []
+            if not all([isinstance(metric, EvalMetric) for metric in metrics]):
+                raise ValueError("metrics must be a Metric or a list of Metric, "
+                                 "refer to mxnet.metric.EvalMetric:{}".format(metrics))
+        return metrics
+
+    def _check_context(self, context):
+        # infer available context
+        gpus = num_gpus()
+        available_gpus = [gpu(i) for i in range(gpus)]
+
+        if context:
+            # check context values, only accept Context or a list of Context
+            if isinstance(context, Context):
+                context = [context]
+            elif isinstance(context, list) and all([isinstance(c, Context) for c in context]):
+                context = context
+            else:
+                raise ValueError("context must be a Context or a list of Context, "
+                                 "for example mx.cpu() or [mx.gpu(0), mx.gpu(1)], "
+                                 "refer to mxnet.Context:{}".format(context))
+            for ctx in context:
+                assert ctx in available_gpus or str(ctx).startswith('cpu'), \
+                    "%s is not available, please make sure " \
+                    "your context is in one of: mx.cpu(), %s" % \
+                    (ctx, ", ".join([str(ctx) for ctx in available_gpus]))
+        else:
+            # provide default context
+            if gpus > 0:
+                # only use 1 GPU by default
+                if gpus > 1:
+                    warnings.warn("You have multiple GPUs, gpu(0) will be used by default."
+                                  "To utilize all your GPUs, specify context as a list of gpus, "
+                                  "e.g. context=[mx.gpu(0), mx.gpu(1)] ")
+                context = [gpu(0)]
+            else:
+                context = [cpu()]
+        return context
+
+    def _initialize(self, initializer):
+        # initialize the network
+        if not self._is_initialized():
+            # net is partially or not initialized,
+            # initialize with user specified initializer
+            # if initializer is None, default initializer will be used
+            # do not re-init layers already initialized
+            if initializer:
+                self.net.initialize(init=initializer, ctx=self.context)
+            else:
+                self.net.initialize(ctx=self.context)
+        elif initializer:
+            # net is fully initialized, and user passed not None initializer
+            # do not force reinitialize, give warning
+            warnings.warn("Network already fully initialized, skipping initialization. "
+                          "You don't need to pass initializer if you already "
+                          "initialized your net. "
+                          "You can use net.initialize(init=your_initializer, force_reinit=True)"
+                          "to force re-initialize.")
+
+    def _check_trainer(self, trainer):
+        # handle trainer
+        if not trainer:
+            warnings.warn("No trainer specified, default SGD optimizer "
+                          "with learning rate 0.001 is used.")
+            trainer = gluon.Trainer(self.net.collect_params(),
+                                    'sgd', {'learning_rate': 0.001})
+        elif not isinstance(trainer, gluon.Trainer):
+            raise ValueError("Trainer must be a Gluon Trainer instance, refer to "
+                             "gluon.Trainer:{}".format(trainer))
+        return trainer
+
+    def _is_initialized(self):
+        param_dict = self.net.collect_params()
+        for param in param_dict:
+            try:
+                param_dict[param].list_ctx()
+            except RuntimeError:
+                return False
+        return True
+
+    def _get_data_and_label(self, batch, ctx, batch_axis=0):
+        data = batch[0]
+        label = batch[1]
+        data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=batch_axis)
+        label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=batch_axis)
+        return data, label
+
+    def prepare_loss_and_metrics(self):
+        """
+        Based on loss functions and training metrics in estimator
+        Create metric wrappers to record loss values,
+        Create copies of train loss/metric objects to record validation values
+        Returns train_metrics and val_metrics
+
+        """
+        if any(not hasattr(self, attribute) for attribute in
+               ['train_metrics', 'val_metrics']):
+            # Use default mx.metric.Accuracy() for gluon.loss.SoftmaxCrossEntropyLoss()
+            if not self.train_metrics and any([isinstance(l, gluon.loss.SoftmaxCrossEntropyLoss) for l in self.loss]):
+                self.train_metrics = [Accuracy()]
+            self.val_metrics = []
+            for loss in self.loss:
+                # remove trailing numbers from loss name to avoid confusion
+                self.train_metrics.append(Loss(loss.name.rstrip('1234567890')))
+            for metric in self.train_metrics:
+                val_metric = copy.deepcopy(metric)
+                metric.name = "train " + metric.name
+                val_metric.name = "validation " + val_metric.name
+                self.val_metrics.append(val_metric)
+        return self.train_metrics, self.val_metrics
+
+    def evaluate(self,
+                 val_data,
+                 val_metrics,
+                 batch_axis=0):
+        """Evaluate model on validation data
+
+         Parameters
+         ----------
+         val_data : DataLoader
+             Validation data loader with data and labels.
+         val_metrics : EvalMetric or list of EvalMetrics
+             Metrics to update validation result.
+         batch_axis : int, default 0
+             Batch axis to split the validation data into devices.
+         """
+        if not isinstance(val_data, gluon.data.DataLoader):
+            raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
+                             "can transform your DataIter or any NDArray into Gluon DataLoader. "
+                             "Refer to gluon.data.dataloader")
+
+        for metric in val_metrics:
+            metric.reset()
+
+        for _, batch in enumerate(val_data):
+            data, label = self._get_data_and_label(batch, self.context, batch_axis)
+            pred = [self.net(x) for x in data]
+            loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
+            # update metrics
+            for metric in val_metrics:
+                if isinstance(metric, Loss):
+                    metric.update(0, loss)
+                else:
+                    metric.update(label, pred)
+
+    def fit(self, train_data,
+            val_data=None,
+            epochs=None,
+            event_handlers=None,
+            batches=None,
+            batch_axis=0):
+        """Trains the model with a given :py:class:`DataLoader` for a specified
+        number of epochs or batches. The batch size is inferred from the
+        data loader's batch_size.
+
+        Parameters
+        ----------
+        train_data : DataLoader
+            Training data loader with data and labels.
+        val_data : DataLoader, default None
+            Validation data loader with data and labels.
+        epochs : int, default None
+            Number of epochs to iterate on the training data.
+            You can only specify one and only one type of iteration(epochs or batches).
+        event_handlers : EventHandler or list of EventHandler
+            List of :py:class:`EventHandlers` to apply during training.
+        batches : int, default None
+            Number of batches to iterate on the training data.
+            You can only specify one and only one type of iteration(epochs or batches).
+        batch_axis : int, default 0
+            Batch axis to split the training data into devices.
+        """
+        if not isinstance(train_data, gluon.data.DataLoader):
+            raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
+                             "can transform your DataIter or any NDArray into Gluon DataLoader. "
+                             "Refer to gluon.data.dataloader")
+
+        # must specify one and only one of epochs or batches
+        if (not epochs) == (not batches):
+            raise ValueError(
+                "Fit only support exactly one type of iteration, "
+                "train by number of epochs or number of batches."
+                "Please specify one and only one of: epochs or batches.")
+
+        self.max_epoch = epochs
+        self.max_batch = batches
+
+        # provide default handlers
+        event_handlers = self._prepare_default_handlers(val_data, event_handlers)
+
+        train_begin, epoch_begin, batch_begin, \
+        batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers)
+
+        # pass a reference to all event handlers
+        estimator_ref = self
+        # training begin
+        for handler in train_begin:
+            handler.train_begin(estimator_ref)
+
+        while True:
+            # epoch begin
+            for handler in epoch_begin:
+                handler.epoch_begin(estimator_ref)
+
+            for i, batch in enumerate(train_data):
+                data, label = self._get_data_and_label(batch, self.context, batch_axis)
+
+                batch_size = batch[0].shape[0]
+
+                # batch begin
+                for handler in batch_begin:
+                    handler.batch_begin(estimator_ref, batch=batch)
+
+                with autograd.record():
+                    pred = [self.net(x) for x in data]
+                    loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)]
+
+                for l in loss:
+                    l.backward()
+
+                self.trainer.step(batch_size)
+                # batch end
+
+                batch_end_result = []
+                for handler in batch_end:
+                    batch_end_result.append(handler.batch_end(estimator_ref, batch=batch,
+                                                              pred=pred, label=label, loss=loss))
+                # if any handler signaled to stop
+                if any(batch_end_result):
+                    break
+
+            # epoch end
+            epoch_end_result = []
+            for handler in epoch_end:
+                epoch_end_result.append(handler.epoch_end(estimator_ref))
+            # if any handler signaled to stop
+            if any(epoch_end_result):
+                break
+
+        # train end
+        for handler in train_end:
+            handler.train_end(estimator_ref)
+
+    def _prepare_default_handlers(self, val_data, event_handlers):
+        event_handlers = event_handlers or []
+        default_handlers = []
+        train_metrics, val_metrics = self.prepare_loss_and_metrics()
+
+        # no need to add to default handler check as StoppingHandler does not use metrics
+        event_handlers.append(StoppingHandler(self.max_epoch, self.max_batch))
+
+        if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
+            event_handlers.append(MetricHandler(train_metrics=train_metrics))
+            default_handlers.append("MetricHandler")
+
+        if val_data and not any(isinstance(handler, ValidationHandler) for handler in event_handlers):
+            event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate,
+                                                    val_metrics=val_metrics))
+            default_handlers.append("ValidationHandler")
+
+        if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
+            event_handlers.append(LoggingHandler(train_metrics=train_metrics,
+                                                 val_metrics=val_metrics))
+            default_handlers.append("LoggingHandler")
+
+        # if there is a mix of user defined event handlers and default event handlers
+        # they should have the same set of loss and metrics
+        if default_handlers:
+            msg = "You are training with the following default event handlers: %s. " \
+                  "They use loss and metrics from estimator.prepare_loss_and_metrics(). " \
+                  "Please use the same set of metrics for all your other handlers." % \
+                  ", ".join(default_handlers)
+            warnings.warn(msg)
+            # check if all handlers has the same set of references to loss and metrics
+            references = []
+            for handler in event_handlers:
+                for attribute in dir(handler):
+                    if any(keyword in attribute for keyword in ['metric' or 'monitor']):
+                        reference = getattr(handler, attribute)
+                        if isinstance(reference, list):
+                            references += reference
+                        else:
+                            references.append(reference)
+            # remove None metric references
+            references = set([ref for ref in references if ref])
+            for metric in references:
+                if metric not in train_metrics + val_metrics:
+                    msg = "We have added following default handlers for you: %s and used " \
+                          "estimator.prepare_loss_and_metrics() to pass metrics to " \
+                          "those handlers. Please use the same set of metrics " \
+                          "for all your handlers." % \
+                          ", ".join(default_handlers)
+                    raise ValueError(msg)
+
+        event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0))
+        return event_handlers
+
+    def _categorize_handlers(self, event_handlers):
+        """
+        categorize handlers into 6 event lists to avoid calling empty methods
+        for example, only event handlers with train_begin method
+        implemented will be called at train begin
+        """
+
+        train_begin = []
+        epoch_begin = []
+        batch_begin = []
+        batch_end = []
+        epoch_end = []
+        train_end = []
+        for handler in event_handlers:
+            if isinstance(handler, TrainBegin):
+                train_begin.append(handler)
+            if isinstance(handler, EpochBegin):
+                epoch_begin.append(handler)
+            if isinstance(handler, BatchBegin):
+                batch_begin.append(handler)
+            if isinstance(handler, BatchEnd):
+                batch_end.append(handler)
+            if isinstance(handler, EpochEnd):
+                epoch_end.append(handler)
+            if isinstance(handler, TrainEnd):
+                train_end.append(handler)
+        return train_begin, epoch_begin, batch_begin, batch_end, epoch_end, train_end
diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py
new file mode 100644
index 0000000..ce5890e
--- /dev/null
+++ b/python/mxnet/gluon/contrib/estimator/event_handler.py
@@ -0,0 +1,705 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import, unused-argument
+"""Gluon EventHandlers for Estimators"""
+
+import logging
+import os
+import time
+import warnings
+
+import numpy as np
+
+from ....metric import EvalMetric, Loss
+
+
+class TrainBegin(object):
+    def train_begin(self, estimator, *args, **kwargs):
+        pass
+
+
+class TrainEnd(object):
+    def train_end(self, estimator, *args, **kwargs):
+        pass
+
+
+class EpochBegin(object):
+    def epoch_begin(self, estimator, *args, **kwargs):
+        pass
+
+
+class EpochEnd(object):
+    def epoch_end(self, estimator, *args, **kwargs):
+        return False
+
+
+class BatchBegin(object):
+    def batch_begin(self, estimator, *args, **kwargs):
+        pass
+
+
+class BatchEnd(object):
+    def batch_end(self, estimator, *args, **kwargs):
+        return False
+
+
+class StoppingHandler(TrainBegin, BatchEnd, EpochEnd):
+    """Stop conditions to stop training
+    Stop training if maximum number of batches or epochs
+    reached.
+
+    Parameters
+    ----------
+    max_epoch : int, default None
+        Number of maximum epochs to train.
+    max_batch : int, default None
+        Number of maximum batches to train.
+
+    """
+
+    def __init__(self, max_epoch=None, max_batch=None):
+        self.max_epoch = max_epoch
+        self.max_batch = max_batch
+        self.current_batch = 0
+        self.current_epoch = 0
+        self.stop_training = False
+
+    def train_begin(self, estimator, *args, **kwargs):
+        self.max_epoch = estimator.max_epoch
+        self.max_batch = estimator.max_batch
+        self.current_batch = 0
+        self.current_epoch = 0
+
+    def batch_end(self, estimator, *args, **kwargs):
+        self.current_batch += 1
+        if self.current_batch == self.max_batch:
+            self.stop_training = True
+        return self.stop_training
+
+    def epoch_end(self, estimator, *args, **kwargs):
+        self.current_epoch += 1
+        if self.current_epoch == self.max_epoch:
+            self.stop_training = True
+        return self.stop_training
+
+
+class MetricHandler(EpochBegin, BatchEnd):
+    """Metric Handler that update metric values at batch end
+
+    :py:class:`MetricHandler` takes model predictions and true labels
+    and update the metrics, it also update metric wrapper for loss with loss values.
+    Validation loss and metrics will be handled by :py:class:`ValidationHandler`
+
+    Parameters
+    ----------
+    train_metrics : List of EvalMetrics
+        Training metrics to be updated at batch end.
+    """
+
+    def __init__(self, train_metrics):
+        self.train_metrics = train_metrics or []
+        # order to be called among all callbacks
+        # metrics need to be calculated before other callbacks can access them
+        self.priority = -np.Inf
+
+    def epoch_begin(self, estimator, *args, **kwargs):
+        for metric in self.train_metrics:
+            metric.reset()
+
+    def batch_end(self, estimator, *args, **kwargs):
+        pred = kwargs['pred']
+        label = kwargs['label']
+        loss = kwargs['loss']
+        for metric in self.train_metrics:
+            if isinstance(metric, Loss):
+                # metric wrapper for loss values
+                metric.update(0, loss)
+            else:
+                metric.update(label, pred)
+
+
+class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
+    """"Validation Handler that evaluate model on validation dataset
+
+    :py:class:`ValidationHandler` takes validation dataset, an evaluation function,
+    metrics to be evaluated, and how often to run the validation. You can provide custom
+    evaluation function or use the one provided my :py:class:`Estimator`
+
+    Parameters
+    ----------
+    val_data : DataLoader
+        Validation data set to run evaluation.
+    eval_fn : function
+        A function defines how to run evaluation and
+        calculate loss and metrics.
+    val_metrics : List of EvalMetrics
+        Validation metrics to be updated.
+    epoch_period : int, default 1
+        How often to run validation at epoch end, by default
+        :py:class:`ValidationHandler` validate every epoch.
+    batch_period : int, default None
+        How often to run validation at batch end, by default
+        :py:class:`ValidationHandler` does not validate at batch end.
+    """
+
+    def __init__(self,
+                 val_data,
+                 eval_fn,
+                 val_metrics=None,
+                 epoch_period=1,
+                 batch_period=None):
+        self.val_data = val_data
+        self.eval_fn = eval_fn
+        self.epoch_period = epoch_period
+        self.batch_period = batch_period
+        self.val_metrics = val_metrics
+        self.current_batch = 0
+        self.current_epoch = 0
+        # order to be called among all callbacks
+        # validation metrics need to be calculated before other callbacks can access them
+        self.priority = -np.Inf
+        self.logger = logging.getLogger(__name__)
+
+    def train_begin(self, estimator, *args, **kwargs):
+        # reset epoch and batch counter
+        self.current_batch = 0
+        self.current_epoch = 0
+
+    def batch_end(self, estimator, *args, **kwargs):
+        self.current_batch += 1
+        if self.batch_period and self.current_batch % self.batch_period == 0:
+            self.eval_fn(val_data=self.val_data,
+                         val_metrics=self.val_metrics)
+            msg = '[Epoch %d] ValidationHandler: %d batches reached, ' \
+                  % (self.current_epoch, self.current_batch)
+            for monitor in self.val_metrics:
+                name, value = monitor.get()
+                msg += '%s: %.4f, ' % (name, value)
+            self.logger.info(msg.rstrip(','))
+
+    def epoch_end(self, estimator, *args, **kwargs):
+        self.current_epoch += 1
+        if self.epoch_period and self.current_epoch % self.epoch_period == 0:
+            self.eval_fn(val_data=self.val_data,
+                         val_metrics=self.val_metrics)
+
+
+class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd):
+    """Basic Logging Handler that applies to every Gluon estimator by default.
+
+    :py:class:`LoggingHandler` logs hyper-parameters, training statistics,
+    and other useful information during training
+
+    Parameters
+    ----------
+    file_name : str
+        File name to save the logs.
+    file_location : str
+        File location to save the logs.
+    filemode : str, default 'a'
+        Logging file mode, default using append mode.
+    verbose : int, default LOG_PER_EPOCH
+        Limit the granularity of metrics displayed during training process.
+        verbose=LOG_PER_EPOCH: display metrics every epoch
+        verbose=LOG_PER_BATCH: display metrics every batch
+    train_metrics : list of EvalMetrics
+        Training metrics to be logged, logged at batch end, epoch end, train end.
+    val_metrics : list of EvalMetrics
+        Validation metrics to be logged, logged at epoch end, train end.
+    """
+
+    LOG_PER_EPOCH = 1
+    LOG_PER_BATCH = 2
+
+    def __init__(self, file_name=None,
+                 file_location=None,
+                 filemode='a',
+                 verbose=LOG_PER_EPOCH,
+                 train_metrics=None,
+                 val_metrics=None):
+        super(LoggingHandler, self).__init__()
+        self.logger = logging.getLogger(__name__)
+        self.logger.setLevel(logging.INFO)
+        stream_handler = logging.StreamHandler()
+        self.logger.addHandler(stream_handler)
+        # save logger to file only if file name or location is specified
+        if file_name or file_location:
+            file_name = file_name or 'estimator_log'
+            file_location = file_location or './'
+            file_handler = logging.FileHandler(os.path.join(file_location, file_name), mode=filemode)
+            self.logger.addHandler(file_handler)
+        if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]:
+            raise ValueError("verbose level must be either LOG_PER_EPOCH or "
+                             "LOG_PER_BATCH, received %s. "
+                             "E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)"
+                             % verbose)
+        self.verbose = verbose
+        self.train_metrics = train_metrics or []
+        self.val_metrics = val_metrics or []
+        self.batch_index = 0
+        self.current_epoch = 0
+        self.processed_samples = 0
+        # logging handler need to be called at last to make sure all states are updated
+        # it will also shut down logging at train end
+        self.priority = np.Inf
+
+    def train_begin(self, estimator, *args, **kwargs):
+        self.train_start = time.time()
+        trainer = estimator.trainer
+        optimizer = trainer.optimizer.__class__.__name__
+        lr = trainer.learning_rate
+        self.logger.info("Training begin: using optimizer %s "
+                         "with current learning rate %.4f ",
+                         optimizer, lr)
+        if estimator.max_epoch:
+            self.logger.info("Train for %d epochs.", estimator.max_epoch)
+        else:
+            self.logger.info("Train for %d batches.", estimator.max_batch)
+        # reset all counters
+        self.current_epoch = 0
+        self.batch_index = 0
+        self.processed_samples = 0
+
+    def train_end(self, estimator, *args, **kwargs):
+        train_time = time.time() - self.train_start
+        msg = 'Train finished using total %ds with %d epochs. ' % (train_time, self.current_epoch)
+        # log every result in train stats including train/validation loss & metrics
+        for metric in self.train_metrics + self.val_metrics:
+            name, value = metric.get()
+            msg += '%s: %.4f, ' % (name, value)
+        self.logger.info(msg.rstrip(', '))
+        # make a copy of handler list and remove one by one
+        # as removing handler will edit the handler list
+        for handler in self.logger.handlers[:]:
+            handler.close()
+            self.logger.removeHandler(handler)
+        logging.shutdown()
+
+    def batch_begin(self, estimator, *args, **kwargs):
+        if self.verbose == self.LOG_PER_BATCH:
+            self.batch_start = time.time()
+
+    def batch_end(self, estimator, *args, **kwargs):
+        if self.verbose == self.LOG_PER_BATCH:
+            batch_time = time.time() - self.batch_start
+            msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index)
+            self.processed_samples += kwargs['batch'][0].shape[0]
+            msg += '[Samples %s] ' % (self.processed_samples)
+            msg += 'time/batch: %.3fs ' % batch_time
+            for metric in self.train_metrics:
+                # only log current training loss & metric after each batch
+                name, value = metric.get()
+                msg += '%s: %.4f, ' % (name, value)
+            self.logger.info(msg.rstrip(', '))
+        self.batch_index += 1
+
+    def epoch_begin(self, estimator, *args, **kwargs):
+        if self.verbose >= self.LOG_PER_EPOCH:
+            self.epoch_start = time.time()
+            self.logger.info("[Epoch %d] Begin, current learning rate: %.4f",
+                             self.current_epoch, estimator.trainer.learning_rate)
+
+    def epoch_end(self, estimator, *args, **kwargs):
+        if self.verbose >= self.LOG_PER_EPOCH:
+            epoch_time = time.time() - self.epoch_start
+            msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time)
+            for monitor in self.train_metrics + self.val_metrics:
+                name, value = monitor.get()
+                msg += '%s: %.4f, ' % (name, value)
+            self.logger.info(msg.rstrip(', '))
+        self.current_epoch += 1
+        self.batch_index = 0
+
+
+class CheckpointHandler(TrainBegin, BatchEnd, EpochEnd):
+    """Save the model after user define period
+
+    :py:class:`CheckpointHandler` saves the network architecture after first batch if the model
+    can be fully hybridized, saves model parameters and trainer states after user defined period,
+    default saves every epoch.
+
+    Parameters
+    ----------
+    model_dir : str
+        File directory to save all the model related files including model architecture,
+        model parameters, and trainer states.
+    model_prefix : str default 'model'
+        Prefix to add for all checkpoint file names.
+    monitor: EvalMetric, default None
+        The metrics to monitor and determine if model has improved
+    verbose: int, default 0
+        Verbosity mode, 1 means inform user every time a checkpoint is saved
+    save_best: bool, default False
+        If True, monitor must not be None, :py:class:`CheckpointHandler` will save the
+        model parameters and trainer states with the best monitored value.
+    mode: str, default 'auto'
+        One of {auto, min, max}, if `save_best=True`, the comparison to make
+        and determine if the monitored value has improved. if 'auto' mode,
+        :py:class:`CheckpointHandler` will try to use min or max based on
+        the monitored metric name.
+    epoch_period: int, default 1
+        Epoch intervals between saving the network. By default, checkpoints are
+        saved every epoch.
+    batch_period: int, default None
+        Batch intervals between saving the network.
+        By default, checkpoints are not saved based on the number of batches.
+    max_checkpoints : int, default 5
+        Maximum number of checkpoint files to keep in the model_dir, older checkpoints
+        will be removed. Best checkpoint file is not counted.
+    resume_from_checkpoint : bool, default False
+        Whether to resume training from checkpoint in model_dir. If True and checkpoints
+        found, :py:class:`CheckpointHandler` will load net parameters and trainer states,
+        and train the remaining of epochs and batches.
+    """
+
+    def __init__(self,
+                 model_dir,
+                 model_prefix='model',
+                 monitor=None,
+                 verbose=0,
+                 save_best=False,
+                 mode='auto',
+                 epoch_period=1,
+                 batch_period=None,
+                 max_checkpoints=5,
+                 resume_from_checkpoint=False):
+        self.monitor = monitor
+        self.verbose = verbose
+        if not os.path.exists(model_dir):
+            os.makedirs(model_dir)
+        self.model_dir = model_dir
+        self.model_prefix = model_prefix
+        self.save_best = save_best
+        if self.save_best and not isinstance(self.monitor, EvalMetric):
+            raise ValueError("To save best model only, please provide one of the metric objects as monitor, "
+                             "You can get these objects using estimator.prepare_loss_and_metric()")
+        self.epoch_period = epoch_period
+        self.batch_period = batch_period
+        self.current_batch = 0
+        self.current_epoch = 0
+        self.max_checkpoints = max_checkpoints
+        self.resume_from_checkpoint = resume_from_checkpoint
+        self.saved_checkpoints = []
+        self.logger = logging.getLogger(__name__)
+        if self.save_best:
+            if mode not in ['auto', 'min', 'max']:
+                warnings.warn('ModelCheckpoint mode %s is unknown, '
+                              'fallback to auto mode. CheckpointHandler will use'
+                              'max mode for f1 and accuracy metric comparison and '
+                              'use min mode other wise' % (mode),
+                              RuntimeWarning)
+                mode = 'auto'
+
+            if mode == 'min':
+                self.monitor_op = np.less
+                self.best = np.Inf
+            elif mode == 'max':
+                self.monitor_op = np.greater
+                self.best = -np.Inf
+            else:
+                # use greater for accuracy and f1 and less otherwise
+                if 'acc' or 'f1' in self.monitor.get()[0].lower():
+                    self.logger.info("`greater` operator will be used to determine "
+                                     "if %s has improved, please use `min` for mode "
+                                     "if you want otherwise", self.monitor.get()[0])
+                    self.monitor_op = np.greater
+                else:
+                    self.logger.info("`less` operator will be used to determine "
+                                     "if %s has improved, please use `max` for mode "
+                                     "if you want otherwise", self.monitor.get()[0])
+                    self.monitor_op = np.less
+
+    def train_begin(self, estimator, *args, **kwargs):
+        # reset all counters
+        self.current_epoch = 0
+        self.current_batch = 0
+        if self.save_best:
+            self.best = np.Inf if self.monitor_op == np.less else -np.Inf
+        if self.resume_from_checkpoint:
+            error_msg = "To use resume from checkpoint, you must only specify " \
+                        "the same type of period you used for training." \
+                        "For example, if you are training based on number of epochs," \
+                        "you must save only based on epochs, and set batch_period to None."
+            if estimator.max_batch:
+                assert self.batch_period, error_msg
+                assert not self.epoch_period, error_msg
+            if estimator.max_epoch:
+                assert self.epoch_period, error_msg
+                assert not self.batch_period, error_msg
+
+            self._resume_from_checkpoint(estimator)
+
+    def batch_end(self, estimator, *args, **kwargs):
+        # only save symbol once after first batch
+        if self.current_batch == 0:
+            self._save_symbol(estimator)
+        if self.batch_period and (self.current_batch + 1) % self.batch_period == 0:
+            self._save_checkpoint(estimator)
+        self.current_batch += 1
+
+    def epoch_end(self, estimator, *args, **kwargs):
+        if self.epoch_period and (self.current_epoch + 1) % self.epoch_period == 0:
+            self._save_checkpoint(estimator)
+        self.current_epoch += 1
+
+    def _save_checkpoint(self, estimator):
+        # if resumed from checkpoint, increment checkpoint number
+        if self.resume_from_checkpoint:
+            save_epoch_number = self.current_epoch + self.trained_epoch + 1
+            if estimator.max_epoch:
+                # checkpoint saved at epoch end, batch number already incremented
+                save_batch_number = self.current_batch + self.trained_batch
+            else:
+                save_batch_number = self.current_batch + self.trained_batch + 1
+        else:
+            save_epoch_number = self.current_epoch
+            save_batch_number = self.current_batch
+        prefix = "%s-epoch%dbatch%d" % (self.model_prefix, save_epoch_number, save_batch_number)
+        self._save_params_and_trainer(estimator, prefix)
+        if self.verbose > 0:
+            self.logger.info('[Epoch %d] CheckpointHandler: trained total %d batches, '
+                             'saving model at %s with prefix: %s',
+                             self.current_epoch, self.current_batch + 1, self.model_dir, prefix)
+
+        if self.save_best:
+            monitor_name, monitor_value = self.monitor.get()
+            # check if monitor exists in train stats
+            if np.isnan(monitor_value):
+                warnings.warn(RuntimeWarning('Skipping save best because %s is not updated, make sure you '
+                                             'pass one of the metric objects as monitor, '
+                                             'you can use estimator.prepare_loss_and_metrics to'
+                                             'create all metric objects', monitor_name))
+            else:
+                if self.monitor_op(monitor_value, self.best):
+                    prefix = self.model_prefix + '-best'
+                    self._save_params_and_trainer(estimator, prefix)
+                    self.best = monitor_value
+                    if self.verbose > 0:
+                        self.logger.info('[Epoch %d] CheckpointHandler: '
+                                         '%s improved from %0.5f to %0.5f, '
+                                         'updating best model at %s with prefix: %s',
+                                         self.current_epoch, monitor_name,
+                                         self.best, monitor_value, self.model_dir, prefix)
+                else:
+                    if self.verbose > 0:
+                        self.logger.info('[Epoch %d] CheckpointHandler: '
+                                         '%s did not improve from %0.5f, '
+                                         'skipping updating best model',
+                                         self.current_batch, monitor_name,
+                                         self.best)
+
+    def _save_symbol(self, estimator):
+        symbol_file = os.path.join(self.model_dir, self.model_prefix + '-symbol.json')
+        if hasattr(estimator.net, '_cached_graph'):
+            sym = estimator.net._cached_graph[1]
+            sym.save(symbol_file)
+        else:
+            self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock"
+                             "to construct your model, can call net.hybridize() before passing to"
+                             "Estimator in order to save model architecture as %s.", symbol_file)
+
+    def _save_params_and_trainer(self, estimator, file_prefix):
+        param_file = os.path.join(self.model_dir, file_prefix + '.params')
+        trainer_file = os.path.join(self.model_dir, file_prefix + '.states')
+        estimator.net.save_parameters(param_file)
+        estimator.trainer.save_states(trainer_file)
+
+        # only count checkpoints with epoch or batch number in file name
+        if 'best' not in file_prefix:
+            self.saved_checkpoints.append(file_prefix)
+        # remove old checkpoint when max number of checkpoints reached
+        if len(self.saved_checkpoints) > self.max_checkpoints:
+            prefix = self.saved_checkpoints.pop(0)
+            for fname in os.listdir(self.model_dir):
+                if fname.startswith(prefix):
+                    os.remove(os.path.join(self.model_dir, fname))
+
+    def _resume_from_checkpoint(self, estimator):
+        prefix = self.model_prefix + '-epoch'
+        self.trained_epoch = self._find_max_iteration(
+            dir=self.model_dir,
+            prefix=prefix,
+            start='epoch',
+            end='batch',
+            saved_checkpoints=self.saved_checkpoints)
+        prefix += str(self.trained_epoch)
+        self.trained_batch = self._find_max_iteration(
+            dir=self.model_dir,
+            prefix=prefix,
+            start='batch',
+            end='.params')
+
+        if self.trained_epoch == -1:
+            msg = "CheckpointHandler: No checkpoint found, training from scratch for "
+            if estimator.max_batch:
+                msg += "%d batches" % estimator.max_batch
+            else:
+                msg += "%d epochs" % estimator.max_epoch
+            self.logger.info(msg)
+        else:
+            msg = "CheckpointHandler: Checkpoint resumed from epoch %d batch %d, " \
+                  "continue to train for " % (self.trained_epoch, self.trained_batch)
+            # change maximum number of epoch or batch to train if resumed from epoch checkpoint
+            if estimator.max_epoch:
+                if self.trained_epoch >= estimator.max_epoch - 1:
+                    raise ValueError("Found checkpoint with maximum number of epoch %d reached, please specify "
+                                     "resume_from_checkpoint=False (default value) if you wan to train from scratch."
+                                     % estimator.max_epoch)
+                estimator.max_epoch = estimator.max_epoch - self.trained_epoch - 1
+                msg += "%d epochs " % estimator.max_epoch
+            if estimator.max_batch:
+                if self.trained_batch >= estimator.max_batch - 1:
+                    raise ValueError("Found checkpoint with maximum number of batch %d reached, please specify"
+                                     "resume_from_checkpoint=False (default value) if you wan to train from scratch."
+                                     % self.trained_batch)
+                estimator.max_batch = estimator.max_batch - self.trained_batch - 1
+                msg += "%d batches " % estimator.max_batch
+            # load checkpoint
+            param_file = "%s-epoch%dbatch%d.params" % (self.model_prefix, self.trained_epoch, self.trained_batch)
+            param_file = os.path.join(self.model_dir, param_file)
+            trainer_file = "%s-epoch%dbatch%d.states" % (self.model_prefix, self.trained_epoch, self.trained_batch)
+            trainer_file = os.path.join(self.model_dir, trainer_file)
+            assert os.path.exists(param_file), "Failed to load checkpoint, %s does not exist" % param_file
+            assert os.path.exists(trainer_file), "Failed to load checkpoint, %s does not exist" % trainer_file
+            estimator.net.load_parameters(param_file, ctx=estimator.context)
+            estimator.trainer.load_states(trainer_file)
+            self.logger.warning(msg)
+
+    def _find_max_iteration(self, dir, prefix, start, end, saved_checkpoints=None):
+        error_msg = "Error parsing checkpoint file, please check your " \
+                    "checkpoints have the format: " \
+                    "{model_name}-epoch{epoch_number}batch{batch_number}.params, " \
+                    "there should also be a .states file for each .params file "
+        max_iter = -1
+        for fname in os.listdir(dir):
+            if fname.startswith(prefix) and '.params' in fname:
+                if saved_checkpoints:
+                    # save prefix of existing checkpoints
+                    saved_checkpoints.append(fname[:fname.find('.params')])
+                try:
+                    # find trained number of epoch
+                    iter = int(fname[fname.find(start) + len(start): fname.find(end)])
+                    if iter > max_iter:
+                        max_iter = iter
+                except ValueError:
+                    raise ValueError(error_msg)
+        return max_iter
+
+
+class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd):
+    """Early stop training if monitored value is not improving
+
+    Parameters
+    ----------
+    monitor: EvalMetric
+        The metric to monitor, and stop training if this metric does not improve.
+    min_delta: float, default 0
+        Minimal change in monitored value to be considered as an improvement.
+    patience: int, default 0
+        Number of epochs to wait for improvement before terminate training.
+    mode: str, default 'auto'
+        One of {auto, min, max}, if `save_best_only=True`, the comparison to make
+        and determine if the monitored value has improved. if 'auto' mode, checkpoint
+        handler will try to use min or max based on the monitored metric name.
+    baseline: float
+        Baseline value to compare the monitored value with.
+    """
+
+    def __init__(self,
+                 monitor,
+                 min_delta=0,
+                 patience=0,
+                 mode='auto',
+                 baseline=None):
+        super(EarlyStoppingHandler, self).__init__()
+
+        if not isinstance(monitor, EvalMetric):
+            raise ValueError("Please provide one of the metric objects as monitor, "
+                             "You can create these objects using estimator.prepare_loss_and_metric()")
+        self.monitor = monitor
+        self.baseline = baseline
+        self.patience = patience
+        self.min_delta = min_delta
+        self.wait = 0
+        self.stopped_epoch = 0
+        self.current_epoch = 0
+        self.stop_training = False
+        self.logger = logging.getLogger(__name__)
+
+        if mode not in ['auto', 'min', 'max']:
+            warnings.warn('EarlyStopping mode %s is unknown, '
+                          'fallback to auto mode. CheckpointHandler will use'
+                          'max mode for f1 and accuracy metric comparison and '
+                          'use min mode other wise' % (mode),
+                          RuntimeWarning)
+            mode = 'auto'
+
+        if mode == 'min':
+            self.monitor_op = np.less
+        elif mode == 'max':
+            self.monitor_op = np.greater
+        else:
+            if 'acc' or 'f1' in self.monitor.get()[0].lower():
+                self.logger.info("`greater` operator is used to determine "
+                                 "if %s has improved, please use `min` for mode "
+                                 "if you want otherwise", self.monitor.get()[0])
+                self.monitor_op = np.greater
+            else:
+                self.logger.info("`less` operator is used to determine "
+                                 "if %s has improved, please use `max` for mode "
+                                 "if you want otherwise", self.monitor.get()[0])
+                self.monitor_op = np.less
+
+        if self.monitor_op == np.greater:
+            self.min_delta *= 1
+        else:
+            self.min_delta *= -1
+
+    def train_begin(self, estimator, *args, **kwargs):
+        self.wait = 0
+        self.stopped_epoch = 0
+        self.current_epoch = 0
+        self.stop_training = False
+        if self.baseline is not None:
+            self.best = self.baseline
+        else:
+            self.best = np.Inf if self.monitor_op == np.less else -np.Inf
+
+    def epoch_end(self, estimator, *args, **kwargs):
+        monitor_name, monitor_value = self.monitor.get()
+        if np.isnan(monitor_value):
+            warnings.warn(RuntimeWarning('%s is not updated, make sure you pass one of the metric objects'
+                                         'as monitor, you can use estimator.prepare_loss_and_metrics to'
+                                         'create all metric objects', monitor_name))
+        else:
+            if self.monitor_op(monitor_value - self.min_delta, self.best):
+                self.best = monitor_value
+                self.wait = 0
+            else:
+                self.wait += 1
+                if self.wait >= self.patience:
+                    self.stopped_epoch = self.current_epoch
+                    self.stop_training = True
+        self.current_epoch += 1
+        return self.stop_training
+
+    def train_end(self, estimator, *args, **kwargs):
+        if self.stopped_epoch > 0:
+            self.logger.info('[Epoch %d] EarlyStoppingHanlder: early stopping due to %s not improving',
+                             self.stopped_epoch, self.monitor.get()[0])
diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index 6935c27..0939490 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -255,6 +255,13 @@ class Trainer(object):
         else:
             return self._optimizer.learning_rate
 
+    @property
+    def optimizer(self):
+        if isinstance(self._optimizer, opt.Optimizer):
+            return self._optimizer
+        else:
+            raise UserWarning("Optimizer has not been initialized yet")
+
     def set_learning_rate(self, lr):
         """Sets a new learning rate of the optimizer.
 
diff --git a/tests/nightly/JenkinsfileForBinaries b/tests/nightly/JenkinsfileForBinaries
index ea6db1a..e4b9ff1 100755
--- a/tests/nightly/JenkinsfileForBinaries
+++ b/tests/nightly/JenkinsfileForBinaries
@@ -141,6 +141,14 @@ core_logic: {
           utils.docker_run('ubuntu_nightly_gpu', 'nightly_tutorial_test_ubuntu_python3_gpu', true, '1500m')
         }
       }
+    },
+    'Gluon estimator: GPU': {
+      node(NODE_LINUX_GPU) {
+        ws('workspace/estimator-test-gpu') {
+          utils.unpack_and_init('gpu', mx_lib)
+          utils.docker_run('ubuntu_nightly_gpu', 'nightly_estimator', true)
+        }
+      }
     }
   }
 }
diff --git a/tests/nightly/estimator/test_estimator_cnn.py b/tests/nightly/estimator/test_estimator_cnn.py
new file mode 100644
index 0000000..c60dc54
--- /dev/null
+++ b/tests/nightly/estimator/test_estimator_cnn.py
@@ -0,0 +1,151 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Test gluon estimator on CNN models
+
+import argparse
+import numpy as np
+import mxnet as mx
+from mxnet import gluon, init, nd
+from mxnet.gluon import data
+from mxnet.gluon.contrib.estimator import estimator
+from mxnet.gluon.model_zoo import vision
+
+def load_data_mnist(batch_size, resize=None, num_workers=4):
+    '''
+    Load MNIST dataset
+    '''
+    transformer = []
+    if resize:
+        transformer += [data.vision.transforms.Resize(resize)]
+    transformer += [data.vision.transforms.ToTensor()]
+    transformer = data.vision.transforms.Compose(transformer)
+    mnist_train = data.vision.MNIST(train=True)
+    mnist_test = data.vision.MNIST(train=False)
+    train_iter = data.DataLoader(
+        mnist_train.transform_first(transformer), batch_size, shuffle=True,
+        num_workers=num_workers)
+    test_iter = data.DataLoader(
+        mnist_test.transform_first(transformer), batch_size, shuffle=False,
+        num_workers=num_workers)
+    return train_iter, test_iter
+
+def bilinear_kernel(in_channels, out_channels, kernel_size):
+    '''
+    Bilinear interpolation using transposed convolution
+    https://github.com/d2l-ai/d2l-en/blob/master/chapter_computer-vision/fcn.md
+    '''
+    factor = (kernel_size + 1) // 2
+    if kernel_size % 2 == 1:
+        center = factor - 1
+    else:
+        center = factor - 0.5
+    og = np.ogrid[:kernel_size, :kernel_size]
+    filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
+    weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype='float32')
+    weight[range(in_channels), range(out_channels), :, :] = filt
+    return nd.array(weight)
+
+def get_net(model_name, context):
+    if model_name == 'FCN':
+        num_classes = 21
+        pretrained_net = vision.resnet18_v2(pretrained=True, ctx=context)
+        net = gluon.nn.HybridSequential()
+        for layer in pretrained_net.features[:-2]:
+            net.add(layer)
+        net.add(gluon.nn.Conv2D(num_classes, kernel_size=1),
+                gluon.nn.Conv2DTranspose(num_classes, kernel_size=64, padding=16, strides=32))
+        net[-1].initialize(init.Constant(bilinear_kernel(num_classes, num_classes, 64)), ctx=context)
+        net[-2].initialize(init=init.Xavier(), ctx=context)
+        input_shape = (1, 3, 320, 480)
+        label_shape = (1, 320, 480)
+        loss_axis = 1
+    else:
+        net = vision.get_model(model_name, classes=10)
+        net.initialize(mx.init.Xavier(), ctx=context)
+        input_shape = (1, 1, 224, 224)
+        label_shape = 1
+        loss_axis = -1
+    return net, input_shape, label_shape, loss_axis
+
+def test_estimator_cpu():
+    '''
+    Test estimator by doing one pass over each model with synthetic data
+    '''
+    models = ['resnet18_v1',
+              'FCN'
+              ]
+    context = mx.cpu()
+    for model_name in models:
+        net, input_shape, label_shape, loss_axis = get_net(model_name, context)
+        train_dataset = gluon.data.dataset.ArrayDataset(mx.nd.random.uniform(shape=input_shape),
+                                                        mx.nd.zeros(shape=label_shape))
+        val_dataset = gluon.data.dataset.ArrayDataset(mx.nd.random.uniform(shape=input_shape),
+                                                      mx.nd.zeros(shape=label_shape))
+        loss = gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis)
+        train_data = gluon.data.DataLoader(train_dataset, batch_size=1)
+        val_data = gluon.data.DataLoader(val_dataset, batch_size=1)
+        net.hybridize()
+        trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+        # Define estimator
+        est = estimator.Estimator(net=net,
+                                  loss=loss,
+                                  metrics=mx.metric.Accuracy(),
+                                  trainer=trainer,
+                                  context=context)
+        # Call fit()
+        est.fit(train_data=train_data,
+                val_data=val_data,
+                epochs=1)
+
+def test_estimator_gpu():
+    '''
+    Test estimator by training resnet18_v1 for 5 epochs on MNIST and verify accuracy
+    '''
+    model_name = 'resnet18_v1'
+    batch_size = 128
+    num_epochs = 5
+    context = mx.gpu(0)
+    net, _, _, _ = get_net(model_name, context)
+    train_data, test_data = load_data_mnist(batch_size, resize=224)
+    loss = gluon.loss.SoftmaxCrossEntropyLoss()
+    net.hybridize()
+    acc = mx.metric.Accuracy()
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+    # Define estimator
+    est = estimator.Estimator(net=net,
+                              loss=loss,
+                              metrics=acc,
+                              trainer=trainer,
+                              context=context)
+    # Call fit()
+    est.fit(train_data=train_data,
+            val_data=test_data,
+            epochs=num_epochs)
+
+    assert acc.get()[1] > 0.80
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='test gluon estimator')
+    parser.add_argument('--type', type=str, default='cpu')
+    opt = parser.parse_args()
+    if opt.type == 'cpu':
+        test_estimator_cpu()
+    elif opt.type == 'gpu':
+        test_estimator_gpu()
+    else:
+        raise RuntimeError("Unknown test type")
diff --git a/tests/nightly/estimator/test_sentiment_rnn.py b/tests/nightly/estimator/test_sentiment_rnn.py
new file mode 100644
index 0000000..404bf83
--- /dev/null
+++ b/tests/nightly/estimator/test_sentiment_rnn.py
@@ -0,0 +1,276 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Gluon Text Sentiment Classification Example using RNN/CNN
+Example modified from below link:
+https://github.com/d2l-ai/d2l-en/blob/master/chapter_natural-language-processing/sentiment-analysis-rnn.md
+https://github.com/d2l-ai/d2l-en/blob/master/chapter_natural-language-processing/sentiment-analysis-cnn.md"""
+
+import argparse
+import os
+import tarfile
+import random
+import collections
+import mxnet as mx
+from mxnet import nd, gluon
+from mxnet.contrib import text
+from mxnet.gluon import nn, rnn
+from mxnet.gluon.contrib.estimator import estimator
+
+
+class TextCNN(nn.Block):
+    def __init__(self, vocab, embed_size, kernel_sizes, num_channels,
+                 **kwargs):
+        super(TextCNN, self).__init__(**kwargs)
+        self.embedding = nn.Embedding(len(vocab), embed_size)
+        # The embedding layer does not participate in training
+        self.constant_embedding = nn.Embedding(len(vocab), embed_size)
+        self.dropout = nn.Dropout(0.5)
+        self.decoder = nn.Dense(2)
+        # The max-over-time pooling layer has no weight, so it can share an
+        # instance
+        self.pool = nn.GlobalMaxPool1D()
+        # Create multiple one-dimensional convolutional layers
+        self.convs = nn.Sequential()
+        for c, k in zip(num_channels, kernel_sizes):
+            self.convs.add(nn.Conv1D(c, k, activation='relu'))
+
+    def forward(self, inputs):
+        # Concatenate the output of two embedding layers with shape of
+        # (batch size, number of words, word vector dimension) by word vector
+        embeddings = nd.concat(
+            self.embedding(inputs), self.constant_embedding(inputs), dim=2)
+        # According to the input format required by Conv1D, the word vector
+        # dimension, that is, the channel dimension of the one-dimensional
+        # convolutional layer, is transformed into the previous dimension
+        embeddings = embeddings.transpose((0, 2, 1))
+        # For each one-dimensional convolutional layer, after max-over-time
+        # pooling, an NDArray with the shape of (batch size, channel size, 1)
+        # can be obtained. Use the flatten function to remove the last
+        # dimension and then concatenate on the channel dimension
+        encoding = nd.concat(*[nd.flatten(
+            self.pool(conv(embeddings))) for conv in self.convs], dim=1)
+        # After applying the dropout method, use a fully connected layer to
+        # obtain the output
+        outputs = self.decoder(self.dropout(encoding))
+        return outputs
+
+
+class BiRNN(nn.Block):
+    def __init__(self, vocab, embed_size, num_hiddens, num_layers, **kwargs):
+        super(BiRNN, self).__init__(**kwargs)
+        self.embedding = nn.Embedding(len(vocab), embed_size)
+        # Set Bidirectional to True to get a bidirectional recurrent neural
+        # network
+        self.encoder = rnn.LSTM(num_hiddens, num_layers=num_layers,
+                                bidirectional=True, input_size=embed_size)
+        self.decoder = nn.Dense(2)
+
+    def forward(self, inputs):
+        # The shape of inputs is (batch size, number of words). Because LSTM
+        # needs to use sequence as the first dimension, the input is
+        # transformed and the word feature is then extracted. The output shape
+        # is (number of words, batch size, word vector dimension).
+        embeddings = self.embedding(inputs.T)
+        # The shape of states is (number of words, batch size, 2 * number of
+        # hidden units).
+        states = self.encoder(embeddings)
+        # Concatenate the hidden states of the initial time step and final
+        # time step to use as the input of the fully connected layer. Its
+        # shape is (batch size, 4 * number of hidden units)
+        encoding = nd.concat(states[0], states[-1])
+        outputs = self.decoder(encoding)
+        return outputs
+
+
+def download_imdb(data_dir='/tmp/data'):
+    '''
+    Download and extract the IMDB dataset
+    '''
+    url = ('http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz')
+    sha1 = '01ada507287d82875905620988597833ad4e0903'
+    if not os.path.exists(data_dir):
+        os.makedirs(data_dir)
+    file_path = os.path.join(data_dir, 'aclImdb_v1.tar.gz')
+    if not os.path.isfile(file_path):
+        file_path = gluon.utils.download(url, data_dir, sha1_hash=sha1)
+    with tarfile.open(file_path, 'r') as f:
+        f.extractall(data_dir)
+
+
+def read_imdb(folder='train'):
+    '''
+    Read the IMDB dataset
+    '''
+    data = []
+    for label in ['pos', 'neg']:
+        folder_name = os.path.join('/tmp/data/aclImdb/', folder, label)
+        for file in os.listdir(folder_name):
+            with open(os.path.join(folder_name, file), 'rb') as f:
+                review = f.read().decode('utf-8').replace('\n', '').lower()
+                data.append([review, 1 if label == 'pos' else 0])
+    random.shuffle(data)
+    return data
+
+
+def get_tokenized_imdb(data):
+    '''
+    Tokenized the words
+    '''
+
+    def tokenizer(text):
+        return [tok.lower() for tok in text.split(' ')]
+
+    return [tokenizer(review) for review, _ in data]
+
+
+def get_vocab_imdb(data):
+    '''
+    Get the indexed tokens
+    '''
+    tokenized_data = get_tokenized_imdb(data)
+    counter = collections.Counter([tk for st in tokenized_data for tk in st])
+    return text.vocab.Vocabulary(counter, min_freq=5)
+
+
+def preprocess_imdb(data, vocab):
+    '''
+    Make the length of each comment 500 by truncating or adding 0s
+    '''
+    max_l = 500
+
+    def pad(x):
+        return x[:max_l] if len(x) > max_l else x + [0] * (max_l - len(x))
+
+    tokenized_data = get_tokenized_imdb(data)
+    features = nd.array([pad(vocab.to_indices(x)) for x in tokenized_data])
+    labels = nd.array([score for _, score in data])
+    return features, labels
+
+
+def run(net, train_dataloader, test_dataloader, **kwargs):
+    '''
+    Train a test sentiment model
+    '''
+    num_epochs = kwargs['epochs']
+    ctx = kwargs['ctx']
+    batch_size = kwargs['batch_size']
+    lr = kwargs['lr']
+
+    # Define trainer
+    trainer = mx.gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})
+    # Define loss and evaluation metrics
+    loss = gluon.loss.SoftmaxCrossEntropyLoss()
+    acc = mx.metric.Accuracy()
+
+    # Define estimator
+    est = estimator.Estimator(net=net, loss=loss, metrics=acc,
+                              trainer=trainer, context=ctx)
+    # Begin training
+    est.fit(train_data=train_dataloader, val_data=test_dataloader,
+            epochs=num_epochs)
+    return acc
+
+
+def test_estimator_cpu(**kwargs):
+    '''
+    Test estimator by doing one pass over each model with synthetic data
+    '''
+    models = ['TextCNN', 'BiRNN']
+    ctx = kwargs['ctx']
+    batch_size = kwargs['batch_size']
+    embed_size = kwargs['embed_size']
+
+    train_data = mx.nd.random.randint(low=0, high=100, shape=(2 * batch_size, 500))
+    train_label = mx.nd.random.randint(low=0, high=2, shape=(2 * batch_size,))
+    val_data = mx.nd.random.randint(low=0, high=100, shape=(batch_size, 500))
+    val_label = mx.nd.random.randint(low=0, high=2, shape=(batch_size,))
+
+    train_dataloader = gluon.data.DataLoader(dataset=gluon.data.ArrayDataset(train_data, train_label),
+                                             batch_size=batch_size, shuffle=True)
+    val_dataloader = gluon.data.DataLoader(dataset=gluon.data.ArrayDataset(val_data, val_label),
+                                           batch_size=batch_size)
+    vocab_list = mx.nd.zeros(shape=(100,))
+
+    # Get the model
+    for model in models:
+        if model == 'TextCNN':
+            kernel_sizes, nums_channels = [3, 4, 5], [100, 100, 100]
+            net = TextCNN(vocab_list, embed_size, kernel_sizes, nums_channels)
+        else:
+            num_hiddens, num_layers = 100, 2
+            net = BiRNN(vocab_list, embed_size, num_hiddens, num_layers)
+        net.initialize(mx.init.Xavier(), ctx=ctx)
+
+        run(net, train_dataloader, val_dataloader, **kwargs)
+
+
+def test_estimator_gpu(**kwargs):
+    '''
+    Test estimator by training Bidirectional RNN for 5 epochs on the IMDB dataset
+    and verify accuracy
+    '''
+    ctx = kwargs['ctx']
+    batch_size = kwargs['batch_size']
+    num_epochs = kwargs['epochs']
+    embed_size = kwargs['embed_size']
+
+    # data
+    download_imdb()
+    train_data, test_data = read_imdb('train'), read_imdb('test')
+    vocab = get_vocab_imdb(train_data)
+
+    train_set = gluon.data.ArrayDataset(*preprocess_imdb(train_data, vocab))
+    test_set = gluon.data.ArrayDataset(*preprocess_imdb(test_data, vocab))
+    train_dataloader = gluon.data.DataLoader(train_set, batch_size, shuffle=True)
+    test_dataloader = gluon.data.DataLoader(test_set, batch_size)
+
+    # Model
+    num_hiddens, num_layers = 100, 2
+    net = BiRNN(vocab, embed_size, num_hiddens, num_layers)
+    net.initialize(mx.init.Xavier(), ctx=ctx)
+
+    glove_embedding = text.embedding.create(
+        'glove', pretrained_file_name='glove.6B.100d.txt', vocabulary=vocab)
+
+    net.embedding.weight.set_data(glove_embedding.idx_to_vec)
+    net.embedding.collect_params().setattr('grad_req', 'null')
+
+    acc = run(net, train_dataloader, test_dataloader, **kwargs)
+
+    assert acc.get()[1] > 0.70
+
+
+parser = argparse.ArgumentParser(description='test gluon estimator')
+parser.add_argument('--type', type=str, default='cpu')
+opt = parser.parse_args()
+kwargs = {
+    'batch_size': 64,
+    'lr': 0.01,
+    'embed_size': 100
+}
+
+if opt.type == 'cpu':
+    kwargs['ctx'] = mx.cpu()
+    kwargs['epochs'] = 1
+    test_estimator_cpu(**kwargs)
+elif opt.type == 'gpu':
+    kwargs['ctx'] = mx.gpu()
+    kwargs['epochs'] = 5
+    test_estimator_gpu(**kwargs)
+else:
+    raise RuntimeError("Unknown test type")
diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py
new file mode 100644
index 0000000..d2e8c08
--- /dev/null
+++ b/tests/python/unittest/test_gluon_estimator.py
@@ -0,0 +1,371 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+''' Unit tests for Gluon Estimator '''
+
+import sys
+import unittest
+
+import mxnet as mx
+from mxnet import gluon
+from mxnet.gluon import nn
+from mxnet.gluon.contrib.estimator import *
+from nose.tools import assert_raises
+
+
+def _get_test_network():
+    net = nn.Sequential()
+    net.add(nn.Dense(4, activation='relu', flatten=False))
+    return net
+
+
+def _get_test_data():
+    batch_size = 4
+    in_data = mx.nd.random.uniform(shape=(10, 3))
+    out_data = mx.nd.random.uniform(shape=(10, 4))
+    # Input dataloader
+    dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
+    dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size)
+    dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size)
+    return dataloader, dataiter
+
+
+def test_fit():
+    ''' test estimator with different train data types '''
+    net = _get_test_network()
+    dataloader, dataiter = _get_test_data()
+    num_epochs = 1
+    ctx = mx.cpu()
+    loss = gluon.loss.L2Loss()
+    acc = mx.metric.Accuracy()
+    net.initialize(ctx=ctx)
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+    est = Estimator(net=net,
+                    loss=loss,
+                    metrics=acc,
+                    trainer=trainer,
+                    context=ctx)
+
+    est.fit(train_data=dataloader,
+            epochs=num_epochs)
+
+    with assert_raises(ValueError):
+        est.fit(train_data=dataiter,
+                epochs=num_epochs)
+
+    # Input NDArray
+    with assert_raises(ValueError):
+        est.fit(train_data=[mx.nd.ones(shape=(10, 3))],
+                epochs=num_epochs)
+
+
+def test_validation():
+    ''' test different validation data types'''
+    net = _get_test_network()
+    dataloader, dataiter = _get_test_data()
+    num_epochs = 1
+    ctx = mx.cpu()
+    loss = gluon.loss.L2Loss()
+    acc = mx.metric.Accuracy()
+    net.initialize(ctx=ctx)
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+    est = Estimator(net=net,
+                    loss=loss,
+                    metrics=acc,
+                    trainer=trainer,
+                    context=ctx)
+    # Input dataloader
+    est.fit(train_data=dataloader,
+            val_data=dataloader,
+            epochs=num_epochs)
+
+    # using validation handler
+    train_metrics, val_metrics = est.prepare_loss_and_metrics()
+    validation_handler = ValidationHandler(val_data=dataloader, eval_fn=est.evaluate,
+                                           val_metrics=val_metrics)
+
+    with assert_raises(ValueError):
+        est.fit(train_data=dataiter,
+                val_data=dataiter,
+                epochs=num_epochs)
+    # Input NDArray
+    with assert_raises(ValueError):
+        est.fit(train_data=[mx.nd.ones(shape=(10, 3))],
+                val_data=[mx.nd.ones(shape=(10, 3))],
+                epochs=num_epochs)
+
+
+@unittest.skipIf(sys.version_info.major < 3, 'Test on python 3')
+def test_initializer():
+    ''' test with no initializer, inconsistent initializer '''
+    net = _get_test_network()
+    train_data, _ = _get_test_data()
+    num_epochs = 1
+    ctx = mx.cpu()
+
+    loss = gluon.loss.L2Loss()
+    acc = mx.metric.Accuracy()
+    # no initializer
+    est = Estimator(net=net,
+                    loss=loss,
+                    metrics=acc,
+                    context=ctx)
+    est.fit(train_data=train_data,
+            epochs=num_epochs)
+
+    # different initializer for net and estimator
+    net = _get_test_network()
+    net.initialize(mx.init.Xavier(), ctx=ctx)
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+    # catch reinit warning
+    with warnings.catch_warnings(record=True) as w:
+        est = Estimator(net=net,
+                        loss=loss,
+                        metrics=acc,
+                        initializer=mx.init.MSRAPrelu(),
+                        trainer=trainer,
+                        context=ctx)
+        assert 'Network already fully initialized' in str(w[-1].message)
+    # net partially initialized, fine tuning use case
+    net = gluon.model_zoo.vision.resnet18_v1(pretrained=True, ctx=ctx)
+    net.output = gluon.nn.Dense(10) #last layer not initialized
+    est = Estimator(net, loss=loss, metrics=acc, context=ctx)
+    dataset =  gluon.data.ArrayDataset(mx.nd.zeros((10, 3, 224, 224)), mx.nd.zeros((10, 10)))
+    train_data = gluon.data.DataLoader(dataset=dataset, batch_size=5)
+    est.fit(train_data=train_data,
+            epochs=num_epochs)
+
+
+@unittest.skipIf(sys.version_info.major < 3, 'Test on python 3')
+def test_trainer():
+    ''' test with no trainer and invalid trainer '''
+    net = _get_test_network()
+    train_data, _ = _get_test_data()
+    num_epochs = 1
+    ctx = mx.cpu()
+
+    loss = gluon.loss.L2Loss()
+    acc = mx.metric.Accuracy()
+    net.initialize(ctx=ctx)
+    # input no trainer
+    with warnings.catch_warnings(record=True) as w:
+        est = Estimator(net=net,
+                        loss=loss,
+                        metrics=acc,
+                        context=ctx)
+        assert 'No trainer specified' in str(w[-1].message)
+    est.fit(train_data=train_data,
+            epochs=num_epochs)
+
+    # input invalid trainer
+    trainer = 'sgd'
+    with assert_raises(ValueError):
+        est = Estimator(net=net,
+                        loss=loss,
+                        metrics=acc,
+                        trainer=trainer,
+                        context=ctx)
+
+
+def test_metric():
+    ''' test with no metric, list of metrics, invalid metric '''
+    net = _get_test_network()
+    train_data, _ = _get_test_data()
+    num_epochs = 1
+    ctx = mx.cpu()
+
+    loss = gluon.loss.L2Loss()
+    net.initialize(ctx=ctx)
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+    # input no metric
+    est = Estimator(net=net,
+                    loss=loss,
+                    trainer=trainer,
+                    context=ctx)
+    est.fit(train_data=train_data,
+            epochs=num_epochs)
+    # input list of metrics
+    metrics = [mx.metric.Accuracy(), mx.metric.Accuracy()]
+    est = Estimator(net=net,
+                    loss=loss,
+                    metrics=metrics,
+                    trainer=trainer,
+                    context=ctx)
+    est.fit(train_data=train_data,
+            epochs=num_epochs)
+    # input invalid metric
+    with assert_raises(ValueError):
+        est = Estimator(net=net,
+                        loss=loss,
+                        metrics='acc',
+                        trainer=trainer,
+                        context=ctx)
+    # test default metric
+    loss = gluon.loss.SoftmaxCrossEntropyLoss()
+    est = Estimator(net=net,
+                    loss=loss,
+                    trainer=trainer,
+                    context=ctx)
+    est.prepare_loss_and_metrics()
+    assert isinstance(est.train_metrics[0], mx.metric.Accuracy)
+
+
+def test_loss():
+    ''' test with invalid loss '''
+    net = _get_test_network()
+    ctx = mx.cpu()
+    acc = mx.metric.Accuracy()
+    net.initialize(ctx=ctx)
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+    # input invalid loss
+    with assert_raises(ValueError):
+        est = Estimator(net=net,
+                        loss='mse',
+                        metrics=acc,
+                        trainer=trainer,
+                        context=ctx)
+
+
+def test_context():
+    ''' test with no context, list of context, invalid context '''
+    net = _get_test_network()
+    loss = gluon.loss.L2Loss()
+    metrics = mx.metric.Accuracy()
+    # input no context
+    est = Estimator(net=net,
+                    loss=loss,
+                    metrics=metrics)
+    # input list of context
+    gpus = mx.context.num_gpus()
+    ctx = [mx.gpu(i) for i in range(gpus)] if gpus > 0 else [mx.cpu()]
+    net = _get_test_network()
+    est = Estimator(net=net,
+                    loss=loss,
+                    metrics=metrics,
+                    context=ctx)
+    # input invalid context
+    with assert_raises(ValueError):
+        est = Estimator(net=net,
+                        loss=loss,
+                        metrics=metrics,
+                        context='cpu')
+
+    with assert_raises(AssertionError):
+        est = Estimator(net=net,
+                        loss=loss,
+                        metrics=metrics,
+                        context=[mx.gpu(0), mx.gpu(100)])
+
+
+def test_categorize_handlers():
+    class CustomHandler1(TrainBegin):
+
+        def train_begin(self):
+            print("custom train begin")
+
+    class CustomHandler2(EpochBegin, BatchBegin, TrainEnd):
+
+        def epoch_begin(self):
+            print("custom epoch begin")
+
+        def batch_begin(self):
+            print("custom batch begin")
+
+        def train_end(self):
+            print("custom train end")
+
+    class CustomHandler3(EpochBegin, BatchBegin, BatchEnd, TrainEnd):
+
+        def epoch_begin(self):
+            print("custom epoch begin")
+
+        def batch_begin(self):
+            print("custom batch begin")
+
+        def batch_end(self):
+            print("custom batch end")
+
+        def train_end(self):
+            print("custom train end")
+
+    net = nn.Sequential()
+    net.add(nn.Dense(10))
+    loss = gluon.loss.SoftmaxCrossEntropyLoss()
+    est = Estimator(net, loss=loss)
+    event_handlers = [CustomHandler1(), CustomHandler2(), CustomHandler3()]
+    train_begin, epoch_begin, batch_begin, \
+    batch_end, epoch_end, train_end = est._categorize_handlers(event_handlers)
+    assert len(train_begin) == 1
+    assert len(epoch_begin) == 2
+    assert len(batch_begin) == 2
+    assert len(batch_end) == 1
+    assert len(train_end) == 2
+
+
+@unittest.skipIf(sys.version_info.major < 3, 'Test on python 3')
+def test_default_handlers():
+    net = _get_test_network()
+    train_data, _ = _get_test_data()
+
+    num_epochs = 1
+    ctx = mx.cpu()
+
+    net.initialize(ctx=ctx)
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+
+    train_acc = mx.metric.RMSE()
+    loss = gluon.loss.L2Loss()
+
+    est = Estimator(net=net,
+                    loss=loss,
+                    metrics=train_acc,
+                    trainer=trainer,
+                    context=ctx)
+    # no handler
+    with warnings.catch_warnings(record=True) as w:
+        est.fit(train_data=train_data, epochs=num_epochs)
+        assert 'You are training with the' in str(w[-1].message)
+
+    # handler with prepared loss and metrics
+    # use mix of default and user defined handlers
+    train_metrics, val_metrics = est.prepare_loss_and_metrics()
+    logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics)
+    with warnings.catch_warnings(record=True) as w:
+        est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging])
+        assert 'You are training with the' in str(w[-1].message)
+        # provide metric handler by default
+        assert 'MetricHandler' in str(w[-1].message)
+
+    # handler with all user defined metrics
+    # use mix of default and user defined handlers
+    metric = MetricHandler(train_metrics=[train_acc])
+    logging = LoggingHandler(train_metrics=[train_acc], val_metrics=[mx.metric.RMSE("val acc")])
+    est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[metric, logging])
+
+    # handler with mixed metrics, some handler use metrics prepared by estimator
+    # some handler use metrics user prepared
+    logging = LoggingHandler(train_metrics=train_metrics, val_metrics=[mx.metric.RMSE("val acc")])
+    with assert_raises(ValueError):
+        est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging])
+
+    # test handler order
+    train_metrics, val_metrics = est.prepare_loss_and_metrics()
+    early_stopping = EarlyStoppingHandler(monitor=val_metrics[0])
+    handlers = est._prepare_default_handlers(val_data=None, event_handlers=[early_stopping])
+    assert len(handlers) == 4
+    assert isinstance(handlers[0], MetricHandler)
+    assert isinstance(handlers[3], LoggingHandler)
diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py
new file mode 100644
index 0000000..7ea5ff3
--- /dev/null
+++ b/tests/python/unittest/test_gluon_event_handler.py
@@ -0,0 +1,198 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+
+import mxnet as mx
+from common import TemporaryDirectory
+from mxnet import nd
+from mxnet.gluon import nn, loss
+from mxnet.gluon.contrib.estimator import estimator, event_handler
+
+
+def _get_test_network(net=nn.Sequential()):
+    net.add(nn.Dense(128, activation='relu', flatten=False),
+            nn.Dense(64, activation='relu'),
+            nn.Dense(10, activation='relu'))
+    return net
+
+
+def _get_test_data():
+    data = nd.ones((32, 100))
+    label = nd.zeros((32, 1))
+    data_arr = mx.gluon.data.dataset.ArrayDataset(data, label)
+    return mx.gluon.data.DataLoader(data_arr, batch_size=8)
+
+
+def test_checkpoint_handler():
+    with TemporaryDirectory() as tmpdir:
+        model_prefix = 'test_epoch'
+        file_path = os.path.join(tmpdir, model_prefix)
+        test_data = _get_test_data()
+
+        net = _get_test_network()
+        ce_loss = loss.SoftmaxCrossEntropyLoss()
+        acc = mx.metric.Accuracy()
+        est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+        checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
+                                                             model_prefix=model_prefix,
+                                                             monitor=acc,
+                                                             save_best=True,
+                                                             epoch_period=1)
+        est.fit(test_data, event_handlers=[checkpoint_handler], epochs=1)
+        assert checkpoint_handler.current_epoch == 1
+        assert checkpoint_handler.current_batch == 4
+        assert os.path.isfile(file_path + '-best.params')
+        assert os.path.isfile(file_path + '-best.states')
+        assert os.path.isfile(file_path + '-epoch0batch4.params')
+        assert os.path.isfile(file_path + '-epoch0batch4.states')
+
+        model_prefix = 'test_batch'
+        file_path = os.path.join(tmpdir, model_prefix)
+        net = _get_test_network(nn.HybridSequential())
+        net.hybridize()
+        est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+        checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
+                                                             model_prefix=model_prefix,
+                                                             epoch_period=None,
+                                                             batch_period=2,
+                                                             max_checkpoints=2)
+        est.fit(test_data, event_handlers=[checkpoint_handler], batches=10)
+        assert checkpoint_handler.current_batch == 10
+        assert checkpoint_handler.current_epoch == 3
+        assert not os.path.isfile(file_path + 'best.params')
+        assert not os.path.isfile(file_path + 'best.states')
+        assert not os.path.isfile(file_path + '-epoch0batch0.params')
+        assert not os.path.isfile(file_path + '-epoch0batch0.states')
+        assert os.path.isfile(file_path + '-symbol.json')
+        assert os.path.isfile(file_path + '-epoch1batch7.params')
+        assert os.path.isfile(file_path + '-epoch1batch7.states')
+        assert os.path.isfile(file_path + '-epoch2batch9.params')
+        assert os.path.isfile(file_path + '-epoch2batch9.states')
+
+def test_resume_checkpoint():
+    with TemporaryDirectory() as tmpdir:
+        model_prefix = 'test_net'
+        file_path = os.path.join(tmpdir, model_prefix)
+        test_data = _get_test_data()
+
+        net = _get_test_network()
+        ce_loss = loss.SoftmaxCrossEntropyLoss()
+        acc = mx.metric.Accuracy()
+        est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+        checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
+                                                             model_prefix=model_prefix,
+                                                             monitor=acc,
+                                                             max_checkpoints=1)
+        est.fit(test_data, event_handlers=[checkpoint_handler], epochs=2)
+        assert os.path.isfile(file_path + '-epoch1batch8.params')
+        assert os.path.isfile(file_path + '-epoch1batch8.states')
+        checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
+                                                             model_prefix=model_prefix,
+                                                             monitor=acc,
+                                                             max_checkpoints=1,
+                                                             resume_from_checkpoint=True)
+        est.fit(test_data, event_handlers=[checkpoint_handler], epochs=5)
+        # should only continue to train 3 epochs and last checkpoint file is epoch4
+        assert est.max_epoch == 3
+        assert os.path.isfile(file_path + '-epoch4batch20.states')
+
+
+def test_early_stopping():
+    test_data = _get_test_data()
+
+    net = _get_test_network()
+    ce_loss = loss.SoftmaxCrossEntropyLoss()
+    acc = mx.metric.Accuracy()
+    est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+    early_stopping = event_handler.EarlyStoppingHandler(monitor=acc,
+                                                        patience=0,
+                                                        mode='min')
+    est.fit(test_data, event_handlers=[early_stopping], epochs=5)
+    assert early_stopping.current_epoch == 2
+    assert early_stopping.stopped_epoch == 1
+
+    early_stopping = event_handler.EarlyStoppingHandler(monitor=acc,
+                                                        patience=2,
+                                                        mode='auto')
+    est.fit(test_data, event_handlers=[early_stopping], epochs=1)
+    assert early_stopping.current_epoch == 1
+
+
+def test_logging():
+    with TemporaryDirectory() as tmpdir:
+        test_data = _get_test_data()
+        file_name = 'test_log'
+        output_dir = os.path.join(tmpdir, file_name)
+
+        net = _get_test_network()
+        ce_loss = loss.SoftmaxCrossEntropyLoss()
+        acc = mx.metric.Accuracy()
+        est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+        train_metrics, val_metrics = est.prepare_loss_and_metrics()
+        logging_handler = event_handler.LoggingHandler(file_name=file_name,
+                                                       file_location=tmpdir,
+                                                       train_metrics=train_metrics,
+                                                       val_metrics=val_metrics)
+        est.fit(test_data, event_handlers=[logging_handler], epochs=3)
+        assert logging_handler.batch_index == 0
+        assert logging_handler.current_epoch == 3
+        assert os.path.isfile(output_dir)
+
+
+def test_custom_handler():
+    class CustomStopHandler(event_handler.TrainBegin,
+                            event_handler.BatchEnd,
+                            event_handler.EpochEnd):
+        def __init__(self, batch_stop=None, epoch_stop=None):
+            self.batch_stop = batch_stop
+            self.epoch_stop = epoch_stop
+            self.num_batch = 0
+            self.num_epoch = 0
+            self.stop_training = False
+
+        def train_begin(self, estimator, *args, **kwargs):
+            self.num_batch = 0
+            self.num_epoch = 0
+
+        def batch_end(self, estimator, *args, **kwargs):
+            self.num_batch += 1
+            if self.num_batch == self.batch_stop:
+                self.stop_training = True
+            return self.stop_training
+
+        def epoch_end(self, estimator, *args, **kwargs):
+            self.num_epoch += 1
+            if self.num_epoch == self.epoch_stop:
+                self.stop_training = True
+            return self.stop_training
+
+    # total data size is 32, batch size is 8
+    # 4 batch per epoch
+    test_data = _get_test_data()
+    net = _get_test_network()
+    ce_loss = loss.SoftmaxCrossEntropyLoss()
+    acc = mx.metric.Accuracy()
+    est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+    custom_handler = CustomStopHandler(3, 2)
+    est.fit(test_data, event_handlers=[custom_handler], epochs=3)
+    assert custom_handler.num_batch == 3
+    assert custom_handler.num_epoch == 1
+    custom_handler = CustomStopHandler(100, 5)
+    est.fit(test_data, event_handlers=[custom_handler], epochs=10)
+    assert custom_handler.num_batch == 5 * 4
+    assert custom_handler.num_epoch == 5