You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2019/12/11 21:58:13 UTC

[incubator-mxnet] branch v1.6.x updated: Backport Gluon estimator changes to 1.6 (#17048)

This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch v1.6.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.6.x by this push:
     new c7dab17  Backport Gluon estimator changes to 1.6 (#17048)
c7dab17 is described below

commit c7dab1704d2604f26a9b058dec712987c9ac3801
Author: Przemyslaw Tredak <pt...@nvidia.com>
AuthorDate: Wed Dec 11 13:57:33 2019 -0800

    Backport Gluon estimator changes to 1.6 (#17048)
    
    * Include eval_net the validation model in the gluon estimator api (#16957)
    
    * Include eval_net the validation model in the estimator api
    
    * fix small issue
    
    * Extend estimator.evaluate() to support event handlers (#16971)
    
    
    
    * fix unittest failures for the new api interface
    
    * Add comments in the code for readability
    
    * Remove unused argument val_metrics
    
    * merge changes with the master branch
    
    * fix some regression errors
    
    * fix bugs introduced in the merging phase
    
    * Add support of plug and play fit_batch and evaluate_batch (#16982)
    
    * Add support of plug and play fit_batch and evaluate_batch
    
    * Add check for the validity of the estimator model
    
    * Rename estimator model as batch processor
    
    * Remove unused import
    
    * Add documentation of the batch processor class
    
    * refine the documentation of the batch processor
    
    * Fix merge bugs
    
    * fix bugs introduced during merge
    
    * fix sanity check failures
    
    * fix CI bugs
    
    * Fix Gluon Estimator nightly test (#17042)
---
 python/mxnet/gluon/contrib/estimator/__init__.py   |   2 +
 .../gluon/contrib/estimator/batch_processor.py     | 105 +++++++++++
 python/mxnet/gluon/contrib/estimator/estimator.py  | 199 ++++++++++++---------
 .../mxnet/gluon/contrib/estimator/event_handler.py |  59 +++---
 tests/nightly/estimator/test_estimator_cnn.py      |   4 +-
 tests/nightly/estimator/test_sentiment_rnn.py      |   2 +-
 .../python/unittest/test_gluon_batch_processor.py  | 117 ++++++++++++
 tests/python/unittest/test_gluon_estimator.py      | 142 ++++++++++++---
 tests/python/unittest/test_gluon_event_handler.py  |  23 ++-
 9 files changed, 498 insertions(+), 155 deletions(-)

diff --git a/python/mxnet/gluon/contrib/estimator/__init__.py b/python/mxnet/gluon/contrib/estimator/__init__.py
index bb0a091..5ffd603 100644
--- a/python/mxnet/gluon/contrib/estimator/__init__.py
+++ b/python/mxnet/gluon/contrib/estimator/__init__.py
@@ -19,5 +19,7 @@
 """Gluon Estimator Module"""
 from . import estimator
 from . import event_handler
+from . import batch_processor
 from .estimator import *
 from .event_handler import *
+from .batch_processor import *
diff --git a/python/mxnet/gluon/contrib/estimator/batch_processor.py b/python/mxnet/gluon/contrib/estimator/batch_processor.py
new file mode 100644
index 0000000..4985f8c
--- /dev/null
+++ b/python/mxnet/gluon/contrib/estimator/batch_processor.py
@@ -0,0 +1,105 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import, unused-argument, too-many-ancestors
+"""Gluon Batch Processor for Estimators"""
+
+from ...utils import split_and_load
+from .... import autograd
+
+__all__ = ['BatchProcessor']
+
+class BatchProcessor(object):
+    """BatchProcessor Class for plug and play fit_batch & evaluate_batch
+
+    During training or validation, data are divided into minibatches for processing. This
+    class aims at providing hooks of training or validating on a minibatch of data. Users
+    may provide customized fit_batch() and evaluate_batch() methods by inheriting from
+    this class and overriding class methods.
+
+    :py:class:`BatchProcessor` can be used to replace fit_batch() and evaluate_batch()
+    in the base estimator class
+    """
+
+    def __init__(self):
+        pass
+
+    def _get_data_and_label(self, batch, ctx, batch_axis=0):
+        data = batch[0]
+        label = batch[1]
+        data = split_and_load(data, ctx_list=ctx, batch_axis=batch_axis)
+        label = split_and_load(label, ctx_list=ctx, batch_axis=batch_axis)
+        return data, label
+
+    def evaluate_batch(self, estimator,
+                       val_batch,
+                       batch_axis=0):
+        """Evaluate the estimator model on a batch of validation data.
+
+        Parameters
+        ----------
+        estimator : Estimator
+            Reference to the estimator
+        val_batch : tuple
+            Data and label of a batch from the validation data loader.
+        batch_axis : int, default 0
+            Batch axis to split the validation data into devices.
+        """
+        data, label = self._get_data_and_label(val_batch, estimator.context, batch_axis)
+        pred = [estimator.eval_net(x) for x in data]
+        loss = [estimator.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)]
+
+        return data, label, pred, loss
+
+    def fit_batch(self, estimator,
+                  train_batch,
+                  batch_axis=0):
+        """Trains the estimator model on a batch of training data.
+
+        Parameters
+        ----------
+        estimator : Estimator
+            Reference to the estimator
+        train_batch : tuple
+            Data and label of a batch from the training data loader.
+        batch_axis : int, default 0
+            Batch axis to split the training data into devices.
+
+        Returns
+        -------
+        data: List of NDArray
+            Sharded data from the batch. Data is sharded with
+            `gluon.split_and_load`.
+        label: List of NDArray
+            Sharded label from the batch. Labels are sharded with
+            `gluon.split_and_load`.
+        pred: List of NDArray
+            Prediction on each of the sharded inputs.
+        loss: List of NDArray
+            Loss on each of the sharded inputs.
+        """
+        data, label = self._get_data_and_label(train_batch, estimator.context, batch_axis)
+
+        with autograd.record():
+            pred = [estimator.net(x) for x in data]
+            loss = [estimator.loss(y_hat, y) for y_hat, y in zip(pred, label)]
+
+        for l in loss:
+            l.backward()
+
+        return data, label, pred, loss
diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py
index ab7018f..09f4315 100644
--- a/python/mxnet/gluon/contrib/estimator/estimator.py
+++ b/python/mxnet/gluon/contrib/estimator/estimator.py
@@ -32,9 +32,9 @@ from ...data import DataLoader
 from ...loss import Loss as gluon_loss
 from ...trainer import Trainer
 from ...utils import split_and_load
-from .... import autograd
 from ....context import Context, cpu, gpu, num_gpus
 from ....metric import Loss as metric_loss
+from .batch_processor import BatchProcessor
 
 __all__ = ['Estimator']
 
@@ -51,18 +51,41 @@ class Estimator(object):
         The model used for training.
     loss : gluon.loss.Loss
         Loss (objective) function to calculate during training.
-    metrics : EvalMetric or list of EvalMetric
-        Metrics for evaluating models.
+    train_metrics : EvalMetric or list of EvalMetric
+        Training metrics for evaluating models on training dataset.
+    val_metrics : EvalMetric or list of EvalMetric
+        Validation metrics for evaluating models on validation dataset.
     initializer : Initializer
         Initializer to initialize the network.
     trainer : Trainer
         Trainer to apply optimizer on network parameters.
     context : Context or list of Context
         Device(s) to run the training on.
-    evaluation_loss: gluon.loss.loss
-        Loss (objective) function to calculate during evaluation. If set evaluation_loss
+    evaluation_loss : gluon.loss.loss
+        Loss (objective) function to calculate during validation. If set evaluation_loss
         None, it will use the same loss function as self.loss
-
+    eval_net : gluon.Block
+        The model used for validation. The validation model does not necessarily belong to
+        the same model class as the training model. But the two models typically share the
+        same architecture. Therefore the validation model can reuse parameters of the
+        training model.
+
+        The code example of consruction of eval_net sharing the same network parameters as
+        the training net is given below:
+
+        >>> net = _get_train_network()
+        >>> eval_net = _get_test_network(params=net.collect_params())
+        >>> net.initialize(ctx=ctx)
+        >>> est = Estimator(net, loss, eval_net=eval_net)
+
+        Proper namespace match is required for weight sharing between two networks. Most networks
+        inheriting :py:class:`Block` can share their parameters correctly. An exception is
+        Sequential networks that Block scope must be specified for correct weight sharing. For
+        the  naming in mxnet Gluon API, please refer to the site
+        (https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/naming.html)
+        for future information.
+    batch_processor: BatchProcessor
+        BatchProcessor provides customized fit_batch() and evaluate_batch() methods
     """
 
     logger = None
@@ -85,19 +108,26 @@ class Estimator(object):
 
     def __init__(self, net,
                  loss,
-                 metrics=None,
+                 train_metrics=None,
+                 val_metrics=None,
                  initializer=None,
                  trainer=None,
                  context=None,
-                 evaluation_loss=None):
+                 evaluation_loss=None,
+                 eval_net=None,
+                 batch_processor=None):
         self.net = net
         self.loss = self._check_loss(loss)
-        self._train_metrics = _check_metrics(metrics)
+        self._train_metrics = _check_metrics(train_metrics)
+        self._val_metrics = _check_metrics(val_metrics)
         self._add_default_training_metrics()
         self._add_validation_metrics()
         self.evaluation_loss = self.loss
         if evaluation_loss is not None:
             self.evaluation_loss = self._check_loss(evaluation_loss)
+        self.eval_net = self.net
+        if eval_net is not None:
+            self.eval_net = eval_net
 
         self.logger = logging.Logger(name='Estimator', level=logging.INFO)
         self.logger.addHandler(logging.StreamHandler(sys.stdout))
@@ -105,6 +135,7 @@ class Estimator(object):
         self.context = self._check_context(context)
         self._initialize(initializer)
         self.trainer = self._check_trainer(trainer)
+        self.batch_processor = self._check_batch_processor(batch_processor)
 
     def _check_loss(self, loss):
         if not isinstance(loss, gluon_loss):
@@ -145,6 +176,18 @@ class Estimator(object):
                 context = [cpu()]
         return context
 
+    def _check_batch_processor(self, batch_processor):
+        # check whether the batch processor contains fit_batch() and evaluate_batch() methods
+        if batch_processor is not None:
+            model_fit = getattr(batch_processor, 'fit_batch', None)
+            model_evaluate = getattr(batch_processor, 'evaluate_batch', None)
+            if not callable(model_fit) or not callable(model_evaluate):
+                raise ValueError('Customized Batch Processor must contain fit_batch()'
+                                 ' and evaluate_batch() methods')
+        else:
+            batch_processor = BatchProcessor()
+        return batch_processor
+
     def _initialize(self, initializer):
         # initialize the network
         if not self._is_initialized():
@@ -202,13 +245,21 @@ class Estimator(object):
             self._train_metrics.append(metric_loss(loss_name))
 
         for metric in self._train_metrics:
-            metric.name = "training " + metric.name
+            # add training prefix to the metric name
+            # it is useful for event handlers to distinguish them from validation metrics
+            metric.name = 'training ' + metric.name
 
     def _add_validation_metrics(self):
-        self._val_metrics = [copy.deepcopy(metric) for metric in self._train_metrics]
+        if not self._val_metrics:
+            self._val_metrics = [copy.deepcopy(metric) for metric in self._train_metrics]
 
         for metric in self._val_metrics:
-            metric.name = "validation " + metric.name
+            # add validation prefix to the metric name
+            # it is useful for event handlers to distinguish them from training metrics
+            if 'training' in metric.name:
+                metric.name = metric.name.replace('training', 'validation')
+            else:
+                metric.name = 'validation ' + metric.name
 
     @property
     def train_metrics(self):
@@ -218,35 +269,10 @@ class Estimator(object):
     def val_metrics(self):
         return self._val_metrics
 
-    def evaluate_batch(self,
-                       val_batch,
-                       val_metrics,
-                       batch_axis=0):
-        """Evaluate model on a batch of validation data.
-
-        Parameters
-        ----------
-        val_batch : tuple
-            Data and label of a batch from the validation data loader.
-        val_metrics : EvalMetric or list of EvalMetrics
-            Metrics to update validation result.
-        batch_axis : int, default 0
-            Batch axis to split the validation data into devices.
-        """
-        data, label = self._get_data_and_label(val_batch, self.context, batch_axis)
-        pred = [self.net(x) for x in data]
-        loss = [self.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)]
-        # update metrics
-        for metric in val_metrics:
-            if isinstance(metric, metric_loss):
-                metric.update(0, loss)
-            else:
-                metric.update(label, pred)
-
     def evaluate(self,
                  val_data,
-                 val_metrics,
-                 batch_axis=0):
+                 batch_axis=0,
+                 event_handlers=None):
         """Evaluate model on validation data.
 
         This function calls :py:func:`evaluate_batch` on each of the batches from the
@@ -257,57 +283,45 @@ class Estimator(object):
         ----------
         val_data : DataLoader
             Validation data loader with data and labels.
-        val_metrics : EvalMetric or list of EvalMetrics
-            Metrics to update validation result.
         batch_axis : int, default 0
             Batch axis to split the validation data into devices.
+        event_handlers : EventHandler or list of EventHandler
+            List of :py:class:`EventHandlers` to apply during validation. Besides
+            event handlers specified here, a default MetricHandler and a LoggingHandler
+            will be added if not specified explicitly.
         """
         if not isinstance(val_data, DataLoader):
             raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you "
                              "can transform your DataIter or any NDArray into Gluon DataLoader. "
                              "Refer to gluon.data.DataLoader")
 
-        for metric in val_metrics:
+        for metric in self.val_metrics:
             metric.reset()
+        estimator_ref = self
 
-        for _, batch in enumerate(val_data):
-            self.evaluate_batch(batch, val_metrics, batch_axis)
+        event_handlers = self._prepare_default_validation_handlers(event_handlers)
 
-    def fit_batch(self, train_batch, batch_axis=0):
-        """Trains the model on a batch of training data.
+        _, epoch_begin, batch_begin, batch_end, \
+        epoch_end, _ = self._categorize_handlers(event_handlers)
 
-        Parameters
-        ----------
-        train_batch : tuple
-            Data and label of a batch from the training data loader.
-        batch_axis : int, default 0
-            Batch axis to split the training data into devices.
+        estimator_ref = self
 
-        Returns
-        -------
-        data: List of NDArray
-            Sharded data from the batch. Data is sharded with
-            `gluon.split_and_load`.
-        label: List of NDArray
-            Sharded label from the batch. Labels are sharded with
-            `gluon.split_and_load`.
-        pred: List of NDArray
-            Prediction on each of the sharded inputs.
-        loss: List of NDArray
-            Loss on each of the sharded inputs.
-        """
-        data, label = self._get_data_and_label(train_batch, self.context, batch_axis)
+        for handler in epoch_begin:
+            handler.epoch_begin(estimator_ref)
 
-        batch_size = train_batch[0].shape[batch_axis]
+        for _, batch in enumerate(val_data):
+            for handler in batch_begin:
+                handler.batch_begin(estimator_ref, batch=batch)
 
-        with autograd.record():
-            pred = [self.net(x) for x in data]
-            loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)]
+            _, label, pred, loss = \
+            self.batch_processor.evaluate_batch(estimator_ref, batch,
+                                                batch_axis)
 
-        for l in loss:
-            l.backward()
+            for handler in batch_end:
+                handler.batch_end(estimator_ref, batch=batch, pred=pred, label=label, loss=loss)
 
-        return data, label, pred, loss
+        for handler in epoch_end:
+            handler.epoch_end(estimator_ref)
 
     def fit(self, train_data,
             val_data=None,
@@ -382,8 +396,8 @@ class Estimator(object):
                 for handler in batch_begin:
                     handler.batch_begin(estimator_ref, batch=batch)
 
-                _, label, pred, loss = self.fit_batch(batch, batch_axis)
-
+                _, label, pred, loss = self.batch_processor.fit_batch(estimator_ref,
+                                                                      batch, batch_axis)
                 # batch end
 
                 batch_end_result = []
@@ -417,23 +431,17 @@ class Estimator(object):
             added_default_handlers.append(GradientUpdateHandler())
 
         if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
-            added_default_handlers.append(MetricHandler(train_metrics=self.train_metrics))
+            added_default_handlers.append(MetricHandler(metrics=self.train_metrics))
 
         if not any(isinstance(handler, ValidationHandler) for handler in event_handlers):
             # no validation handler
             if val_data:
-                val_metrics = self.val_metrics
                 # add default validation handler if validation data found
                 added_default_handlers.append(ValidationHandler(val_data=val_data,
-                                                                eval_fn=self.evaluate,
-                                                                val_metrics=val_metrics))
-            else:
-                # set validation metrics to None if no validation data and no validation handler
-                val_metrics = []
+                                                                eval_fn=self.evaluate))
 
         if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
-            added_default_handlers.append(LoggingHandler(train_metrics=self.train_metrics,
-                                                         val_metrics=val_metrics))
+            added_default_handlers.append(LoggingHandler(metrics=self.train_metrics))
 
         # if there is a mix of user defined event handlers and default event handlers
         # they should have the same set of metrics
@@ -450,6 +458,29 @@ class Estimator(object):
         event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0))
         return event_handlers
 
+    def _prepare_default_validation_handlers(self, event_handlers):
+        event_handlers = _check_event_handlers(event_handlers)
+        added_default_handlers = []
+
+        # add default logging handler and metric handler for validation
+        if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
+            added_default_handlers.append(MetricHandler(metrics=self.val_metrics))
+
+        if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
+            added_default_handlers.append(LoggingHandler(metrics=self.val_metrics))
+
+        mixing_handlers = event_handlers and added_default_handlers
+        event_handlers.extend(added_default_handlers)
+
+        # check if all handlers refer to well-defined validation metrics
+        if mixing_handlers:
+            known_metrics = set(self.val_metrics)
+            for handler in event_handlers:
+                _check_handler_metric_ref(handler, known_metrics)
+
+        event_handlers.sort(key=lambda handler: getattr(handler, 'priority', 0))
+        return event_handlers
+
     def _categorize_handlers(self, event_handlers):
         """
         categorize handlers into 6 event lists to avoid calling empty methods
diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py
index 6477760..c755136 100644
--- a/python/mxnet/gluon/contrib/estimator/event_handler.py
+++ b/python/mxnet/gluon/contrib/estimator/event_handler.py
@@ -128,28 +128,28 @@ class MetricHandler(EpochBegin, BatchEnd):
 
     Parameters
     ----------
-    train_metrics : List of EvalMetrics
-        Training metrics to be updated at batch end.
+    metrics : List of EvalMetrics
+        Metrics to be updated at batch end.
     priority : scalar
         Priority level of the MetricHandler. Priority level is sorted in ascending
         order. The lower the number is, the higher priority level the handler is.
     """
 
-    def __init__(self, train_metrics, priority=-1000):
-        self.train_metrics = _check_metrics(train_metrics)
+    def __init__(self, metrics, priority=-1000):
+        self.metrics = _check_metrics(metrics)
         # order to be called among all callbacks
         # metrics need to be calculated before other callbacks can access them
         self.priority = priority
 
     def epoch_begin(self, estimator, *args, **kwargs):
-        for metric in self.train_metrics:
+        for metric in self.metrics:
             metric.reset()
 
     def batch_end(self, estimator, *args, **kwargs):
         pred = kwargs['pred']
         label = kwargs['label']
         loss = kwargs['loss']
-        for metric in self.train_metrics:
+        for metric in self.metrics:
             if isinstance(metric, metric_loss):
                 # metric wrapper for loss values
                 metric.update(0, loss)
@@ -171,8 +171,6 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
     eval_fn : function
         A function defines how to run evaluation and
         calculate loss and metrics.
-    val_metrics : List of EvalMetrics
-        Validation metrics to be updated.
     epoch_period : int, default 1
         How often to run validation at epoch end, by default
         :py:class:`ValidationHandler` validate every epoch.
@@ -188,7 +186,6 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
     def __init__(self,
                  val_data,
                  eval_fn,
-                 val_metrics=None,
                  epoch_period=1,
                  batch_period=None,
                  priority=-1000):
@@ -196,7 +193,6 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
         self.eval_fn = eval_fn
         self.epoch_period = epoch_period
         self.batch_period = batch_period
-        self.val_metrics = _check_metrics(val_metrics)
         self.current_batch = 0
         self.current_epoch = 0
         # order to be called among all callbacks
@@ -211,20 +207,12 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
     def batch_end(self, estimator, *args, **kwargs):
         self.current_batch += 1
         if self.batch_period and self.current_batch % self.batch_period == 0:
-            self.eval_fn(val_data=self.val_data,
-                         val_metrics=self.val_metrics)
-            msg = '[Epoch %d] ValidationHandler: %d batches reached, ' \
-                  % (self.current_epoch, self.current_batch)
-            for monitor in self.val_metrics:
-                name, value = monitor.get()
-                msg += '%s: %.4f, ' % (name, value)
-            estimator.logger.info(msg.rstrip(','))
+            self.eval_fn(val_data=self.val_data)
 
     def epoch_end(self, estimator, *args, **kwargs):
         self.current_epoch += 1
         if self.epoch_period and self.current_epoch % self.epoch_period == 0:
-            self.eval_fn(val_data=self.val_data,
-                         val_metrics=self.val_metrics)
+            self.eval_fn(val_data=self.val_data)
 
 
 class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd):
@@ -239,10 +227,8 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
         Logging interval during training.
         log_interval='epoch': display metrics every epoch
         log_interval=integer k: display metrics every interval of k batches
-    train_metrics : list of EvalMetrics
-        Training metrics to be logged, logged at batch end, epoch end, train end.
-    val_metrics : list of EvalMetrics
-        Validation metrics to be logged, logged at epoch end, train end.
+    metrics : list of EvalMetrics
+        Metrics to be logged, logged at batch end, epoch end, train end.
     priority : scalar, default np.Inf
         Priority level of the LoggingHandler. Priority level is sorted in
         ascending order. The lower the number is, the higher priority level the
@@ -250,14 +236,12 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
     """
 
     def __init__(self, log_interval='epoch',
-                 train_metrics=None,
-                 val_metrics=None,
+                 metrics=None,
                  priority=np.Inf):
         super(LoggingHandler, self).__init__()
         if not isinstance(log_interval, int) and log_interval != 'epoch':
             raise ValueError("log_interval must be either an integer or string 'epoch'")
-        self.train_metrics = _check_metrics(train_metrics)
-        self.val_metrics = _check_metrics(val_metrics)
+        self.metrics = _check_metrics(metrics)
         self.batch_index = 0
         self.current_epoch = 0
         self.processed_samples = 0
@@ -265,6 +249,7 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
         # it will also shut down logging at train end
         self.priority = priority
         self.log_interval = log_interval
+        self.log_interval_time = 0
 
     def train_begin(self, estimator, *args, **kwargs):
         self.train_start = time.time()
@@ -288,7 +273,7 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
         train_time = time.time() - self.train_start
         msg = 'Train finished using total %ds with %d epochs. ' % (train_time, self.current_epoch)
         # log every result in train stats including train/validation loss & metrics
-        for metric in self.train_metrics + self.val_metrics:
+        for metric in self.metrics:
             name, value = metric.get()
             msg += '%s: %.4f, ' % (name, value)
         estimator.logger.info(msg.rstrip(', '))
@@ -307,7 +292,7 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
             if self.batch_index % self.log_interval == 0:
                 msg += 'time/interval: %.3fs ' % self.log_interval_time
                 self.log_interval_time = 0
-                for metric in self.train_metrics:
+                for metric in self.metrics:
                     # only log current training loss & metric after each interval
                     name, value = metric.get()
                     msg += '%s: %.4f, ' % (name, value)
@@ -316,15 +301,23 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
 
     def epoch_begin(self, estimator, *args, **kwargs):
         if isinstance(self.log_interval, int) or self.log_interval == 'epoch':
+            is_training = False
+            # use the name hack defined in __init__() of estimator class
+            for metric in self.metrics:
+                if 'training' in metric.name:
+                    is_training = True
             self.epoch_start = time.time()
-            estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f",
-                                  self.current_epoch, estimator.trainer.learning_rate)
+            if is_training:
+                estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f",
+                                      self.current_epoch, estimator.trainer.learning_rate)
+            else:
+                estimator.logger.info("Validation Begin")
 
     def epoch_end(self, estimator, *args, **kwargs):
         if isinstance(self.log_interval, int) or self.log_interval == 'epoch':
             epoch_time = time.time() - self.epoch_start
             msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time)
-            for monitor in self.train_metrics + self.val_metrics:
+            for monitor in self.metrics:
                 name, value = monitor.get()
                 msg += '%s: %.4f, ' % (name, value)
             estimator.logger.info(msg.rstrip(', '))
diff --git a/tests/nightly/estimator/test_estimator_cnn.py b/tests/nightly/estimator/test_estimator_cnn.py
index 4a3bb20..af51953 100644
--- a/tests/nightly/estimator/test_estimator_cnn.py
+++ b/tests/nightly/estimator/test_estimator_cnn.py
@@ -116,7 +116,7 @@ def test_estimator_cpu():
         # Define estimator
         est = estimator.Estimator(net=net,
                                   loss=loss,
-                                  metrics=mx.metric.Accuracy(),
+                                  train_metrics=mx.metric.Accuracy(),
                                   trainer=trainer,
                                   context=context)
         # Call fit()
@@ -145,7 +145,7 @@ def test_estimator_gpu():
     # Define estimator
     est = estimator.Estimator(net=net,
                               loss=loss,
-                              metrics=acc,
+                              train_metrics=acc,
                               trainer=trainer,
                               context=context)
     # Call fit()
diff --git a/tests/nightly/estimator/test_sentiment_rnn.py b/tests/nightly/estimator/test_sentiment_rnn.py
index 233355b..ab124ba 100644
--- a/tests/nightly/estimator/test_sentiment_rnn.py
+++ b/tests/nightly/estimator/test_sentiment_rnn.py
@@ -197,7 +197,7 @@ def run(net, train_dataloader, test_dataloader, num_epochs, ctx, lr):
     nested_metrics.add([metrics, mx.metric.Accuracy()])
 
     # Define estimator
-    est = estimator.Estimator(net=net, loss=loss, metrics=nested_metrics,
+    est = estimator.Estimator(net=net, loss=loss, train_metrics=nested_metrics,
                               trainer=trainer, context=ctx)
     # Begin training
     est.fit(train_data=train_dataloader, val_data=test_dataloader,
diff --git a/tests/python/unittest/test_gluon_batch_processor.py b/tests/python/unittest/test_gluon_batch_processor.py
new file mode 100644
index 0000000..4bd6f76
--- /dev/null
+++ b/tests/python/unittest/test_gluon_batch_processor.py
@@ -0,0 +1,117 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+''' Unit tests for Gluon Batch Processor '''
+
+import sys
+import unittest
+import warnings
+
+import mxnet as mx
+from mxnet import gluon
+from mxnet.gluon import nn
+from mxnet.gluon.contrib.estimator import *
+from mxnet.gluon.contrib.estimator.event_handler import *
+from mxnet.gluon.contrib.estimator.batch_processor import BatchProcessor
+from nose.tools import assert_raises
+
+def _get_test_network():
+    net = nn.Sequential()
+    net.add(nn.Dense(4, activation='relu', flatten=False))
+    return net
+
+
+def _get_test_data():
+    batch_size = 4
+    in_data = mx.nd.random.uniform(shape=(10, 3))
+    out_data = mx.nd.random.uniform(shape=(10, 4))
+    # Input dataloader
+    dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
+    dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size)
+    dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size)
+    return dataloader, dataiter
+
+def test_batch_processor_fit():
+    ''' test estimator with different train data types '''
+    net = _get_test_network()
+    dataloader, dataiter = _get_test_data()
+    num_epochs = 1
+    ctx = mx.cpu()
+    loss = gluon.loss.L2Loss()
+    acc = mx.metric.Accuracy()
+    net.initialize(ctx=ctx)
+    processor = BatchProcessor()
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+    est = Estimator(net=net,
+                    loss=loss,
+                    train_metrics=acc,
+                    trainer=trainer,
+                    context=ctx,
+                    batch_processor=processor)
+
+    est.fit(train_data=dataloader,
+            epochs=num_epochs)
+
+    with assert_raises(ValueError):
+        est.fit(train_data=dataiter,
+                epochs=num_epochs)
+
+    # Input NDArray
+    with assert_raises(ValueError):
+        est.fit(train_data=[mx.nd.ones(shape=(10, 3))],
+                epochs=num_epochs)
+
+
+def test_batch_processor_validation():
+    ''' test different validation data types'''
+    net = _get_test_network()
+    dataloader, dataiter = _get_test_data()
+    num_epochs = 1
+    ctx = mx.cpu()
+    loss = gluon.loss.L2Loss()
+    acc = mx.metric.Accuracy()
+    evaluation_loss = gluon.loss.L1Loss()
+    net.initialize(ctx=ctx)
+    processor = BatchProcessor()
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+    est = Estimator(net=net,
+                    loss=loss,
+                    train_metrics=acc,
+                    trainer=trainer,
+                    context=ctx,
+                    evaluation_loss=evaluation_loss,
+                    batch_processor=processor)
+    # Input dataloader
+    est.fit(train_data=dataloader,
+            val_data=dataloader,
+            epochs=num_epochs)
+
+    # using validation handler
+    train_metrics = est.train_metrics
+    val_metrics = est.val_metrics
+    validation_handler = ValidationHandler(val_data=dataloader, eval_fn=est.evaluate)
+
+    with assert_raises(ValueError):
+        est.fit(train_data=dataiter,
+                val_data=dataiter,
+                epochs=num_epochs)
+    # Input NDArray
+    with assert_raises(ValueError):
+        est.fit(train_data=[mx.nd.ones(shape=(10, 3))],
+                val_data=[mx.nd.ones(shape=(10, 3))],
+                epochs=num_epochs)
+
diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py
index 21f949a..924dd08 100644
--- a/tests/python/unittest/test_gluon_estimator.py
+++ b/tests/python/unittest/test_gluon_estimator.py
@@ -29,11 +29,16 @@ from mxnet.gluon.contrib.estimator.event_handler import *
 from nose.tools import assert_raises
 
 
-def _get_test_network():
-    net = nn.Sequential()
+def _get_test_network(params=None):
+    net = nn.Sequential(params=params)
     net.add(nn.Dense(4, activation='relu', flatten=False))
     return net
 
+def _get_test_network_with_namescope(params=None):
+    net = nn.Sequential(params=params)
+    with net.name_scope():
+        net.add(nn.Dense(4, activation='relu', flatten=False))
+    return net
 
 def _get_test_data():
     batch_size = 4
@@ -58,7 +63,7 @@ def test_fit():
     trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
     est = Estimator(net=net,
                     loss=loss,
-                    metrics=acc,
+                    train_metrics=acc,
                     trainer=trainer,
                     context=ctx)
 
@@ -88,7 +93,7 @@ def test_validation():
     trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
     est = Estimator(net=net,
                     loss=loss,
-                    metrics=acc,
+                    train_metrics=acc,
                     trainer=trainer,
                     context=ctx,
                     evaluation_loss=evaluation_loss)
@@ -100,8 +105,7 @@ def test_validation():
     # using validation handler
     train_metrics = est.train_metrics
     val_metrics = est.val_metrics
-    validation_handler = ValidationHandler(val_data=dataloader, eval_fn=est.evaluate,
-                                           val_metrics=val_metrics)
+    validation_handler = ValidationHandler(val_data=dataloader, eval_fn=est.evaluate)
 
     with assert_raises(ValueError):
         est.fit(train_data=dataiter,
@@ -127,7 +131,7 @@ def test_initializer():
     # no initializer
     est = Estimator(net=net,
                     loss=loss,
-                    metrics=acc,
+                    train_metrics=acc,
                     context=ctx)
     est.fit(train_data=train_data,
             epochs=num_epochs)
@@ -140,7 +144,7 @@ def test_initializer():
     with warnings.catch_warnings(record=True) as w:
         est = Estimator(net=net,
                         loss=loss,
-                        metrics=acc,
+                        train_metrics=acc,
                         initializer=mx.init.MSRAPrelu(),
                         trainer=trainer,
                         context=ctx)
@@ -148,7 +152,7 @@ def test_initializer():
     # net partially initialized, fine tuning use case
     net = gluon.model_zoo.vision.resnet18_v1(pretrained=True, ctx=ctx)
     net.output = gluon.nn.Dense(10) #last layer not initialized
-    est = Estimator(net, loss=loss, metrics=acc, context=ctx)
+    est = Estimator(net, loss=loss, train_metrics=acc, context=ctx)
     dataset =  gluon.data.ArrayDataset(mx.nd.zeros((10, 3, 224, 224)), mx.nd.zeros((10, 10)))
     train_data = gluon.data.DataLoader(dataset=dataset, batch_size=5)
     est.fit(train_data=train_data,
@@ -170,7 +174,7 @@ def test_trainer():
     with warnings.catch_warnings(record=True) as w:
         est = Estimator(net=net,
                         loss=loss,
-                        metrics=acc,
+                        train_metrics=acc,
                         context=ctx)
         assert 'No trainer specified' in str(w[-1].message)
     est.fit(train_data=train_data,
@@ -181,7 +185,7 @@ def test_trainer():
     with assert_raises(ValueError):
         est = Estimator(net=net,
                         loss=loss,
-                        metrics=acc,
+                        train_metrics=acc,
                         trainer=trainer,
                         context=ctx)
 
@@ -207,7 +211,7 @@ def test_metric():
     metrics = [mx.metric.Accuracy(), mx.metric.Accuracy()]
     est = Estimator(net=net,
                     loss=loss,
-                    metrics=metrics,
+                    train_metrics=metrics,
                     trainer=trainer,
                     context=ctx)
     est.fit(train_data=train_data,
@@ -216,7 +220,7 @@ def test_metric():
     with assert_raises(ValueError):
         est = Estimator(net=net,
                         loss=loss,
-                        metrics='acc',
+                        train_metrics='acc',
                         trainer=trainer,
                         context=ctx)
     # test default metric
@@ -239,7 +243,7 @@ def test_loss():
     with assert_raises(ValueError):
         est = Estimator(net=net,
                         loss='mse',
-                        metrics=acc,
+                        train_metrics=acc,
                         trainer=trainer,
                         context=ctx)
 
@@ -252,26 +256,26 @@ def test_context():
     # input no context
     est = Estimator(net=net,
                     loss=loss,
-                    metrics=metrics)
+                    train_metrics=metrics)
     # input list of context
     gpus = mx.context.num_gpus()
     ctx = [mx.gpu(i) for i in range(gpus)] if gpus > 0 else [mx.cpu()]
     net = _get_test_network()
     est = Estimator(net=net,
                     loss=loss,
-                    metrics=metrics,
+                    train_metrics=metrics,
                     context=ctx)
     # input invalid context
     with assert_raises(ValueError):
         est = Estimator(net=net,
                         loss=loss,
-                        metrics=metrics,
+                        train_metrics=metrics,
                         context='cpu')
 
     with assert_raises(AssertionError):
         est = Estimator(net=net,
                         loss=loss,
-                        metrics=metrics,
+                        train_metrics=metrics,
                         context=[mx.gpu(0), mx.gpu(100)])
 
 
@@ -336,7 +340,7 @@ def test_default_handlers():
 
     est = Estimator(net=net,
                     loss=loss,
-                    metrics=train_acc,
+                    train_metrics=train_acc,
                     trainer=trainer,
                     context=ctx)
     # no handler(all default handlers), no warning
@@ -347,18 +351,18 @@ def test_default_handlers():
     # use mix of default and user defined handlers
     train_metrics = est.train_metrics
     val_metrics = est.val_metrics
-    logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics)
+    logging = LoggingHandler(metrics=train_metrics)
     est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging])
 
     # handler with all user defined metrics
     # use mix of default and user defined handlers
-    metric = MetricHandler(train_metrics=[train_acc])
-    logging = LoggingHandler(train_metrics=[train_acc])
+    metric = MetricHandler(metrics=[train_acc])
+    logging = LoggingHandler(metrics=[train_acc])
     est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[metric, logging])
 
     # handler with mixed metrics, some handler use metrics prepared by estimator
     # some handler use metrics user prepared
-    logging = LoggingHandler(train_metrics=train_metrics, val_metrics=[mx.metric.RMSE("val acc")])
+    logging = LoggingHandler(metrics=[mx.metric.RMSE("val acc")])
     with assert_raises(ValueError):
         est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging])
 
