You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2019/12/10 17:27:06 UTC

[incubator-mxnet] branch v1.6.x updated: Backport #16895, #16922, #16878, #16979 and #16900 to 1.6 (#17029)

This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch v1.6.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.6.x by this push:
     new c973f01  Backport #16895, #16922, #16878, #16979 and #16900 to 1.6 (#17029)
c973f01 is described below

commit c973f01f95ad43dc3f93b322f00c7c8bf1648f0d
Author: Przemyslaw Tredak <pt...@nvidia.com>
AuthorDate: Tue Dec 10 09:26:26 2019 -0800

    Backport #16895, #16922, #16878, #16979 and #16900 to 1.6 (#17029)
    
    * Fix ndarray indexing bug (#16895)
    
    * Fix indexing bug
    
    * More test cases
    
    * Add test from 16647
    
    * [Gluon] Update contrib.Estimator LoggingHandler to support logging per batch interval (#16922)
    
    * Update LoggingHandler to support logging per interval
    
    * Fix the constant variable issue in the logging handler
    
    * Remove the constant variable hack in the logging handler.
    
    * 1) replace LOG_PER_BATCH with LOG_PER_INTERVAL 2) add test case
    
    * Improve the test script for LoggingHandler
    
    * small fix on the test script
    
    * logging handler test case bug fix
    
    * remove parameter verbose from LoggingHandler
    
    * move log_interval to the first argument
    
    * resolve unittest mistakes
    
    * Add micro averaging strategy to pearsonr metric (#16878)
    
            Strategy to be used for aggregating across mini-batches.
                "macro": average the pearsonr scores for each batch.
                "micro": compute a single pearsonr score across all batches.
    
    * [Bugfix] [Numpy] Add `kAddTo` and kNullOp to Transpose (#16979)
    
    * update
    
    Check for repeated axes
    
    enable addto to transpose
    
    fix
    
    fix
    
    fix
    
    fix
    
    remove unused ndim
    
    Update pseudo2DTranspose_op-inl.cuh
    
    Update pseudo2DTranspose_op-inl.cuh
    
    Update pseudo2DTranspose_op-inl.cuh
    
    fix
    
    Update pseudo2DTranspose_op-inl.cuh
    
    try to fix
    
    Update pseudo2DTranspose_op-inl.cuh
    
    Update pseudo2DTranspose_op-inl.cuh
    
    Update pseudo2DTranspose_op-inl.cuh
    
    fix
    
    Update np_matrix_op.cc
    
    Update test_numpy_op.py
    
    update test case
    
    fix implementation
    
    fix bug
    
    update
    
    fix bug
    
    Update pseudo2DTranspose_op-inl.cuh
    
    fix
    
    fix
    
    Update test_numpy_op.py
    
    * Fix bug
    
    * fix docstring
    
    * try to address comment
    
    * no need to change this line
    
    * Fix bug
    
    * address comments
    
    * address comment
    
    * introduce  gradient update handler to the  base estimator (#16900)
    
    * introduce  gradient update handler to the  base estimator
    
    * Modify the gradient update handler to include the batch size
    
    * Remove unrelated gradient update handler.
    
    * Modify gradient update handler to take the current batch size.
    
    * Remove white space to avoid the sanity check failure
    
    * add small tweak to the handler code
    
    * Modify the documentation of priority parameter of relevant handlers.
    
    * small modification on the documentation.
    
    * Add small modification on the documentation.
    
    * Remove unnecessary list check
---
 python/mxnet/gluon/contrib/estimator/estimator.py  |   8 +-
 .../mxnet/gluon/contrib/estimator/event_handler.py | 100 +++++++++----
 python/mxnet/metric.py                             |  79 ++++++++--
 python/mxnet/ndarray/ndarray.py                    |  42 +++---
 src/ndarray/ndarray.cc                             |  10 +-
 src/operator/numpy/np_matrix_op-inl.h              |  16 ++-
 src/operator/numpy/np_matrix_op.cc                 |  12 +-
 src/operator/tensor/matrix_op-inl.h                | 106 +++++++++-----
 src/operator/tensor/matrix_op.cc                   |   5 +-
 src/operator/tensor/pseudo2DTranspose_op-inl.cuh   | 160 ++++++++-------------
 tests/python/unittest/test_gluon_estimator.py      |   7 +-
 tests/python/unittest/test_gluon_event_handler.py  |  72 +++++++++-
 tests/python/unittest/test_metric.py               |  42 +++++-
 tests/python/unittest/test_ndarray.py              |  16 +++
 tests/python/unittest/test_numpy_ndarray.py        |  63 +++++---
 tests/python/unittest/test_numpy_op.py             |  84 ++++++-----
 16 files changed, 539 insertions(+), 283 deletions(-)

diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py
index 54a0b16..ab7018f 100644
--- a/python/mxnet/gluon/contrib/estimator/estimator.py
+++ b/python/mxnet/gluon/contrib/estimator/estimator.py
@@ -24,7 +24,7 @@ import logging
 import sys
 import warnings
 
-from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler
+from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler, GradientUpdateHandler
 from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd
 from .event_handler import _check_event_handlers
 from .utils import _check_metrics, _suggest_metric_for_loss, _check_handler_metric_ref
@@ -307,8 +307,6 @@ class Estimator(object):
         for l in loss:
             l.backward()
 
-        self.trainer.step(batch_size)
-
         return data, label, pred, loss
 
     def fit(self, train_data,
@@ -360,6 +358,7 @@ class Estimator(object):
 
         self.max_epoch = epochs
         self.max_batch = batches
+        self.batch_axis = batch_axis
 
         # provide default handlers
         event_handlers = self._prepare_default_handlers(val_data, event_handlers)
@@ -414,6 +413,9 @@ class Estimator(object):
         # no need to add to default handler check as StoppingHandler does not use metrics
         added_default_handlers.append(StoppingHandler(self.max_epoch, self.max_batch))
 
+        if not any(isinstance(handler, GradientUpdateHandler) for handler in event_handlers):
+            added_default_handlers.append(GradientUpdateHandler())
+
         if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
             added_default_handlers.append(MetricHandler(train_metrics=self.train_metrics))
 
diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py
index 3cdc407..6477760 100644
--- a/python/mxnet/gluon/contrib/estimator/event_handler.py
+++ b/python/mxnet/gluon/contrib/estimator/event_handler.py
@@ -31,7 +31,7 @@ from .utils import _check_metrics
 
 __all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd', 'BatchBegin', 'BatchEnd',
            'StoppingHandler', 'MetricHandler', 'ValidationHandler',
-           'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler']
+           'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler', 'GradientUpdateHandler']
 
 
 class EventHandler(object):
@@ -130,13 +130,16 @@ class MetricHandler(EpochBegin, BatchEnd):
     ----------
     train_metrics : List of EvalMetrics
         Training metrics to be updated at batch end.
+    priority : scalar
+        Priority level of the MetricHandler. Priority level is sorted in ascending
+        order. The lower the number is, the higher priority level the handler is.
     """
 
-    def __init__(self, train_metrics):
+    def __init__(self, train_metrics, priority=-1000):
         self.train_metrics = _check_metrics(train_metrics)
         # order to be called among all callbacks
         # metrics need to be calculated before other callbacks can access them
-        self.priority = -np.Inf
+        self.priority = priority
 
     def epoch_begin(self, estimator, *args, **kwargs):
         for metric in self.train_metrics:
@@ -176,6 +179,10 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
     batch_period : int, default None
         How often to run validation at batch end, by default
         :py:class:`ValidationHandler` does not validate at batch end.
+    priority: scalar, default -1000
+        Priority level of the ValidationHandler. Priority level is sorted in
+        ascending order. The lower the number is, the higher priority level the
+        handler is.
     """
 
     def __init__(self,
@@ -183,7 +190,8 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
                  eval_fn,
                  val_metrics=None,
                  epoch_period=1,
-                 batch_period=None):
+                 batch_period=None,
+                 priority=-1000):
         self.val_data = val_data
         self.eval_fn = eval_fn
         self.epoch_period = epoch_period
@@ -193,7 +201,7 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
         self.current_epoch = 0
         # order to be called among all callbacks
         # validation metrics need to be calculated before other callbacks can access them
-        self.priority = -np.Inf
+        self.priority = priority
 
     def train_begin(self, estimator, *args, **kwargs):
         # reset epoch and batch counter
@@ -227,29 +235,27 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
 
     Parameters
     ----------
-    verbose : int, default LOG_PER_EPOCH
-        Limit the granularity of metrics displayed during training process.
-        verbose=LOG_PER_EPOCH: display metrics every epoch
-        verbose=LOG_PER_BATCH: display metrics every batch
+    log_interval: int or str, default 'epoch'
+        Logging interval during training.
+        log_interval='epoch': display metrics every epoch
+        log_interval=integer k: display metrics every interval of k batches
     train_metrics : list of EvalMetrics
         Training metrics to be logged, logged at batch end, epoch end, train end.
     val_metrics : list of EvalMetrics
         Validation metrics to be logged, logged at epoch end, train end.
+    priority : scalar, default np.Inf
+        Priority level of the LoggingHandler. Priority level is sorted in
+        ascending order. The lower the number is, the higher priority level the
+        handler is.
     """
 
-    LOG_PER_EPOCH = 1
-    LOG_PER_BATCH = 2
-
-    def __init__(self, verbose=LOG_PER_EPOCH,
+    def __init__(self, log_interval='epoch',
                  train_metrics=None,
-                 val_metrics=None):
+                 val_metrics=None,
+                 priority=np.Inf):
         super(LoggingHandler, self).__init__()
-        if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]:
-            raise ValueError("verbose level must be either LOG_PER_EPOCH or "
-                             "LOG_PER_BATCH, received %s. "
-                             "E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)"
-                             % verbose)
-        self.verbose = verbose
+        if not isinstance(log_interval, int) and log_interval != 'epoch':
+            raise ValueError("log_interval must be either an integer or string 'epoch'")
         self.train_metrics = _check_metrics(train_metrics)
         self.val_metrics = _check_metrics(val_metrics)
         self.batch_index = 0
@@ -257,7 +263,8 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
         self.processed_samples = 0
         # logging handler need to be called at last to make sure all states are updated
         # it will also shut down logging at train end
-        self.priority = np.Inf
+        self.priority = priority
+        self.log_interval = log_interval
 
     def train_begin(self, estimator, *args, **kwargs):
         self.train_start = time.time()
@@ -275,6 +282,7 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
         self.current_epoch = 0
         self.batch_index = 0
         self.processed_samples = 0
+        self.log_interval_time = 0
 
     def train_end(self, estimator, *args, **kwargs):
         train_time = time.time() - self.train_start
@@ -286,31 +294,34 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
         estimator.logger.info(msg.rstrip(', '))
 
     def batch_begin(self, estimator, *args, **kwargs):
-        if self.verbose == self.LOG_PER_BATCH:
+        if isinstance(self.log_interval, int):
             self.batch_start = time.time()
 
     def batch_end(self, estimator, *args, **kwargs):
-        if self.verbose == self.LOG_PER_BATCH:
+        if isinstance(self.log_interval, int):
             batch_time = time.time() - self.batch_start
             msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index)
             self.processed_samples += kwargs['batch'][0].shape[0]
             msg += '[Samples %s] ' % (self.processed_samples)
-            msg += 'time/batch: %.3fs ' % batch_time
-            for metric in self.train_metrics:
-                # only log current training loss & metric after each batch
-                name, value = metric.get()
-                msg += '%s: %.4f, ' % (name, value)
-            estimator.logger.info(msg.rstrip(', '))
+            self.log_interval_time += batch_time
+            if self.batch_index % self.log_interval == 0:
+                msg += 'time/interval: %.3fs ' % self.log_interval_time
+                self.log_interval_time = 0
+                for metric in self.train_metrics:
+                    # only log current training loss & metric after each interval
+                    name, value = metric.get()
+                    msg += '%s: %.4f, ' % (name, value)
+                estimator.logger.info(msg.rstrip(', '))
         self.batch_index += 1
 
     def epoch_begin(self, estimator, *args, **kwargs):
-        if self.verbose >= self.LOG_PER_EPOCH:
+        if isinstance(self.log_interval, int) or self.log_interval == 'epoch':
             self.epoch_start = time.time()
             estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f",
                                   self.current_epoch, estimator.trainer.learning_rate)
 
     def epoch_end(self, estimator, *args, **kwargs):
-        if self.verbose >= self.LOG_PER_EPOCH:
+        if isinstance(self.log_interval, int) or self.log_interval == 'epoch':
             epoch_time = time.time() - self.epoch_start
             msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time)
             for monitor in self.train_metrics + self.val_metrics:
@@ -706,3 +717,30 @@ class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd):
             estimator.logger.info('[Epoch %d] EarlyStoppingHanlder: '
                                   'early stopping due to %s not improving',
                                   self.stopped_epoch, self.monitor.get()[0])
+
+class GradientUpdateHandler(BatchEnd):
+    """Gradient Update Handler that apply gradients on network weights
+
+    :py:class:`GradientUpdateHandler` takes the priority level. It updates weight parameters
+    at the end of each batch
+
+    Parameters
+    ----------
+    priority : scalar, default -2000
+        priority level of the gradient update handler. Priority level is sorted in ascending
+        order. The lower the number is, the higher priority level the handler is.
+    ----------
+    """
+    def __init__(self, priority=-2000):
+        self.priority = priority
+
+    def batch_end(self, estimator, *args, **kwargs):
+        loss = kwargs['loss']
+        batch_size = 0
+        if not isinstance(loss, list):
+            loss = [loss]
+        if isinstance(loss, list):
+            for l in loss:
+                batch_size += l.shape[estimator.batch_axis]
+
+        estimator.trainer.step(batch_size)
diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index 6e2d66c..d1074c9 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -590,8 +590,9 @@ class TopKAccuracy(EvalMetric):
 
 
 class _BinaryClassificationMetrics(object):
-    """Private container class for classification metric statistics. True/false positive and
-     true/false negative counts are sufficient statistics for various classification metrics.
+    """Private container class for classification metric statistics.
+
+    True/false positive and true/false negative counts are sufficient statistics for various classification metrics.
     This class provides the machinery to track those statistics across mini-batches of
     (label, prediction) pairs.
     """
@@ -1430,6 +1431,10 @@ class PearsonCorrelation(EvalMetric):
     label_names : list of str, or None
         Name of labels that should be used when updating with update_dict.
         By default include all labels.
+    average : str, default 'macro'
+        Strategy to be used for aggregating across mini-batches.
+            "macro": average the pearsonr scores for each batch.
+            "micro": compute a single pearsonr score across all batches.
 
     Examples
     --------
@@ -1438,13 +1443,46 @@ class PearsonCorrelation(EvalMetric):
     >>> pr = mx.metric.PearsonCorrelation()
     >>> pr.update(labels, predicts)
     >>> print pr.get()
-    ('pearson-correlation', 0.42163704544016178)
+    ('pearsonr', 0.42163704544016178)
     """
     def __init__(self, name='pearsonr',
-                 output_names=None, label_names=None):
+                 output_names=None, label_names=None, average='macro'):
+        self.average = average
         super(PearsonCorrelation, self).__init__(
             name, output_names=output_names, label_names=label_names,
             has_global_stats=True)
+        if self.average == 'micro':
+            self.reset_micro()
+
+    def reset_micro(self):
+        self._sse_p = 0
+        self._mean_p = 0
+        self._sse_l = 0
+        self._mean_l = 0
+        self._pred_nums = 0
+        self._label_nums = 0
+        self._conv = 0
+
+    def reset(self):
+        self.num_inst = 0
+        self.sum_metric = 0.0
+        self.global_num_inst = 0
+        self.global_sum_metric = 0.0
+        if self.average == 'micro':
+            self.reset_micro()
+
+    def update_variance(self, new_values, *aggregate):
+        #Welford's online algorithm for variance update
+        count, mean, m_2 = aggregate
+        count += len(new_values)
+        delta = new_values - mean
+        mean += numpy.sum(delta / count)
+        delta_2 = new_values - mean
+        m_2 += numpy.sum(delta * delta_2)
+        return count, mean, m_2
+
+    def update_cov(self, label, pred):
+        self._conv = self._conv + numpy.sum((label - self._mean_l) * (pred - self._mean_p))
 
     def update(self, labels, preds):
         """Updates the internal evaluation result.
@@ -1457,17 +1495,34 @@ class PearsonCorrelation(EvalMetric):
             Predicted values.
         """
         labels, preds = check_label_shapes(labels, preds, True)
-
         for label, pred in zip(labels, preds):
             check_label_shapes(label, pred, False, True)
-            label = label.asnumpy()
-            pred = pred.asnumpy()
-            pearson_corr = numpy.corrcoef(pred.ravel(), label.ravel())[0, 1]
-            self.sum_metric += pearson_corr
-            self.global_sum_metric += pearson_corr
-            self.num_inst += 1
-            self.global_num_inst += 1
+            label = label.asnumpy().ravel().astype(numpy.float64)
+            pred = pred.asnumpy().ravel().astype(numpy.float64)
+            if self.average == 'macro':
+                pearson_corr = numpy.corrcoef(pred, label)[0, 1]
+                self.sum_metric += pearson_corr
+                self.global_sum_metric += pearson_corr
+                self.num_inst += 1
+                self.global_num_inst += 1
+            else:
+                self.global_num_inst += 1
+                self.num_inst += 1
+                self._label_nums, self._mean_l, self._sse_l = \
+                    self.update_variance(label, self._label_nums, self._mean_l, self._sse_l)
+                self.update_cov(label, pred)
+                self._pred_nums, self._mean_p, self._sse_p = \
+                    self.update_variance(pred, self._pred_nums, self._mean_p, self._sse_p)
 
+    def get(self):
+        if self.num_inst == 0:
+            return (self.name, float('nan'))
+        if self.average == 'macro':
+            return (self.name, self.sum_metric / self.num_inst)
+        else:
+            n = self._label_nums
+            pearsonr = self._conv / ((n-1) * numpy.sqrt(self._sse_p / (n - 1)) * numpy.sqrt(self._sse_l / (n - 1)))
+            return (self.name, pearsonr)
 
 @register
 class PCC(EvalMetric):
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index a7ad8e6..bc1bbe4 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -905,6 +905,17 @@ fixed-size items.
         return flat_begin, flat_end + 1
     # pylint: enable=invalid-name
 
+    @staticmethod
+    def _drop_int_axes(indexed_shape, int_axes):
+        """drop the axis of indexed_shape corresponding to int axes"""
+        bcast_shape = []
+        for i, size in enumerate(indexed_shape):
+            if i not in int_axes:
+                bcast_shape.append(size)
+        if not bcast_shape:
+            bcast_shape = [1]
+        return tuple(bcast_shape)
+
     def _set_nd_basic_indexing(self, key, value):
         """This function indexes ``self`` with a tuple of ``slice`` objects only."""
         for idx in key:
@@ -946,14 +957,10 @@ fixed-size items.
             if type(value) == self.__class__:  # pylint: disable=unidiomatic-typecheck
                 if value.handle is not self.handle:
                     # Need to do this before `broadcast_to`.
-                    tmp_shape = _shape_for_bcast(
-                        value.shape, target_ndim=self.ndim, new_axes=int_axes
-                    )
-                    value = value.reshape(tmp_shape)
-
-                    if value.shape != self.shape:
-                        value = value.broadcast_to(self.shape)
-                    value.copyto(self)
+                    bcast_shape = self._drop_int_axes(indexed_shape, int_axes)
+                    value_nd = self._prepare_value_nd(value, bcast_shape=bcast_shape, squeeze_axes=new_axes)
+                    value_nd = value_nd.reshape(indexed_shape)
+                    value_nd.copyto(self)
 
             elif isinstance(value, numeric_types):
                 self._full(value)
@@ -969,9 +976,10 @@ fixed-size items.
 
             else:
                 # Other array-like
-                value_nd = self._prepare_value_nd(
-                    value, bcast_shape=self.shape
-                )
+                # drop the axis of indexed_shape corresponding to int axes
+                bcast_shape = self._drop_int_axes(indexed_shape, int_axes)
+                value_nd = self._prepare_value_nd(value, bcast_shape=bcast_shape, squeeze_axes=new_axes)
+                value_nd = value_nd.reshape(indexed_shape)
                 value_nd.copyto(self)
 
         elif isinstance(value, numeric_types):
@@ -979,16 +987,8 @@ fixed-size items.
 
         else:
             # drop the axis of indexed_shape corresponding to int axes
-            bcast_shape = []
-            for i, size in enumerate(indexed_shape):
-                if i not in int_axes:
-                    bcast_shape.append(size)
-            if bcast_shape == []:
-                bcast_shape = [1]
-            bcast_shape = tuple(bcast_shape)
-            value_nd = self._prepare_value_nd(
-                value, bcast_shape=bcast_shape, squeeze_axes=new_axes
-            )
+            bcast_shape = self._drop_int_axes(indexed_shape, int_axes)
+            value_nd = self._prepare_value_nd(value, bcast_shape=bcast_shape, squeeze_axes=new_axes)
             value_nd = value_nd.reshape(indexed_shape)
             self.slice_assign(value_nd, begin, end, step)
 
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 9375bed..ba3c334 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -243,8 +243,14 @@ NDArray NDArray::MKLDNNDataReshape(const mxnet::TShape &shape) const {
 
 NDArray NDArray::Reshape(const mxnet::TShape &shape) const {
   CHECK(!is_none()) << "NDArray is not initialized";
-  CHECK_GE(shape_.Size(), shape.Size())
-    << "NDArray.Reshape: target shape size is larger current shape";
+  if (Imperative::Get()->is_np_shape()) {
+    CHECK_EQ(shape_.Size(), shape.Size())
+        << "NDArray.Reshape: target shape must have the same size as "
+        << "current shape.";
+  } else {
+    CHECK_GE(shape_.Size(), shape.Size())
+        << "NDArray.Reshape: target shape size is larger than the current shape";
+  }
   NDArray ret = this->Detach();
   // If the shape doesn't change, we can just return it now.
   if (ret.shape_ == shape)
diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h
index a9828f4..42cc99a 100644
--- a/src/operator/numpy/np_matrix_op-inl.h
+++ b/src/operator/numpy/np_matrix_op-inl.h
@@ -119,16 +119,22 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs,
                     const std::vector<OpReqType>& req,
                     const std::vector<TBlob>& outputs) {
   const NumpyTransposeParam& param = nnvm::get<NumpyTransposeParam>(attrs.parsed);
-  CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace";
+  if (req[0] == kNullOp) return;
+  CHECK(req[0] == kWriteTo || req[0] == kAddTo)
+      << "Transpose only supports kWriteTo, kNullOp and kAddTo";
+  mxnet::TShape axes;
   if (ndim_is_known(param.axes)) {
-    mxnet::TShape axes = common::CanonicalizeAxes(param.axes);
-    TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
+    axes = common::CanonicalizeAxes(param.axes);
   } else {
-    mxnet::TShape axes(inputs[0].ndim(), -1);
+    axes = mxnet::TShape(inputs[0].ndim(), -1);
     for (int i = 0; i < axes.ndim(); ++i) {
       axes[i] = axes.ndim() - 1 - i;
     }
-    TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
+  }
+  if (req[0] == kAddTo) {
+    TransposeImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0], axes);
+  } else {
+    TransposeImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0], axes);
   }
 }
 
diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc
index 3967cde..41d8d02 100644
--- a/src/operator/numpy/np_matrix_op.cc
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -24,6 +24,7 @@
  */
 
 #include <vector>
+#include <set>
 #include "./np_matrix_op-inl.h"
 #include "../nn/concat-inl.h"
 
@@ -65,8 +66,13 @@ bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs,
   mxnet::TShape ret(ndim, -1);
 
   if (ndim_is_known(param.axes)) {
-    CHECK_EQ(ndim, param.axes.ndim());
+    CHECK_EQ(ndim, param.axes.ndim())
+        << "The number of axes does not match the dimension of the tensor. axes = "
+        << param.axes << ", input tensor shape = " << shp;
     mxnet::TShape axes = common::CanonicalizeAxes(param.axes);
+    std::set<dim_t> axes_set(axes.begin(), axes.end());
+    CHECK_EQ(axes_set.size(), axes.ndim()) << "Repeated axis in transpose. param.axes = "
+                                                << param.axes;
     if (ndim_is_known(shp)) {
       for (int i = 0; i < ndim; ++i) {
         ret[i] = shp[axes[i]];
@@ -115,9 +121,9 @@ NNVM_REGISTER_OP(_np_transpose)
       }
       std::ostringstream os;
       os << axes;
-      return MakeNonlossGradNode("transpose", n, ograds, {}, {{"axes", os.str()}});
+      return MakeNonlossGradNode("_np_transpose", n, ograds, {}, {{"axes", os.str()}});
     } else {
-      return MakeNonlossGradNode("transpose", n, ograds, {},
+      return MakeNonlossGradNode("_np_transpose", n, ograds, {},
                                  std::unordered_map<std::string, std::string>());
     }
   })
diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index 0fee2a2..4bd059a 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -269,8 +269,10 @@ struct TransposeParam : public dmlc::Parameter<TransposeParam> {
  * \param out output tensor
  * \param row shape of dim 0 of input
  * \param col shape of dim 1 of input
+ * \tparam DType Data type
+ * \tparam is_addto
  */
-template<typename DType>
+template<typename DType, bool is_addto>
 MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index_t col) {
   // ensure cache line hits and prevent cache miss for any configuration
   // L1 cache size to be utilized = 32kb = 2^15
@@ -282,7 +284,7 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index
   // Block-size - 2^5 v 2^5 (32 v 32) with potential 4 pragma for loop unrolled
   // blocksize * blocksize * num_threads = cache_size / dtype_size
   // Instead of explicit unroll, let compiler figure out optimal unroll factor
-  index_t blocksize = 32;
+  const index_t blocksize = 32;
 
   // collapse 2 parallelizes 2 for loops
   // inner 2 for loops aren't parallelized to prevent cache miss
@@ -299,14 +301,25 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index
       // transpose the block
       for (index_t a = j; (a < blocksize + j) && (a < col); ++a) {
         for (index_t b = i; (b < blocksize + i) && (b < row); ++b) {
-          out[a * row + b] = in[b * col + a];
+          if (!is_addto) {
+            out[a * row + b] = in[b * col + a];
+          } else {
+            out[a * row + b] += in[b * col + a];
+          }
         }
       }
     }
   }
 }
 
-template<typename xpu>
+inline bool IsIdentityTranspose(const TShape& axes) {
+  for (dim_t i = 0; i < axes.ndim(); i++) {
+    if (axes[i] != i) return false;
+  }
+  return true;
+}
+
+template<typename xpu, bool is_addto = false>
 void TransposeImpl(RunContext ctx,
                    const TBlob& src,
                    const TBlob& ret,
@@ -323,62 +336,79 @@ void TransposeImpl(RunContext ctx,
   // Example: (0, 2, 3, 1) or (0, 3, 1, 2), but not (0, 2, 1, 3).
   if (isPseudo2DTranspose(axes)) {
     MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
-      transpose_pseudo2D<DType>(ret, src, axes, s);
+      transpose_pseudo2D<DType, is_addto>(ret, src, axes, s);
     });
     return;
   }
 #endif
+  // Special handle the identity case
+  if (IsIdentityTranspose(axes)) {
+    MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
+      Tensor<xpu, 1, DType> in = src.get_with_shape<xpu, 1, DType>(mshadow::Shape1(src.Size()), s);
+      Tensor<xpu, 1, DType> out = ret.get_with_shape<xpu, 1, DType>(mshadow::Shape1(ret.Size()), s);
+      if (!is_addto) {
+        // Use memcpy to accelerate the speed
+        Copy(out, in, s);
+      } else {
+        mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, kAddTo>, xpu>::Launch(
+            s, ret.Size(), out.dptr_, in.dptr_);
+      }
+    });
+    return;
+  }
+  // Handle the general transpose case
   MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
     switch (axes.ndim()) {
-     case 0: {
-      Tensor<xpu, 1, DType> in = src.get_with_shape<xpu, 1, DType>(mshadow::Shape1(1), s);
-      Tensor<xpu, 1, DType> out = ret.get_with_shape<xpu, 1, DType>(mshadow::Shape1(1), s);
-      Copy(out, in, s);
-      break;
-     }
-     case 1: {
-      Tensor<xpu, 1, DType> in = src.get<xpu, 1, DType>(s);
-      Tensor<xpu, 1, DType> out = ret.get<xpu, 1, DType>(s);
-      Copy(out, in, s);
-      break;
-     }
      case 2: {
-      mshadow::Tensor<xpu, 2, DType> in = src.FlatTo2D<xpu, DType>(s);
-      mshadow::Tensor<xpu, 2, DType> out = ret.FlatTo2D<xpu, DType>(s);
-
-      if (axes[0] == 1 && axes[1] == 0) {
-        if (ctx.get_ctx().dev_mask() == cpu::kDevMask) {
-          Transpose2D<DType>(in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]);
-        } else {
-          out = in.T();
-        }
+      Tensor<xpu, 2, DType> in = src.get<xpu, 2, DType>(s);
+      Tensor<xpu, 2, DType> out = ret.get<xpu, 2, DType>(s);
+      if (ctx.get_ctx().dev_mask() == cpu::kDevMask) {
+        Transpose2D<DType, is_addto>(in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]);
       } else {
-        Copy(out, in, s);
+        LOG(FATAL) << "Not Implemented. We should never reach here because the 2D case "
+                      "in GPU has been covered by transpose_pseudo2D."
+                      " Report an issue in Github.";
       }
       break;
      }
      case 3: {
       Tensor<xpu, 3, DType> in = src.get<xpu, 3, DType>(s);
       Tensor<xpu, 3, DType> out = ret.get<xpu, 3, DType>(s);
-      out = transpose(in, axes.get<3>());
+      if (!is_addto) {
+        out = transpose(in, axes.get<3>());
+      } else {
+        out += transpose(in, axes.get<3>());
+      }
       break;
      }
      case 4: {
       Tensor<xpu, 4, DType> in = src.get<xpu, 4, DType>(s);
       Tensor<xpu, 4, DType> out = ret.get<xpu, 4, DType>(s);
-      out = transpose(in, axes.get<4>());
+      if (!is_addto) {
+        out = transpose(in, axes.get<4>());
+      } else {
+        out += transpose(in, axes.get<4>());
+      }
       break;
      }
      case 5: {
       Tensor<xpu, 5, DType> in = src.get<xpu, 5, DType>(s);
       Tensor<xpu, 5, DType> out = ret.get<xpu, 5, DType>(s);
-      out = transpose(in, axes.get<5>());
+      if (!is_addto) {
+        out = transpose(in, axes.get<5>());
+      } else {
+        out += transpose(in, axes.get<5>());
+      }
       break;
      }
      case 6: {
       Tensor<xpu, 6, DType> in = src.get<xpu, 6, DType>(s);
       Tensor<xpu, 6, DType> out = ret.get<xpu, 6, DType>(s);
-      out = transpose(in, axes.get<6>());
+      if (!is_addto) {
+        out = transpose(in, axes.get<6>());
+      } else {
+        out += transpose(in, axes.get<6>());
+      }
       break;
      }
      default:
@@ -399,15 +429,21 @@ void Transpose(const nnvm::NodeAttrs& attrs,
     return;
   }
   const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
-  CHECK_EQ(req[0], kWriteTo) << "Transpose does not support kWriteInplace and kAddTo";
+  CHECK(req[0] == kWriteTo || req[0] == kAddTo)
+       << "Transpose only supports kNullOp, kWriteTo and kAddTo";
+  mxnet::TShape axes;
   if (param.axes.ndim() == 0) {
-    mxnet::TShape axes(inputs[0].ndim(), -1);
+    axes = mxnet::TShape(inputs[0].ndim(), -1);
     for (int i = 0; i < axes.ndim(); ++i) {
       axes[i] = axes.ndim() - 1 - i;
     }
-    TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
   } else {
-    TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], param.axes);
+    axes = common::CanonicalizeAxes(param.axes);
+  }
+  if (req[0] == kAddTo) {
+    TransposeImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0], axes);
+  } else {
+    TransposeImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0], axes);
   }
 }
 
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index b09b332..1e69f72 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -283,11 +283,12 @@ static void TransposeComputeExCPU(const nnvm::NodeAttrs& attrs,
     return;
   }
   const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
-  CHECK_EQ(req[0], kWriteTo) << "Transpose does not support kWriteInplace and kAddTo";
+  CHECK(req[0] == kWriteTo || req[0] == kAddTo) <<
+      "Transpose only supports kNullOp, kWriteTo and kAddTo";
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
 
-  if (SupportMKLDNNTranspose(param, inputs[0])) {
+  if (SupportMKLDNNTranspose(param, inputs[0]) && req[0] == kWriteTo) {
     MKLDNNTransposeForward(attrs, ctx, inputs[0], req[0], outputs[0]);
     return;
   }
diff --git a/src/operator/tensor/pseudo2DTranspose_op-inl.cuh b/src/operator/tensor/pseudo2DTranspose_op-inl.cuh
index 5b7cf04..b3ca9fb 100644
--- a/src/operator/tensor/pseudo2DTranspose_op-inl.cuh
+++ b/src/operator/tensor/pseudo2DTranspose_op-inl.cuh
@@ -39,22 +39,31 @@ namespace mxnet {
 namespace op {
 namespace cuda {
 
-
-template <typename DType, typename CType>
+/*!
+ * \brief The `transpose_pseudo2D` based on chosen vectorized types. It transposes an array of
+ *    shape (k, m, n) to (k, n, m)
+ * \param out Pointer to output memory.
+ * \param inp Pointer to input memory.
+ * \param m First of tensor dimensions.
+ * \param n Second of tensor dimensions.
+ * \param nIterY The number of iterations in the y-dim of the thread to cover all rows. (1-->m)
+ * \param nIterZ The number of iterations in the z-dim of the thread to cover all rows. (1-->k)
+ * \tparam DType Data type
+ * \tparam CType The type to load the data.
+ * \tparam is_addto Whether to perform out += transpose(data) or out = transpose(data)
+ */
+template <typename DType, typename CType, bool is_addto>
 __global__ void transpose_pseudo2D(DType* out, DType* inp,
                                    const index_t m, const index_t n,
                                    const index_t nIterY, const index_t nIterZ) {
-  const index_t TSR = sizeof(CType)/sizeof(DType);  // TypeSizeRatio
+  // Calculate the TypeSizeRatio
+  const index_t TSR = sizeof(CType) / sizeof(DType) > 0 ? sizeof(CType) / sizeof(DType) : 1;
   const index_t chunked_n = n/TSR;
   const index_t chunked_m = m/TSR;
 
-  union transp_t {
-    CType valChunk;
-    DType values[TSR];
-  };
-
-  __shared__ DType d_shm[1024*TSR*TSR];
-  CType* c_shm = reinterpret_cast<CType*>(d_shm);
+  extern __shared__ char buf[];
+  DType* d_shm = reinterpret_cast<DType*>(buf);
+  CType* c_shm = reinterpret_cast<CType*>(buf);
 
   CType* cInp = reinterpret_cast<CType*>(inp);
   CType* cOut = reinterpret_cast<CType*>(out);
@@ -78,23 +87,34 @@ __global__ void transpose_pseudo2D(DType* out, DType* inp,
         }
         __syncthreads();
 
-        // read from shared to registers
-        transp_t tmp[TSR];
+        // read from shared to local registers
+        CType tmp[TSR];
         #pragma unroll
         for (index_t i = 0; i < TSR; i++) {
+          DType* tmp_dptr = reinterpret_cast<DType*>(&tmp[i]);
           #pragma unroll
           for (int j = 0; j < TSR; j++) {
             index_t shmIdx = (TSR*threadIdx.y + j)*blockDim.x*TSR + TSR*threadIdx.x + i;
-            tmp[i].values[j] = d_shm[shmIdx];
+            tmp_dptr[j] = d_shm[shmIdx];
           }
         }
         __syncthreads();
 
         // write back to global output
-        offset = blockIdx_z*m*chunked_n + blockIdx.x*blockDim.x*TSR*chunked_m + blockIdx_y*blockDim.y;
+        offset = blockIdx_z*m*chunked_n + blockIdx.x*blockDim.x*TSR*chunked_m
+                                        + blockIdx_y*blockDim.y;
         #pragma unroll
         for (index_t i = 0; i < TSR; i++) {
-            cOut[offset + (TSR*threadIdx.x + i)*chunked_m + threadIdx.y] = tmp[i].valChunk;
+          if (is_addto) {
+            DType* tmp_dptr = reinterpret_cast<DType*>(&tmp[i]);
+            #pragma unroll
+            for (int j = 0; j < TSR; j++) {
+              out[TSR * (offset + (TSR*threadIdx.x + i)*chunked_m + threadIdx.y) + j]
+                += tmp_dptr[j];
+            }
+          } else {
+            cOut[offset + (TSR*threadIdx.x + i)*chunked_m + threadIdx.y] = tmp[i];
+          }
         }
       }
     }
@@ -107,7 +127,6 @@ __global__ void transpose_pseudo2D(DType* out, DType* inp,
 /*!
  * \brief Calls proper version of kernel `transpose_pseudo2D`
  *        basing on chosen type sizes.
- * \param dTypeSize Size of data type.
  * \param cTypeSize Size of type that should be use to copy.
  * \param grid Grid dimensions for the kernel.
  * \param block Block dimensions for the kernel.
@@ -116,92 +135,39 @@ __global__ void transpose_pseudo2D(DType* out, DType* inp,
  * \param inp Pointer to input memory.
  * \param m First of tensor dimensions.
  * \param n Second of tensor dimensions.
+ * \tparam DType Data type
+ * \tparam is_addto Whether to trigger add the transpose result to the output tensor.
  */
-inline void call_transpose_pseudo2D(index_t dTypeSize, index_t cTypeSize,
-                                   dim3 grid, dim3 block, cudaStream_t stream,
-                                   void* out, void* inp, const index_t m, const index_t n,
-                                   const index_t nIterY, const index_t nIterZ) {
-  switch (dTypeSize) {
-   case (1): {
-    uint8_t* d_outPtr = reinterpret_cast<uint8_t*>(out);
-    uint8_t* d_inpPtr = reinterpret_cast<uint8_t*>(inp);
-    switch (cTypeSize) {
-     case (1):
-      cuda::transpose_pseudo2D<uint8_t, uint8_t><<<grid, block, 0, stream>>>
+template <typename DType, bool is_addto>
+inline void call_transpose_pseudo2D(index_t cTypeSize,
+                                    dim3 grid, dim3 block, cudaStream_t stream,
+                                    DType* d_outPtr, DType* d_inpPtr,
+                                    const index_t m, const index_t n,
+                                    const index_t nIterY, const index_t nIterZ) {
+  const int nshared = 1024 * cTypeSize / sizeof(DType) * cTypeSize;
+  switch (cTypeSize) {
+    case (1):
+      cuda::transpose_pseudo2D<DType, uint8_t, is_addto><<<grid, block, nshared, stream>>>
                               (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ);
       break;
-     case (2):
-      cuda::transpose_pseudo2D<uint8_t, uint16_t><<<grid, block, 0, stream>>>
+    case (2):
+      cuda::transpose_pseudo2D<DType, uint16_t, is_addto><<<grid, block, nshared, stream>>>
                               (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ);
       break;
-     case (4):
-      cuda::transpose_pseudo2D<uint8_t, uint32_t><<<grid, block, 0, stream>>>
+    case (4):
+      cuda::transpose_pseudo2D<DType, uint32_t, is_addto><<<grid, block, nshared, stream>>>
                               (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ);
       break;
-     case (8):
-      // case guarded against in function getBestCopyTypeSize
-      LOG(FATAL) << "cuda::transpose_pseudo2D<uint8_t, uint64_t> would take too much shared memory";
-     default:
-      LOG(FATAL) << "Unsupported type combination";
-    }
-    break;
-   }
-   case (2): {
-    uint16_t* d_outPtr = reinterpret_cast<uint16_t*>(out);
-    uint16_t* d_inpPtr = reinterpret_cast<uint16_t*>(inp);
-    switch (cTypeSize) {
-     case (2):
-      cuda::transpose_pseudo2D<uint16_t, uint16_t><<<grid, block, 0, stream>>>
+    case (8):
+      cuda::transpose_pseudo2D<DType, uint64_t, is_addto><<<grid, block, nshared, stream>>>
                               (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ);
       break;
-     case (4):
-      cuda::transpose_pseudo2D<uint16_t, uint32_t><<<grid, block, 0, stream>>>
-                              (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ);
-      break;
-     case (8):
-      cuda::transpose_pseudo2D<uint16_t, uint64_t><<<grid, block, 0, stream>>>
-                              (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ);
-      break;
-     default:
-      LOG(FATAL) << "Unsupported type combination";
-    }
-    break;
-   }
-   case (4): {
-    uint32_t* d_outPtr = reinterpret_cast<uint32_t*>(out);
-    uint32_t* d_inpPtr = reinterpret_cast<uint32_t*>(inp);
-    switch (cTypeSize) {
-     case (4):
-      cuda::transpose_pseudo2D<uint32_t, uint32_t><<<grid, block, 0, stream>>>
-                              (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ);
-      break;
-     case (8):
-      cuda::transpose_pseudo2D<uint32_t, uint64_t><<<grid, block, 0, stream>>>
-                              (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ);
-      break;
-     default:
-      LOG(FATAL) << "Unsupported type combination";
-    }
-    break;
-   }
-   case (8): {
-    uint64_t* d_outPtr = reinterpret_cast<uint64_t*>(out);
-    uint64_t* d_inpPtr = reinterpret_cast<uint64_t*>(inp);
-    switch (cTypeSize) {
-     case (8):
-      cuda::transpose_pseudo2D<uint64_t, uint64_t><<<grid, block, 0, stream>>>
-                              (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ);
-      break;
-     default:
-      LOG(FATAL) << "Unsupported type combination";
-    }
-    break;
-   }
-   default:
-    LOG(FATAL) << "Unsupported type combination";
+    default:
+      LOG(FATAL) << "Unsupported type combination. " << "Copy type size = " << cTypeSize;
   }
   auto cuErr = cudaPeekAtLastError();
-  CHECK_EQ(cuErr, cudaSuccess) << "Transpose kernel failure: " << cudaGetErrorString(cuErr) << ". "
+  CHECK_EQ(cuErr, cudaSuccess) << "TransposePseudo2D kernel failure: "
+                               << cudaGetErrorString(cuErr) << ". "
                                << "block: (" << block.x << "," << block.y << "," << block.z << ")"
                                << " grid: (" << grid.x << "," << grid.y << "," << grid.z << ")";
 }
@@ -225,7 +191,6 @@ inline bool isPseudo2DTranspose(const TShape& params) {
   return n_swpDims == 2;
 }
 
-
 struct pseudo2DSizes {
   index_t leadDimS;
   index_t M;
@@ -306,15 +271,14 @@ inline std::pair<dim3, dim3> calculateKernelParams(pseudo2DSizes sizes, const in
  * \param outBlob Tensor blob to store result.
  * \param inpBlob Tensor blob with input data.
  * \param params Parameters (axes) of the transpose.
+ * \param is_addto Whether to add the transpose result to the outBlob
  * \param s Pointer to GPU stream.
  */
-template <typename DType, typename gpu>
+template <typename DType, bool is_addto>
 void transpose_pseudo2D(const TBlob& outBlob, const TBlob& inpBlob,
                         const TShape& params, mshadow::Stream<gpu>* s) {
   const TShape& shape = inpBlob.shape_;
   CHECK_EQ(shape.ndim(), params.ndim());
-  auto ndim = params.ndim();
-
   auto sizes = getPackedTransposeDimensions(shape, params);
 
   index_t cTypeSize = getBestCopyTypeSize(sizeof(DType), sizes.M, sizes.N);
@@ -337,8 +301,10 @@ void transpose_pseudo2D(const TBlob& outBlob, const TBlob& inpBlob,
   }
 
   cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
-  call_transpose_pseudo2D(sizeof(DType), cTypeSize, grid, block, stream,
-                          outBlob.dptr_, inpBlob.dptr_, sizes.M, sizes.N, nIterY, nIterZ);
+  call_transpose_pseudo2D<DType, is_addto>
+      (cTypeSize, grid, block, stream,
+       outBlob.dptr<DType>(), inpBlob.dptr<DType>(),
+       sizes.M, sizes.N, nIterY, nIterZ);
 }
 
 }  // namespace op
diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py
index cf913a6..21f949a 100644
--- a/tests/python/unittest/test_gluon_estimator.py
+++ b/tests/python/unittest/test_gluon_estimator.py
@@ -367,6 +367,7 @@ def test_default_handlers():
     val_metrics = est.val_metrics
     early_stopping = EarlyStoppingHandler(monitor=val_metrics[0])
     handlers = est._prepare_default_handlers(val_data=None, event_handlers=[early_stopping])
-    assert len(handlers) == 4
-    assert isinstance(handlers[0], MetricHandler)
-    assert isinstance(handlers[3], LoggingHandler)
+    assert len(handlers) == 5
+    assert isinstance(handlers[0], GradientUpdateHandler)
+    assert isinstance(handlers[1], MetricHandler)
+    assert isinstance(handlers[4], LoggingHandler)
diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py
index 17c7581..658fb88 100644
--- a/tests/python/unittest/test_gluon_event_handler.py
+++ b/tests/python/unittest/test_gluon_event_handler.py
@@ -17,13 +17,19 @@
 
 import os
 import logging
+import sys
+import re
 
 import mxnet as mx
 from common import TemporaryDirectory
 from mxnet import nd
 from mxnet.gluon import nn, loss
 from mxnet.gluon.contrib.estimator import estimator, event_handler
-
+from mxnet.gluon.contrib.estimator.event_handler import LoggingHandler
+try:
+    from StringIO import StringIO
+except ImportError:
+    from io import StringIO
 
 def _get_test_network(net=nn.Sequential()):
     net.add(nn.Dense(128, activation='relu', flatten=False),
@@ -32,9 +38,9 @@ def _get_test_network(net=nn.Sequential()):
     return net
 
 
-def _get_test_data():
-    data = nd.ones((32, 100))
-    label = nd.zeros((32, 1))
+def _get_test_data(in_size=32):
+    data = nd.ones((in_size, 100))
+    label = nd.zeros((in_size, 1))
     data_arr = mx.gluon.data.dataset.ArrayDataset(data, label)
     return mx.gluon.data.DataLoader(data_arr, batch_size=8)
 
@@ -200,3 +206,61 @@ def test_custom_handler():
     est.fit(test_data, event_handlers=[custom_handler], epochs=10)
     assert custom_handler.num_batch == 5 * 4
     assert custom_handler.num_epoch == 5
+
+def test_logging_interval():
+    ''' test different options for logging handler '''
+    ''' test case #1: log interval is 1 '''
+    batch_size = 8
+    data_size = 100
+    old_stdout = sys.stdout
+    sys.stdout = mystdout = StringIO()
+    log_interval = 1
+    net = _get_test_network()
+    dataloader = _get_test_data(in_size=data_size)
+    num_epochs = 1
+    ce_loss = loss.SoftmaxCrossEntropyLoss()
+    acc = mx.metric.Accuracy()
+    logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval)
+    est = estimator.Estimator(net=net,
+                              loss=ce_loss,
+                              metrics=acc)
+
+    est.fit(train_data=dataloader,
+            epochs=num_epochs,
+            event_handlers=[logging])
+
+    sys.stdout = old_stdout
+    log_info_list = mystdout.getvalue().splitlines()
+    info_len = 0
+    for info in log_info_list:
+        match = re.match(
+            '(\[Epoch \d+\]\[Batch \d+\]\[Samples \d+\] time\/interval: \d+.\d+s' +
+            ' training accuracy: \d+.\d+)', info)
+        if match:
+            info_len += 1
+
+    assert(info_len == int(data_size/batch_size/log_interval) + 1)
+    ''' test case #2: log interval is 5 '''
+    old_stdout = sys.stdout
+    sys.stdout = mystdout = StringIO()
+    acc = mx.metric.Accuracy()
+    log_interval = 5
+    logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval)
+    est = estimator.Estimator(net=net,
+                              loss=ce_loss,
+                              metrics=acc)
+    est.fit(train_data=dataloader,
+            epochs=num_epochs,
+            event_handlers=[logging])
+    sys.stdout = old_stdout
+    log_info_list = mystdout.getvalue().splitlines()
+    info_len = 0
+    for info in log_info_list:
+        match = re.match(
+            '(\[Epoch \d+\]\[Batch \d+\]\[Samples \d+\] time\/interval: \d+.\d+s' +
+            ' training accuracy: \d+.\d+)', info)
+        if match:
+            info_len += 1
+
+    assert(info_len == int(data_size/batch_size/log_interval) + 1)
+
diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py
index 0ae8aea..a1e5128 100644
--- a/tests/python/unittest/test_metric.py
+++ b/tests/python/unittest/test_metric.py
@@ -17,6 +17,7 @@
 
 import mxnet as mx
 import numpy as np
+import scipy
 import json
 import math
 from common import with_seed
@@ -263,13 +264,40 @@ def test_perplexity():
     assert perplexity == perplexity_expected
 
 def test_pearsonr():
-    pred = mx.nd.array([[0.7, 0.3], [0.1, 0.9], [1., 0]])
-    label = mx.nd.array([[0, 1], [1, 0], [1, 0]])
-    pearsonr_expected = np.corrcoef(pred.asnumpy().ravel(), label.asnumpy().ravel())[0, 1]
-    metric = mx.metric.create('pearsonr')
-    metric.update([label], [pred])
-    _, pearsonr = metric.get()
-    assert pearsonr == pearsonr_expected
+    pred1 = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])
+    label1 = mx.nd.array([[1, 0], [0, 1], [0, 1]])
+    pearsonr_expected_np = np.corrcoef(pred1.asnumpy().ravel(), label1.asnumpy().ravel())[0, 1]
+    pearsonr_expected_scipy, _ = scipy.stats.pearsonr(pred1.asnumpy().ravel(), label1.asnumpy().ravel())
+    macro_pr = mx.metric.create('pearsonr', average='macro')
+    micro_pr = mx.metric.create('pearsonr', average='micro')
+
+    assert np.isnan(macro_pr.get()[1])
+    assert np.isnan(micro_pr.get()[1])
+
+    macro_pr.update([label1], [pred1])
+    micro_pr.update([label1], [pred1])
+
+    np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_np)
+    np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_scipy)
+    np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_np)
+    np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_scipy)
+
+    pred2 = mx.nd.array([[1, 2], [3, 2], [4, 6]])
+    label2 = mx.nd.array([[1, 0], [0, 1], [0, 1]])
+    # Note that pred12 = pred1 + pred2; label12 = label1 + label2
+    pred12 = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6],[1, 2], [3, 2], [4, 6]])
+    label12 = mx.nd.array([[1, 0], [0, 1], [0, 1], [1, 0], [0, 1], [0, 1]])
+
+    pearsonr_expected_np = np.corrcoef(pred12.asnumpy().ravel(), label12.asnumpy().ravel())[0, 1]
+    pearsonr_expected_scipy, _ = scipy.stats.pearsonr(pred12.asnumpy().ravel(), label12.asnumpy().ravel())
+
+    macro_pr.reset()
+    micro_pr.update([label2], [pred2])
+    macro_pr.update([label12], [pred12])
+    np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_np)
+    np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_scipy)
+    np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_np)
+    np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_scipy)
 
 def cm_batch(cm):
     # generate a batch yielding a given confusion matrix
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index 4c6d9f7..d097799 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -154,6 +154,22 @@ def test_ndarray_setitem():
         assert x.shape == trivial_shape
         assert same(x.asnumpy(), x_np)
 
+    # test https://github.com/apache/incubator-mxnet/issues/16647
+    dst = mx.nd.zeros((1, 3, 1))  # destination array
+    src = [1, 2, 3]
+    dst[0, :len(src), 0] = src
+    assert same(dst.asnumpy(), np.array([1, 2, 3], dtype=dst.dtype).reshape(dst.shape))
+
+    dst = mx.nd.zeros((1, 3, 1))  # destination array
+    src = [1, 2, 3]
+    dst[0, :len(src), 0] = mx.nd.array(src)
+    assert same(dst.asnumpy(), np.array([1, 2, 3], dtype=dst.dtype).reshape(dst.shape))
+
+    dst = mx.nd.zeros((1, 3, 1))  # destination array
+    src = [1, 2]
+    dst[0, :len(src), 0] = src
+    assert same(dst.asnumpy(), np.array([1, 2, 0], dtype=dst.dtype).reshape(dst.shape))
+
 
 @with_seed()
 def test_ndarray_elementwise():
diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py
index 9f4e62c..0bd620e 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -696,26 +696,37 @@ def test_np_ndarray_indexing():
         np_indexed_array = _np.random.randint(low=-10000, high=0, size=indexed_array_shape)
         # test value is a native numpy array without broadcast
         assert_same(np_array, np_index, mx_array, index, np_indexed_array)
+        # test value is a list without broadcast
+        assert_same(np_array, np_index, mx_array, index, np_indexed_array.tolist())
         # test value is a mxnet numpy array without broadcast
         assert_same(np_array, np_index, mx_array, index, np.array(np_indexed_array))
         # test value is an numeric_type
         assert_same(np_array, np_index, mx_array, index, _np.random.randint(low=-10000, high=0))
-        if len(indexed_array_shape) > 1:
-            np_value = _np.random.randint(low=-10000, high=0, size=(indexed_array_shape[-1],))
-            # test mxnet ndarray with broadcast
-            assert_same(np_array, np_index, mx_array, index, np.array(np_value))
-            # test native numpy array with broadcast
-            assert_same(np_array, np_index, mx_array, index, np_value)
-
-            # test value shape are expanded to be longer than index array's shape
-            # this is currently only supported in basic indexing
-            if _is_basic_index(index):
-                expanded_value_shape = (1, 1, 1) + np_value.shape
-                assert_same(np_array, np_index, mx_array, index, np.array(np_value.reshape(expanded_value_shape)))
-                assert_same(np_array, np_index, mx_array, index, np_value.reshape(expanded_value_shape))
-            # test list with broadcast
-            assert_same(np_array, np_index, mx_array, index,
-                        [_np.random.randint(low=-10000, high=0)] * indexed_array_shape[-1])
+
+        np_value = _np.random.randint(low=-10000, high=0,
+                                      size=(indexed_array_shape[-1],) if len(indexed_array_shape) > 0 else ())
+        # test mxnet ndarray with broadcast
+        assert_same(np_array, np_index, mx_array, index, np.array(np_value))
+        # test native numpy array with broadcast
+        assert_same(np_array, np_index, mx_array, index, np_value)
+        # test python list with broadcast
+        assert_same(np_array, np_index, mx_array, index, np_value.tolist())
+
+        # test value shape are expanded to be longer than index array's shape
+        # this is currently only supported in basic indexing
+        if _is_basic_index(index):
+            expanded_value_shape = (1, 1) + np_value.shape
+            assert_same(np_array, np_index, mx_array, index, np.array(np_value.reshape(expanded_value_shape)))
+            assert_same(np_array, np_index, mx_array, index, np_value.reshape(expanded_value_shape))
+            if len(expanded_value_shape) <= np_array[index].ndim:
+                # NumPy does not allow value.ndim > np_array[index].ndim when value is a python list.
+                # It may be a bug of NumPy.
+                assert_same(np_array, np_index, mx_array, index, np_value.reshape(expanded_value_shape).tolist())
+
+        # test list with broadcast
+        assert_same(np_array, np_index, mx_array, index,
+                    [_np.random.randint(low=-10000, high=0)] * indexed_array_shape[-1] if len(indexed_array_shape) > 0
+                    else _np.random.randint(low=-10000, high=0))
 
     def test_getitem_autograd(np_array, index):
         """
@@ -905,6 +916,9 @@ def test_np_ndarray_indexing():
         range(4),
         range(3, 0, -1),
         (range(4,), [1]),
+        (1, 1, slice(None), 1),
+        (1, 1, slice(None, 3), 1),
+        (1, 1, slice(None, 8, 3), 1),
     ]
     for index in index_list:
         test_getitem(np_array, index)