@@ -371,3 +375,95 @@ def test_default_handlers():
     assert isinstance(handlers[0], GradientUpdateHandler)
     assert isinstance(handlers[1], MetricHandler)
     assert isinstance(handlers[4], LoggingHandler)
+
+def test_eval_net():
+    ''' test estimator with a different evaluation net '''
+    ''' test weight sharing of sequential networks without namescope '''
+    net = _get_test_network()
+    eval_net = _get_test_network(params=net.collect_params())
+    dataloader, dataiter = _get_test_data()
+    num_epochs = 1
+    ctx = mx.cpu()
+    loss = gluon.loss.L2Loss()
+    evaluation_loss = gluon.loss.L2Loss()
+    acc = mx.metric.Accuracy()
+    net.initialize(ctx=ctx)
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+    est = Estimator(net=net,
+                    loss=loss,
+                    train_metrics=acc,
+                    trainer=trainer,
+                    context=ctx,
+                    evaluation_loss=evaluation_loss,
+                    eval_net=eval_net)
+
+    with assert_raises(RuntimeError):
+        est.fit(train_data=dataloader,
+                val_data=dataloader,
+                epochs=num_epochs)
+
+    ''' test weight sharing of sequential networks with namescope '''
+    net = _get_test_network_with_namescope()
+    eval_net = _get_test_network_with_namescope(params=net.collect_params())
+    net.initialize(ctx=ctx)
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+    est = Estimator(net=net,
+                    loss=loss,
+                    train_metrics=acc,
+                    trainer=trainer,
+                    context=ctx,
+                    evaluation_loss=evaluation_loss,
+                    eval_net=eval_net)
+
+    est.fit(train_data=dataloader,
+            val_data=dataloader,
+            epochs=num_epochs)
+
+    ''' test weight sharing of two resnets '''
+    net = gluon.model_zoo.vision.resnet18_v1(pretrained=False, ctx=ctx)
+    net.output = gluon.nn.Dense(10)
+    eval_net = gluon.model_zoo.vision.resnet18_v1(pretrained=False, ctx=ctx)
+    eval_net.output = gluon.nn.Dense(10, params=net.collect_params())
+    dataset = gluon.data.ArrayDataset(mx.nd.zeros((10, 3, 224, 224)), mx.nd.zeros((10, 10)))
+    dataloader = gluon.data.DataLoader(dataset=dataset, batch_size=5)
+    net.initialize(ctx=ctx)
+    eval_net.initialize(ctx=ctx)
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+    est = Estimator(net=net,
+                    loss=loss,
+                    train_metrics=acc,
+                    trainer=trainer,
+                    context=ctx,
+                    evaluation_loss=evaluation_loss,
+                    eval_net=eval_net)
+
+    est.fit(train_data=dataloader,
+            val_data=dataloader,
+            epochs=num_epochs)
+
+def test_val_handlers():
+    net = _get_test_network()
+    train_data, _ = _get_test_data()
+    val_data, _ = _get_test_data()
+
+    num_epochs = 1
+    ctx = mx.cpu()
+    net.initialize(ctx=ctx)
+    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
+
+    train_acc = mx.metric.RMSE()
+    loss = gluon.loss.L2Loss()
+
+    est = Estimator(net=net,
+                    loss=loss,
+                    train_metrics=train_acc,
+                    trainer=trainer,
+                    context=ctx)
+
+    with warnings.catch_warnings(record=True) as w:
+        est.fit(train_data=train_data, epochs=num_epochs)
+        est.evaluate(val_data=val_data)
+
+    logging = LoggingHandler(log_interval=1, metrics=est.val_metrics)
+    est.evaluate(val_data=val_data, event_handlers=[logging])
+
diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py
index 658fb88..41b7901 100644
--- a/tests/python/unittest/test_gluon_event_handler.py
+++ b/tests/python/unittest/test_gluon_event_handler.py
@@ -54,7 +54,7 @@ def test_checkpoint_handler():
         net = _get_test_network()
         ce_loss = loss.SoftmaxCrossEntropyLoss()
         acc = mx.metric.Accuracy()
-        est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+        est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
         checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
                                                              model_prefix=model_prefix,
                                                              monitor=acc,
@@ -72,7 +72,7 @@ def test_checkpoint_handler():
         file_path = os.path.join(tmpdir, model_prefix)
         net = _get_test_network(nn.HybridSequential())
         net.hybridize()
-        est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+        est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
         checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
                                                              model_prefix=model_prefix,
                                                              epoch_period=None,
@@ -100,7 +100,7 @@ def test_resume_checkpoint():
         net = _get_test_network()
         ce_loss = loss.SoftmaxCrossEntropyLoss()
         acc = mx.metric.Accuracy()
-        est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+        est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
         checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
                                                              model_prefix=model_prefix,
                                                              monitor=acc,
@@ -125,7 +125,7 @@ def test_early_stopping():
     net = _get_test_network()
     ce_loss = loss.SoftmaxCrossEntropyLoss()
     acc = mx.metric.Accuracy()
-    est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+    est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
     early_stopping = event_handler.EarlyStoppingHandler(monitor=acc,
                                                         patience=0,
                                                         mode='min')
@@ -149,14 +149,13 @@ def test_logging():
         net = _get_test_network()
         ce_loss = loss.SoftmaxCrossEntropyLoss()
         acc = mx.metric.Accuracy()