@@ -925,8 +939,8 @@ def test_np_ndarray_indexing():
 
     # test zero-size tensors get and setitem
     shapes_indices = [
-                        ((0), [slice(None, None, None)]),
-                        ((3, 0), [2, (slice(None, None, None)), (slice(None, None, None), None)]),
+        ((0), [slice(None, None, None)]),
+        ((3, 0), [2, (slice(None, None, None)), (slice(None, None, None), None)]),
     ]
     for shape, indices in shapes_indices:
         np_array = _np.zeros(shape)
@@ -1198,11 +1212,14 @@ def test_np_ndarray_pickle():
     a = np.random.uniform(size=(4, 5))
     a_copy = a.copy()
     import pickle
-    with open("np_ndarray_pickle_test_file", 'wb') as f:
-        pickle.dump(a_copy, f)
-    with open("np_ndarray_pickle_test_file", 'rb') as f:
-        a_load = pickle.load(f)
-    same(a.asnumpy(), a_load.asnumpy())
+
+    with TemporaryDirectory() as work_dir:
+        fname = os.path.join(work_dir, 'np_ndarray_pickle_test_file')
+        with open(fname, 'wb') as f:
+            pickle.dump(a_copy, f)
+        with open(fname, 'rb') as f:
+            a_load = pickle.load(f)
+        same(a.asnumpy(), a_load.asnumpy())
 
 
 if __name__ == '__main__':
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 5ec9944..1ff1b61 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -1302,7 +1302,9 @@ def test_np_transpose():
         if axes is None or axes == ():
             return _np.transpose(ograd, axes)
         np_axes = _np.array(list(axes))