-        est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+        est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
 
         est.logger.addHandler(logging.FileHandler(output_dir))
 
         train_metrics = est.train_metrics
         val_metrics = est.val_metrics
-        logging_handler = event_handler.LoggingHandler(train_metrics=train_metrics,
-                                                       val_metrics=val_metrics)
+        logging_handler = event_handler.LoggingHandler(metrics=train_metrics)
         est.fit(test_data, event_handlers=[logging_handler], epochs=3)
         assert logging_handler.batch_index == 0
         assert logging_handler.current_epoch == 3
@@ -197,7 +196,7 @@ def test_custom_handler():
     net = _get_test_network()
     ce_loss = loss.SoftmaxCrossEntropyLoss()
     acc = mx.metric.Accuracy()
-    est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
+    est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
     custom_handler = CustomStopHandler(3, 2)
     est.fit(test_data, event_handlers=[custom_handler], epochs=3)
     assert custom_handler.num_batch == 3
@@ -220,10 +219,10 @@ def test_logging_interval():
     num_epochs = 1
     ce_loss = loss.SoftmaxCrossEntropyLoss()
     acc = mx.metric.Accuracy()
-    logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval)
+    logging = LoggingHandler(metrics=[acc], log_interval=log_interval)
     est = estimator.Estimator(net=net,
                               loss=ce_loss,
-                              metrics=acc)
+                              train_metrics=acc)
 
     est.fit(train_data=dataloader,
             epochs=num_epochs,
@@ -245,10 +244,10 @@ def test_logging_interval():
     sys.stdout = mystdout = StringIO()
     acc = mx.metric.Accuracy()
     log_interval = 5
-    logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval)
+    logging = LoggingHandler(metrics=[acc], log_interval=log_interval)
     est = estimator.Estimator(net=net,
                               loss=ce_loss,
-                              metrics=acc)
+                              train_metrics=acc)
     est.fit(train_data=dataloader,
             epochs=num_epochs,
             event_handlers=[logging])