-        return _np.transpose(ograd, tuple(list(_np.argsort(np_axes))))
+        transpose_axes = _np.zeros_like(np_axes)
+        transpose_axes[np_axes] = _np.arange(len(np_axes))
+        return _np.transpose(ograd, tuple(list(transpose_axes)))
 
     class TestTranspose(HybridBlock):
         def __init__(self, axes=None):
@@ -1311,45 +1313,57 @@ def test_np_transpose():
 
         def hybrid_forward(self, F, a):
             return F.np.transpose(a, self.axes)
+    test_workloads = [[(), [(), None]],
+                      [(2,), [(0,), None]],
+                      [(0, 2), [(0, 1), (1, 0)]],
+                      [(5, 10), [(0, 1), (1, 0), None]],
+                      [(8, 2, 3), [(2, 0, 1), (0, 2, 1), (0, 1, 2), (2, 1, 0), (-1, 1, 0), None]],
+                      [(8, 2, 16), [(0, 2, 1), (2, 0, 1), (0, 1, 2), (2, 1, 0), (-1, -2, -3)]],
+                      [(8, 3, 4, 8), [(0, 2, 3, 1), (1, 2, 3, 0), (0, 3, 2, 1)]],
+                      [(8, 3, 2, 3, 8), [(0, 1, 3, 2, 4), (0, 1, 2, 3, 4), (4, 0, 1, 2, 3)]],
+                      [(3, 4, 3, 4, 3, 2), [(0, 1, 3, 2, 4, 5), (2, 3, 4, 1, 0, 5), None]]]
 
     for hybridize in [True, False]:
-        for dtype in [_np.int32, _np.float32]:
-            for ndim in range(7):
-                shape = rand_shape_nd(ndim, dim=5, allow_zero_size=True)
-                axeses = [None]
-                if ndim == 0:
-                    axeses += [()]
-                else:
-                    axes = [i for i in range(ndim)]
-                    axeses.append(tuple(axes))
-                    random.shuffle(axes)
-                    axeses.append(tuple(axes))
-                    axeses.append([i - len(axes) for i in axes])
-                for axes in axeses:
-                    test_trans = TestTranspose(axes)
-                    if hybridize:
-                        test_trans.hybridize()
-                    x = rand_ndarray(shape).as_np_ndarray()
-                    x = x.astype(dtype)
-                    x.attach_grad()
-                    np_out = _np.transpose(x.asnumpy(), axes)
-                    with mx.autograd.record():
-                        mx_out = test_trans(x)
-                    assert mx_out.shape == np_out.shape
-                    assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
-                    mx_out.backward()
-                    np_backward = np_transpose_grad(np_out.shape, dtype, axes)
-                    assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5, use_broadcast=False)
-
-                    mx_out = x.transpose(axes)
-                    np_out = x.asnumpy().transpose(axes)
-                    assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
+        for dtype in [_np.float32, _np.float16, _np.int32]:
+            for data_shape, axes_workload in test_workloads:
+                for axes in axes_workload:
+                    for grad_req in ['write', 'add']:
+                        test_trans = TestTranspose(axes)
+                        if hybridize:
+                            test_trans.hybridize()
+                        x = np.random.normal(0, 1, data_shape).astype(dtype)
+                        x = x.astype(dtype)
+                        x.attach_grad(grad_req=grad_req)
+                        if grad_req == 'add':
+                            x.grad[()] = np.random.normal(0, 1, x.grad.shape).astype(x.grad.dtype)
+                            x_grad_np = x.grad.asnumpy()
+                        np_out = _np.transpose(x.asnumpy(), axes)
+                        with mx.autograd.record():
+                            mx_out = test_trans(x)
+                        assert mx_out.shape == np_out.shape
+                        assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
+                        mx_out.backward()
+                        np_backward = np_transpose_grad(np_out.shape, dtype, axes)
+                        if grad_req == 'add':
+                            assert_almost_equal(x.grad.asnumpy(), np_backward + x_grad_np,
+                                                rtol=1e-3, atol=1e-5, use_broadcast=False)
+                        else:
+                            assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5, use_broadcast=False)
 
-                    if isinstance(axes, (list, tuple)):
-                        mx_out = x.transpose(*axes)
-                        np_out = x.asnumpy().transpose(*axes)
+                        mx_out = x.transpose(axes)
+                        np_out = x.asnumpy().transpose(axes)
                         assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
 
+                        if isinstance(axes, (list, tuple)):
+                            mx_out = x.transpose(*axes)
+                            np_out = x.asnumpy().transpose(*axes)
+                            assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
+    # Test for error raising
+    dat = np.random.normal(0, 1, (3, 4, 5), dtype=np.float32)
+    assert_raises(MXNetError, lambda: dat.transpose((0, 0, 1)))
+    assert_raises(MXNetError, lambda: dat.transpose((0, 1, 3)))
+
+
 
 @with_seed()
 @use_np