You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2019/12/10 17:27:06 UTC
[incubator-mxnet] branch v1.6.x updated: Backport #16895, #16922,
#16878, #16979 and #16900 to 1.6 (#17029)
This is an automated email from the ASF dual-hosted git repository.
ptrendx pushed a commit to branch v1.6.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.6.x by this push:
new c973f01 Backport #16895, #16922, #16878, #16979 and #16900 to 1.6 (#17029)
c973f01 is described below
commit c973f01f95ad43dc3f93b322f00c7c8bf1648f0d
Author: Przemyslaw Tredak <pt...@nvidia.com>
AuthorDate: Tue Dec 10 09:26:26 2019 -0800
Backport #16895, #16922, #16878, #16979 and #16900 to 1.6 (#17029)
* Fix ndarray indexing bug (#16895)
* Fix indexing bug
* More test cases
* Add test from 16647
* [Gluon] Update contrib.Estimator LoggingHandler to support logging per batch interval (#16922)
* Update LoggingHandler to support logging per interval
* Fix the constant variable issue in the logging handler
* Remove the constant variable hack in the logging handler.
* 1) replace LOG_PER_BATCH with LOG_PER_INTERVAL 2) add test case
* Improve the test script for LoggingHandler
* small fix on the test script
* logging handler test case bug fix
* remove parameter verbose from LoggingHandler
* move log_interval to the first argument
* resolve unittest mistakes
* Add micro averaging strategy to pearsonr metric (#16878)
Strategy to be used for aggregating across mini-batches.
"macro": average the pearsonr scores for each batch.
"micro": compute a single pearsonr score across all batches.
* [Bugfix] [Numpy] Add `kAddTo` and kNullOp to Transpose (#16979)
* update
Check for repeated axes
enable addto to transpose
fix
fix
fix
fix
remove unused ndim
Update pseudo2DTranspose_op-inl.cuh
Update pseudo2DTranspose_op-inl.cuh
Update pseudo2DTranspose_op-inl.cuh
fix
Update pseudo2DTranspose_op-inl.cuh
try to fix
Update pseudo2DTranspose_op-inl.cuh
Update pseudo2DTranspose_op-inl.cuh
Update pseudo2DTranspose_op-inl.cuh
fix
Update np_matrix_op.cc
Update test_numpy_op.py
update test case
fix implementation
fix bug
update
fix bug
Update pseudo2DTranspose_op-inl.cuh
fix
fix
Update test_numpy_op.py
* Fix bug
* fix docstring
* try to address comment
* no need to change this line
* Fix bug
* address comments
* address comment
* introduce gradient update handler to the base estimator (#16900)
* introduce gradient update handler to the base estimator
* Modify the gradient update handler to include the batch size
* Remove unrelated gradient update handler.
* Modify gradient update handler to take the current batch size.
* Remove white space to avoid the sanity check failure
* add small tweak to the handler code
* Modify the documentation of priority parameter of relevant handlers.
* small modification on the documentation.
* Add small modification on the documentation.
* Remove unnecessary list check
---
python/mxnet/gluon/contrib/estimator/estimator.py | 8 +-
.../mxnet/gluon/contrib/estimator/event_handler.py | 100 +++++++++----
python/mxnet/metric.py | 79 ++++++++--
python/mxnet/ndarray/ndarray.py | 42 +++---
src/ndarray/ndarray.cc | 10 +-
src/operator/numpy/np_matrix_op-inl.h | 16 ++-
src/operator/numpy/np_matrix_op.cc | 12 +-
src/operator/tensor/matrix_op-inl.h | 106 +++++++++-----
src/operator/tensor/matrix_op.cc | 5 +-
src/operator/tensor/pseudo2DTranspose_op-inl.cuh | 160 ++++++++-------------
tests/python/unittest/test_gluon_estimator.py | 7 +-
tests/python/unittest/test_gluon_event_handler.py | 72 +++++++++-
tests/python/unittest/test_metric.py | 42 +++++-
tests/python/unittest/test_ndarray.py | 16 +++
tests/python/unittest/test_numpy_ndarray.py | 63 +++++---
tests/python/unittest/test_numpy_op.py | 84 ++++++-----
16 files changed, 539 insertions(+), 283 deletions(-)
diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py
index 54a0b16..ab7018f 100644
--- a/python/mxnet/gluon/contrib/estimator/estimator.py
+++ b/python/mxnet/gluon/contrib/estimator/estimator.py
@@ -24,7 +24,7 @@ import logging
import sys
import warnings
-from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler
+from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler, GradientUpdateHandler
from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd
from .event_handler import _check_event_handlers
from .utils import _check_metrics, _suggest_metric_for_loss, _check_handler_metric_ref
@@ -307,8 +307,6 @@ class Estimator(object):
for l in loss:
l.backward()
- self.trainer.step(batch_size)
-
return data, label, pred, loss
def fit(self, train_data,
@@ -360,6 +358,7 @@ class Estimator(object):
self.max_epoch = epochs
self.max_batch = batches
+ self.batch_axis = batch_axis
# provide default handlers
event_handlers = self._prepare_default_handlers(val_data, event_handlers)
@@ -414,6 +413,9 @@ class Estimator(object):
# no need to add to default handler check as StoppingHandler does not use metrics
added_default_handlers.append(StoppingHandler(self.max_epoch, self.max_batch))
+ if not any(isinstance(handler, GradientUpdateHandler) for handler in event_handlers):
+ added_default_handlers.append(GradientUpdateHandler())
+
if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
added_default_handlers.append(MetricHandler(train_metrics=self.train_metrics))
diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py
index 3cdc407..6477760 100644
--- a/python/mxnet/gluon/contrib/estimator/event_handler.py
+++ b/python/mxnet/gluon/contrib/estimator/event_handler.py
@@ -31,7 +31,7 @@ from .utils import _check_metrics
__all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd', 'BatchBegin', 'BatchEnd',
'StoppingHandler', 'MetricHandler', 'ValidationHandler',
- 'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler']
+ 'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler', 'GradientUpdateHandler']
class EventHandler(object):
@@ -130,13 +130,16 @@ class MetricHandler(EpochBegin, BatchEnd):
----------
train_metrics : List of EvalMetrics
Training metrics to be updated at batch end.
+ priority : scalar
+ Priority level of the MetricHandler. Priority level is sorted in ascending
+ order. The lower the number is, the higher priority level the handler is.
"""
- def __init__(self, train_metrics):
+ def __init__(self, train_metrics, priority=-1000):
self.train_metrics = _check_metrics(train_metrics)
# order to be called among all callbacks
# metrics need to be calculated before other callbacks can access them
- self.priority = -np.Inf
+ self.priority = priority
def epoch_begin(self, estimator, *args, **kwargs):
for metric in self.train_metrics:
@@ -176,6 +179,10 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
batch_period : int, default None
How often to run validation at batch end, by default
:py:class:`ValidationHandler` does not validate at batch end.
+ priority: scalar, default -1000
+ Priority level of the ValidationHandler. Priority level is sorted in
+ ascending order. The lower the number is, the higher priority level the
+ handler is.
"""
def __init__(self,
@@ -183,7 +190,8 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
eval_fn,
val_metrics=None,
epoch_period=1,
- batch_period=None):
+ batch_period=None,
+ priority=-1000):
self.val_data = val_data
self.eval_fn = eval_fn
self.epoch_period = epoch_period
@@ -193,7 +201,7 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
self.current_epoch = 0
# order to be called among all callbacks
# validation metrics need to be calculated before other callbacks can access them
- self.priority = -np.Inf
+ self.priority = priority
def train_begin(self, estimator, *args, **kwargs):
# reset epoch and batch counter
@@ -227,29 +235,27 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
Parameters
----------
- verbose : int, default LOG_PER_EPOCH
- Limit the granularity of metrics displayed during training process.
- verbose=LOG_PER_EPOCH: display metrics every epoch
- verbose=LOG_PER_BATCH: display metrics every batch
+ log_interval: int or str, default 'epoch'
+ Logging interval during training.
+ log_interval='epoch': display metrics every epoch
+ log_interval=integer k: display metrics every interval of k batches
train_metrics : list of EvalMetrics
Training metrics to be logged, logged at batch end, epoch end, train end.
val_metrics : list of EvalMetrics
Validation metrics to be logged, logged at epoch end, train end.
+ priority : scalar, default np.Inf
+ Priority level of the LoggingHandler. Priority level is sorted in
+ ascending order. The lower the number is, the higher priority level the
+ handler is.
"""
- LOG_PER_EPOCH = 1
- LOG_PER_BATCH = 2
-
- def __init__(self, verbose=LOG_PER_EPOCH,
+ def __init__(self, log_interval='epoch',
train_metrics=None,
- val_metrics=None):
+ val_metrics=None,
+ priority=np.Inf):
super(LoggingHandler, self).__init__()
- if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]:
- raise ValueError("verbose level must be either LOG_PER_EPOCH or "
- "LOG_PER_BATCH, received %s. "
- "E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)"
- % verbose)
- self.verbose = verbose
+ if not isinstance(log_interval, int) and log_interval != 'epoch':
+ raise ValueError("log_interval must be either an integer or string 'epoch'")
self.train_metrics = _check_metrics(train_metrics)
self.val_metrics = _check_metrics(val_metrics)
self.batch_index = 0
@@ -257,7 +263,8 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
self.processed_samples = 0
# logging handler need to be called at last to make sure all states are updated
# it will also shut down logging at train end
- self.priority = np.Inf
+ self.priority = priority
+ self.log_interval = log_interval
def train_begin(self, estimator, *args, **kwargs):
self.train_start = time.time()
@@ -275,6 +282,7 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
self.current_epoch = 0
self.batch_index = 0
self.processed_samples = 0
+ self.log_interval_time = 0
def train_end(self, estimator, *args, **kwargs):
train_time = time.time() - self.train_start
@@ -286,31 +294,34 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
estimator.logger.info(msg.rstrip(', '))
def batch_begin(self, estimator, *args, **kwargs):
- if self.verbose == self.LOG_PER_BATCH:
+ if isinstance(self.log_interval, int):
self.batch_start = time.time()
def batch_end(self, estimator, *args, **kwargs):
- if self.verbose == self.LOG_PER_BATCH:
+ if isinstance(self.log_interval, int):
batch_time = time.time() - self.batch_start
msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index)
self.processed_samples += kwargs['batch'][0].shape[0]
msg += '[Samples %s] ' % (self.processed_samples)
- msg += 'time/batch: %.3fs ' % batch_time
- for metric in self.train_metrics:
- # only log current training loss & metric after each batch
- name, value = metric.get()
- msg += '%s: %.4f, ' % (name, value)
- estimator.logger.info(msg.rstrip(', '))
+ self.log_interval_time += batch_time
+ if self.batch_index % self.log_interval == 0:
+ msg += 'time/interval: %.3fs ' % self.log_interval_time
+ self.log_interval_time = 0
+ for metric in self.train_metrics:
+ # only log current training loss & metric after each interval
+ name, value = metric.get()
+ msg += '%s: %.4f, ' % (name, value)
+ estimator.logger.info(msg.rstrip(', '))
self.batch_index += 1
def epoch_begin(self, estimator, *args, **kwargs):
- if self.verbose >= self.LOG_PER_EPOCH:
+ if isinstance(self.log_interval, int) or self.log_interval == 'epoch':
self.epoch_start = time.time()
estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f",
self.current_epoch, estimator.trainer.learning_rate)
def epoch_end(self, estimator, *args, **kwargs):
- if self.verbose >= self.LOG_PER_EPOCH:
+ if isinstance(self.log_interval, int) or self.log_interval == 'epoch':
epoch_time = time.time() - self.epoch_start
msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time)
for monitor in self.train_metrics + self.val_metrics:
@@ -706,3 +717,30 @@ class EarlyStoppingHandler(TrainBegin, EpochEnd, TrainEnd):
estimator.logger.info('[Epoch %d] EarlyStoppingHanlder: '
'early stopping due to %s not improving',
self.stopped_epoch, self.monitor.get()[0])
+
+class GradientUpdateHandler(BatchEnd):
+ """Gradient Update Handler that apply gradients on network weights
+
+ :py:class:`GradientUpdateHandler` takes the priority level. It updates weight parameters
+ at the end of each batch
+
+ Parameters
+ ----------
+ priority : scalar, default -2000
+ priority level of the gradient update handler. Priority level is sorted in ascending
+ order. The lower the number is, the higher priority level the handler is.
+ ----------
+ """
+ def __init__(self, priority=-2000):
+ self.priority = priority
+
+ def batch_end(self, estimator, *args, **kwargs):
+ loss = kwargs['loss']
+ batch_size = 0
+ if not isinstance(loss, list):
+ loss = [loss]
+ if isinstance(loss, list):
+ for l in loss:
+ batch_size += l.shape[estimator.batch_axis]
+
+ estimator.trainer.step(batch_size)
diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index 6e2d66c..d1074c9 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -590,8 +590,9 @@ class TopKAccuracy(EvalMetric):
class _BinaryClassificationMetrics(object):
- """Private container class for classification metric statistics. True/false positive and
- true/false negative counts are sufficient statistics for various classification metrics.
+ """Private container class for classification metric statistics.
+
+ True/false positive and true/false negative counts are sufficient statistics for various classification metrics.
This class provides the machinery to track those statistics across mini-batches of
(label, prediction) pairs.
"""
@@ -1430,6 +1431,10 @@ class PearsonCorrelation(EvalMetric):
label_names : list of str, or None
Name of labels that should be used when updating with update_dict.
By default include all labels.
+ average : str, default 'macro'
+ Strategy to be used for aggregating across mini-batches.
+ "macro": average the pearsonr scores for each batch.
+ "micro": compute a single pearsonr score across all batches.
Examples
--------
@@ -1438,13 +1443,46 @@ class PearsonCorrelation(EvalMetric):
>>> pr = mx.metric.PearsonCorrelation()
>>> pr.update(labels, predicts)
>>> print pr.get()
- ('pearson-correlation', 0.42163704544016178)
+ ('pearsonr', 0.42163704544016178)
"""
def __init__(self, name='pearsonr',
- output_names=None, label_names=None):
+ output_names=None, label_names=None, average='macro'):
+ self.average = average
super(PearsonCorrelation, self).__init__(
name, output_names=output_names, label_names=label_names,
has_global_stats=True)
+ if self.average == 'micro':
+ self.reset_micro()
+
+ def reset_micro(self):
+ self._sse_p = 0
+ self._mean_p = 0
+ self._sse_l = 0
+ self._mean_l = 0
+ self._pred_nums = 0
+ self._label_nums = 0
+ self._conv = 0
+
+ def reset(self):
+ self.num_inst = 0
+ self.sum_metric = 0.0
+ self.global_num_inst = 0
+ self.global_sum_metric = 0.0
+ if self.average == 'micro':
+ self.reset_micro()
+
+ def update_variance(self, new_values, *aggregate):
+ #Welford's online algorithm for variance update
+ count, mean, m_2 = aggregate
+ count += len(new_values)
+ delta = new_values - mean
+ mean += numpy.sum(delta / count)
+ delta_2 = new_values - mean
+ m_2 += numpy.sum(delta * delta_2)
+ return count, mean, m_2
+
+ def update_cov(self, label, pred):
+ self._conv = self._conv + numpy.sum((label - self._mean_l) * (pred - self._mean_p))
def update(self, labels, preds):
"""Updates the internal evaluation result.
@@ -1457,17 +1495,34 @@ class PearsonCorrelation(EvalMetric):
Predicted values.
"""
labels, preds = check_label_shapes(labels, preds, True)
-
for label, pred in zip(labels, preds):
check_label_shapes(label, pred, False, True)
- label = label.asnumpy()
- pred = pred.asnumpy()
- pearson_corr = numpy.corrcoef(pred.ravel(), label.ravel())[0, 1]
- self.sum_metric += pearson_corr
- self.global_sum_metric += pearson_corr
- self.num_inst += 1
- self.global_num_inst += 1
+ label = label.asnumpy().ravel().astype(numpy.float64)
+ pred = pred.asnumpy().ravel().astype(numpy.float64)
+ if self.average == 'macro':
+ pearson_corr = numpy.corrcoef(pred, label)[0, 1]
+ self.sum_metric += pearson_corr
+ self.global_sum_metric += pearson_corr
+ self.num_inst += 1
+ self.global_num_inst += 1
+ else:
+ self.global_num_inst += 1
+ self.num_inst += 1
+ self._label_nums, self._mean_l, self._sse_l = \
+ self.update_variance(label, self._label_nums, self._mean_l, self._sse_l)
+ self.update_cov(label, pred)
+ self._pred_nums, self._mean_p, self._sse_p = \
+ self.update_variance(pred, self._pred_nums, self._mean_p, self._sse_p)
+ def get(self):
+ if self.num_inst == 0:
+ return (self.name, float('nan'))
+ if self.average == 'macro':
+ return (self.name, self.sum_metric / self.num_inst)
+ else:
+ n = self._label_nums
+ pearsonr = self._conv / ((n-1) * numpy.sqrt(self._sse_p / (n - 1)) * numpy.sqrt(self._sse_l / (n - 1)))
+ return (self.name, pearsonr)
@register
class PCC(EvalMetric):
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index a7ad8e6..bc1bbe4 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -905,6 +905,17 @@ fixed-size items.
return flat_begin, flat_end + 1
# pylint: enable=invalid-name
+ @staticmethod
+ def _drop_int_axes(indexed_shape, int_axes):
+ """drop the axis of indexed_shape corresponding to int axes"""
+ bcast_shape = []
+ for i, size in enumerate(indexed_shape):
+ if i not in int_axes:
+ bcast_shape.append(size)
+ if not bcast_shape:
+ bcast_shape = [1]
+ return tuple(bcast_shape)
+
def _set_nd_basic_indexing(self, key, value):
"""This function indexes ``self`` with a tuple of ``slice`` objects only."""
for idx in key:
@@ -946,14 +957,10 @@ fixed-size items.
if type(value) == self.__class__: # pylint: disable=unidiomatic-typecheck
if value.handle is not self.handle:
# Need to do this before `broadcast_to`.
- tmp_shape = _shape_for_bcast(
- value.shape, target_ndim=self.ndim, new_axes=int_axes
- )
- value = value.reshape(tmp_shape)
-
- if value.shape != self.shape:
- value = value.broadcast_to(self.shape)
- value.copyto(self)
+ bcast_shape = self._drop_int_axes(indexed_shape, int_axes)
+ value_nd = self._prepare_value_nd(value, bcast_shape=bcast_shape, squeeze_axes=new_axes)
+ value_nd = value_nd.reshape(indexed_shape)
+ value_nd.copyto(self)
elif isinstance(value, numeric_types):
self._full(value)
@@ -969,9 +976,10 @@ fixed-size items.
else:
# Other array-like
- value_nd = self._prepare_value_nd(
- value, bcast_shape=self.shape
- )
+ # drop the axis of indexed_shape corresponding to int axes
+ bcast_shape = self._drop_int_axes(indexed_shape, int_axes)
+ value_nd = self._prepare_value_nd(value, bcast_shape=bcast_shape, squeeze_axes=new_axes)
+ value_nd = value_nd.reshape(indexed_shape)
value_nd.copyto(self)
elif isinstance(value, numeric_types):
@@ -979,16 +987,8 @@ fixed-size items.
else:
# drop the axis of indexed_shape corresponding to int axes
- bcast_shape = []
- for i, size in enumerate(indexed_shape):
- if i not in int_axes:
- bcast_shape.append(size)
- if bcast_shape == []:
- bcast_shape = [1]
- bcast_shape = tuple(bcast_shape)
- value_nd = self._prepare_value_nd(
- value, bcast_shape=bcast_shape, squeeze_axes=new_axes
- )
+ bcast_shape = self._drop_int_axes(indexed_shape, int_axes)
+ value_nd = self._prepare_value_nd(value, bcast_shape=bcast_shape, squeeze_axes=new_axes)
value_nd = value_nd.reshape(indexed_shape)
self.slice_assign(value_nd, begin, end, step)
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 9375bed..ba3c334 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -243,8 +243,14 @@ NDArray NDArray::MKLDNNDataReshape(const mxnet::TShape &shape) const {
NDArray NDArray::Reshape(const mxnet::TShape &shape) const {
CHECK(!is_none()) << "NDArray is not initialized";
- CHECK_GE(shape_.Size(), shape.Size())
- << "NDArray.Reshape: target shape size is larger current shape";
+ if (Imperative::Get()->is_np_shape()) {
+ CHECK_EQ(shape_.Size(), shape.Size())
+ << "NDArray.Reshape: target shape must have the same size as "
+ << "current shape.";
+ } else {
+ CHECK_GE(shape_.Size(), shape.Size())
+ << "NDArray.Reshape: target shape size is larger than the current shape";
+ }
NDArray ret = this->Detach();
// If the shape doesn't change, we can just return it now.
if (ret.shape_ == shape)
diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h
index a9828f4..42cc99a 100644
--- a/src/operator/numpy/np_matrix_op-inl.h
+++ b/src/operator/numpy/np_matrix_op-inl.h
@@ -119,16 +119,22 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const NumpyTransposeParam& param = nnvm::get<NumpyTransposeParam>(attrs.parsed);
- CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace";
+ if (req[0] == kNullOp) return;
+ CHECK(req[0] == kWriteTo || req[0] == kAddTo)
+ << "Transpose only supports kWriteTo, kNullOp and kAddTo";
+ mxnet::TShape axes;
if (ndim_is_known(param.axes)) {
- mxnet::TShape axes = common::CanonicalizeAxes(param.axes);
- TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
+ axes = common::CanonicalizeAxes(param.axes);
} else {
- mxnet::TShape axes(inputs[0].ndim(), -1);
+ axes = mxnet::TShape(inputs[0].ndim(), -1);
for (int i = 0; i < axes.ndim(); ++i) {
axes[i] = axes.ndim() - 1 - i;
}
- TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
+ }
+ if (req[0] == kAddTo) {
+ TransposeImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0], axes);
+ } else {
+ TransposeImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0], axes);
}
}
diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc
index 3967cde..41d8d02 100644
--- a/src/operator/numpy/np_matrix_op.cc
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -24,6 +24,7 @@
*/
#include <vector>
+#include <set>
#include "./np_matrix_op-inl.h"
#include "../nn/concat-inl.h"
@@ -65,8 +66,13 @@ bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape ret(ndim, -1);
if (ndim_is_known(param.axes)) {
- CHECK_EQ(ndim, param.axes.ndim());
+ CHECK_EQ(ndim, param.axes.ndim())
+ << "The number of axes does not match the dimension of the tensor. axes = "
+ << param.axes << ", input tensor shape = " << shp;
mxnet::TShape axes = common::CanonicalizeAxes(param.axes);
+ std::set<dim_t> axes_set(axes.begin(), axes.end());
+ CHECK_EQ(axes_set.size(), axes.ndim()) << "Repeated axis in transpose. param.axes = "
+ << param.axes;
if (ndim_is_known(shp)) {
for (int i = 0; i < ndim; ++i) {
ret[i] = shp[axes[i]];
@@ -115,9 +121,9 @@ NNVM_REGISTER_OP(_np_transpose)
}
std::ostringstream os;
os << axes;
- return MakeNonlossGradNode("transpose", n, ograds, {}, {{"axes", os.str()}});
+ return MakeNonlossGradNode("_np_transpose", n, ograds, {}, {{"axes", os.str()}});
} else {
- return MakeNonlossGradNode("transpose", n, ograds, {},
+ return MakeNonlossGradNode("_np_transpose", n, ograds, {},
std::unordered_map<std::string, std::string>());
}
})
diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index 0fee2a2..4bd059a 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -269,8 +269,10 @@ struct TransposeParam : public dmlc::Parameter<TransposeParam> {
* \param out output tensor
* \param row shape of dim 0 of input
* \param col shape of dim 1 of input
+ * \tparam DType Data type
+ * \tparam is_addto
*/
-template<typename DType>
+template<typename DType, bool is_addto>
MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index_t col) {
// ensure cache line hits and prevent cache miss for any configuration
// L1 cache size to be utilized = 32kb = 2^15
@@ -282,7 +284,7 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index
// Block-size - 2^5 v 2^5 (32 v 32) with potential 4 pragma for loop unrolled
// blocksize * blocksize * num_threads = cache_size / dtype_size
// Instead of explicit unroll, let compiler figure out optimal unroll factor
- index_t blocksize = 32;
+ const index_t blocksize = 32;
// collapse 2 parallelizes 2 for loops
// inner 2 for loops aren't parallelized to prevent cache miss
@@ -299,14 +301,25 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index
// transpose the block
for (index_t a = j; (a < blocksize + j) && (a < col); ++a) {
for (index_t b = i; (b < blocksize + i) && (b < row); ++b) {
- out[a * row + b] = in[b * col + a];
+ if (!is_addto) {
+ out[a * row + b] = in[b * col + a];
+ } else {
+ out[a * row + b] += in[b * col + a];
+ }
}
}
}
}
}
-template<typename xpu>
+inline bool IsIdentityTranspose(const TShape& axes) {
+ for (dim_t i = 0; i < axes.ndim(); i++) {
+ if (axes[i] != i) return false;
+ }
+ return true;
+}
+
+template<typename xpu, bool is_addto = false>
void TransposeImpl(RunContext ctx,
const TBlob& src,
const TBlob& ret,
@@ -323,62 +336,79 @@ void TransposeImpl(RunContext ctx,
// Example: (0, 2, 3, 1) or (0, 3, 1, 2), but not (0, 2, 1, 3).
if (isPseudo2DTranspose(axes)) {
MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
- transpose_pseudo2D<DType>(ret, src, axes, s);
+ transpose_pseudo2D<DType, is_addto>(ret, src, axes, s);
});
return;
}
#endif
+ // Special handle the identity case
+ if (IsIdentityTranspose(axes)) {
+ MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
+ Tensor<xpu, 1, DType> in = src.get_with_shape<xpu, 1, DType>(mshadow::Shape1(src.Size()), s);
+ Tensor<xpu, 1, DType> out = ret.get_with_shape<xpu, 1, DType>(mshadow::Shape1(ret.Size()), s);
+ if (!is_addto) {
+ // Use memcpy to accelerate the speed
+ Copy(out, in, s);
+ } else {
+ mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, kAddTo>, xpu>::Launch(
+ s, ret.Size(), out.dptr_, in.dptr_);
+ }
+ });
+ return;
+ }
+ // Handle the general transpose case
MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
switch (axes.ndim()) {
- case 0: {
- Tensor<xpu, 1, DType> in = src.get_with_shape<xpu, 1, DType>(mshadow::Shape1(1), s);
- Tensor<xpu, 1, DType> out = ret.get_with_shape<xpu, 1, DType>(mshadow::Shape1(1), s);
- Copy(out, in, s);
- break;
- }
- case 1: {
- Tensor<xpu, 1, DType> in = src.get<xpu, 1, DType>(s);
- Tensor<xpu, 1, DType> out = ret.get<xpu, 1, DType>(s);
- Copy(out, in, s);
- break;
- }
case 2: {
- mshadow::Tensor<xpu, 2, DType> in = src.FlatTo2D<xpu, DType>(s);
- mshadow::Tensor<xpu, 2, DType> out = ret.FlatTo2D<xpu, DType>(s);
-
- if (axes[0] == 1 && axes[1] == 0) {
- if (ctx.get_ctx().dev_mask() == cpu::kDevMask) {
- Transpose2D<DType>(in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]);
- } else {
- out = in.T();
- }
+ Tensor<xpu, 2, DType> in = src.get<xpu, 2, DType>(s);
+ Tensor<xpu, 2, DType> out = ret.get<xpu, 2, DType>(s);
+ if (ctx.get_ctx().dev_mask() == cpu::kDevMask) {
+ Transpose2D<DType, is_addto>(in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]);
} else {
- Copy(out, in, s);
+ LOG(FATAL) << "Not Implemented. We should never reach here because the 2D case "
+ "in GPU has been covered by transpose_pseudo2D."
+ " Report an issue in Github.";
}
break;
}
case 3: {
Tensor<xpu, 3, DType> in = src.get<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> out = ret.get<xpu, 3, DType>(s);
- out = transpose(in, axes.get<3>());
+ if (!is_addto) {
+ out = transpose(in, axes.get<3>());
+ } else {
+ out += transpose(in, axes.get<3>());
+ }
break;
}
case 4: {
Tensor<xpu, 4, DType> in = src.get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> out = ret.get<xpu, 4, DType>(s);
- out = transpose(in, axes.get<4>());
+ if (!is_addto) {
+ out = transpose(in, axes.get<4>());
+ } else {
+ out += transpose(in, axes.get<4>());
+ }
break;
}
case 5: {
Tensor<xpu, 5, DType> in = src.get<xpu, 5, DType>(s);
Tensor<xpu, 5, DType> out = ret.get<xpu, 5, DType>(s);
- out = transpose(in, axes.get<5>());
+ if (!is_addto) {
+ out = transpose(in, axes.get<5>());
+ } else {
+ out += transpose(in, axes.get<5>());
+ }
break;
}
case 6: {
Tensor<xpu, 6, DType> in = src.get<xpu, 6, DType>(s);
Tensor<xpu, 6, DType> out = ret.get<xpu, 6, DType>(s);
- out = transpose(in, axes.get<6>());
+ if (!is_addto) {
+ out = transpose(in, axes.get<6>());
+ } else {
+ out += transpose(in, axes.get<6>());
+ }
break;
}
default:
@@ -399,15 +429,21 @@ void Transpose(const nnvm::NodeAttrs& attrs,
return;
}
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
- CHECK_EQ(req[0], kWriteTo) << "Transpose does not support kWriteInplace and kAddTo";
+ CHECK(req[0] == kWriteTo || req[0] == kAddTo)
+ << "Transpose only supports kNullOp, kWriteTo and kAddTo";
+ mxnet::TShape axes;
if (param.axes.ndim() == 0) {
- mxnet::TShape axes(inputs[0].ndim(), -1);
+ axes = mxnet::TShape(inputs[0].ndim(), -1);
for (int i = 0; i < axes.ndim(); ++i) {
axes[i] = axes.ndim() - 1 - i;
}
- TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
} else {
- TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], param.axes);
+ axes = common::CanonicalizeAxes(param.axes);
+ }
+ if (req[0] == kAddTo) {
+ TransposeImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0], axes);
+ } else {
+ TransposeImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0], axes);
}
}
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index b09b332..1e69f72 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -283,11 +283,12 @@ static void TransposeComputeExCPU(const nnvm::NodeAttrs& attrs,
return;
}
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
- CHECK_EQ(req[0], kWriteTo) << "Transpose does not support kWriteInplace and kAddTo";
+ CHECK(req[0] == kWriteTo || req[0] == kAddTo) <<
+ "Transpose only supports kNullOp, kWriteTo and kAddTo";
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
- if (SupportMKLDNNTranspose(param, inputs[0])) {
+ if (SupportMKLDNNTranspose(param, inputs[0]) && req[0] == kWriteTo) {
MKLDNNTransposeForward(attrs, ctx, inputs[0], req[0], outputs[0]);
return;
}
diff --git a/src/operator/tensor/pseudo2DTranspose_op-inl.cuh b/src/operator/tensor/pseudo2DTranspose_op-inl.cuh
index 5b7cf04..b3ca9fb 100644
--- a/src/operator/tensor/pseudo2DTranspose_op-inl.cuh
+++ b/src/operator/tensor/pseudo2DTranspose_op-inl.cuh
@@ -39,22 +39,31 @@ namespace mxnet {
namespace op {
namespace cuda {
-
-template <typename DType, typename CType>
+/*!
+ * \brief The `transpose_pseudo2D` based on chosen vectorized types. It transposes an array of
+ * shape (k, m, n) to (k, n, m)
+ * \param out Pointer to output memory.
+ * \param inp Pointer to input memory.
+ * \param m First of tensor dimensions.
+ * \param n Second of tensor dimensions.
+ * \param nIterY The number of iterations in the y-dim of the thread to cover all rows. (1-->m)
+ * \param nIterZ The number of iterations in the z-dim of the thread to cover all rows. (1-->k)
+ * \tparam DType Data type
+ * \tparam CType The type to load the data.
+ * \tparam is_addto Whether to perform out += transpose(data) or out = transpose(data)
+ */
+template <typename DType, typename CType, bool is_addto>
__global__ void transpose_pseudo2D(DType* out, DType* inp,
const index_t m, const index_t n,
const index_t nIterY, const index_t nIterZ) {
- const index_t TSR = sizeof(CType)/sizeof(DType); // TypeSizeRatio
+ // Calculate the TypeSizeRatio
+ const index_t TSR = sizeof(CType) / sizeof(DType) > 0 ? sizeof(CType) / sizeof(DType) : 1;
const index_t chunked_n = n/TSR;
const index_t chunked_m = m/TSR;
- union transp_t {
- CType valChunk;
- DType values[TSR];
- };
-
- __shared__ DType d_shm[1024*TSR*TSR];
- CType* c_shm = reinterpret_cast<CType*>(d_shm);
+ extern __shared__ char buf[];
+ DType* d_shm = reinterpret_cast<DType*>(buf);
+ CType* c_shm = reinterpret_cast<CType*>(buf);
CType* cInp = reinterpret_cast<CType*>(inp);
CType* cOut = reinterpret_cast<CType*>(out);
@@ -78,23 +87,34 @@ __global__ void transpose_pseudo2D(DType* out, DType* inp,
}
__syncthreads();
- // read from shared to registers
- transp_t tmp[TSR];
+ // read from shared to local registers
+ CType tmp[TSR];
#pragma unroll
for (index_t i = 0; i < TSR; i++) {
+ DType* tmp_dptr = reinterpret_cast<DType*>(&tmp[i]);
#pragma unroll
for (int j = 0; j < TSR; j++) {
index_t shmIdx = (TSR*threadIdx.y + j)*blockDim.x*TSR + TSR*threadIdx.x + i;
- tmp[i].values[j] = d_shm[shmIdx];
+ tmp_dptr[j] = d_shm[shmIdx];
}
}
__syncthreads();
// write back to global output
- offset = blockIdx_z*m*chunked_n + blockIdx.x*blockDim.x*TSR*chunked_m + blockIdx_y*blockDim.y;
+ offset = blockIdx_z*m*chunked_n + blockIdx.x*blockDim.x*TSR*chunked_m
+ + blockIdx_y*blockDim.y;
#pragma unroll
for (index_t i = 0; i < TSR; i++) {
- cOut[offset + (TSR*threadIdx.x + i)*chunked_m + threadIdx.y] = tmp[i].valChunk;
+ if (is_addto) {
+ DType* tmp_dptr = reinterpret_cast<DType*>(&tmp[i]);
+ #pragma unroll
+ for (int j = 0; j < TSR; j++) {
+ out[TSR * (offset + (TSR*threadIdx.x + i)*chunked_m + threadIdx.y) + j]
+ += tmp_dptr[j];
+ }
+ } else {
+ cOut[offset + (TSR*threadIdx.x + i)*chunked_m + threadIdx.y] = tmp[i];
+ }
}
}
}
@@ -107,7 +127,6 @@ __global__ void transpose_pseudo2D(DType* out, DType* inp,
/*!
* \brief Calls proper version of kernel `transpose_pseudo2D`
* basing on chosen type sizes.
- * \param dTypeSize Size of data type.
* \param cTypeSize Size of type that should be use to copy.
* \param grid Grid dimensions for the kernel.
* \param block Block dimensions for the kernel.
@@ -116,92 +135,39 @@ __global__ void transpose_pseudo2D(DType* out, DType* inp,
* \param inp Pointer to input memory.
* \param m First of tensor dimensions.
* \param n Second of tensor dimensions.
+ * \tparam DType Data type
+ * \tparam is_addto Whether to trigger add the transpose result to the output tensor.
*/
-inline void call_transpose_pseudo2D(index_t dTypeSize, index_t cTypeSize,
- dim3 grid, dim3 block, cudaStream_t stream,
- void* out, void* inp, const index_t m, const index_t n,
- const index_t nIterY, const index_t nIterZ) {
- switch (dTypeSize) {
- case (1): {
- uint8_t* d_outPtr = reinterpret_cast<uint8_t*>(out);
- uint8_t* d_inpPtr = reinterpret_cast<uint8_t*>(inp);
- switch (cTypeSize) {
- case (1):
- cuda::transpose_pseudo2D<uint8_t, uint8_t><<<grid, block, 0, stream>>>
+template <typename DType, bool is_addto>
+inline void call_transpose_pseudo2D(index_t cTypeSize,
+ dim3 grid, dim3 block, cudaStream_t stream,
+ DType* d_outPtr, DType* d_inpPtr,
+ const index_t m, const index_t n,
+ const index_t nIterY, const index_t nIterZ) {
+ const int nshared = 1024 * cTypeSize / sizeof(DType) * cTypeSize;
+ switch (cTypeSize) {
+ case (1):
+ cuda::transpose_pseudo2D<DType, uint8_t, is_addto><<<grid, block, nshared, stream>>>
(d_outPtr, d_inpPtr, m, n, nIterY, nIterZ);
break;
- case (2):
- cuda::transpose_pseudo2D<uint8_t, uint16_t><<<grid, block, 0, stream>>>
+ case (2):
+ cuda::transpose_pseudo2D<DType, uint16_t, is_addto><<<grid, block, nshared, stream>>>
(d_outPtr, d_inpPtr, m, n, nIterY, nIterZ);
break;
- case (4):
- cuda::transpose_pseudo2D<uint8_t, uint32_t><<<grid, block, 0, stream>>>
+ case (4):
+ cuda::transpose_pseudo2D<DType, uint32_t, is_addto><<<grid, block, nshared, stream>>>
(d_outPtr, d_inpPtr, m, n, nIterY, nIterZ);
break;
- case (8):
- // case guarded against in function getBestCopyTypeSize
- LOG(FATAL) << "cuda::transpose_pseudo2D<uint8_t, uint64_t> would take too much shared memory";
- default:
- LOG(FATAL) << "Unsupported type combination";
- }
- break;
- }
- case (2): {
- uint16_t* d_outPtr = reinterpret_cast<uint16_t*>(out);
- uint16_t* d_inpPtr = reinterpret_cast<uint16_t*>(inp);
- switch (cTypeSize) {
- case (2):
- cuda::transpose_pseudo2D<uint16_t, uint16_t><<<grid, block, 0, stream>>>
+ case (8):
+ cuda::transpose_pseudo2D<DType, uint64_t, is_addto><<<grid, block, nshared, stream>>>
(d_outPtr, d_inpPtr, m, n, nIterY, nIterZ);
break;
- case (4):
- cuda::transpose_pseudo2D<uint16_t, uint32_t><<<grid, block, 0, stream>>>
- (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ);
- break;
- case (8):
- cuda::transpose_pseudo2D<uint16_t, uint64_t><<<grid, block, 0, stream>>>
- (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ);
- break;
- default:
- LOG(FATAL) << "Unsupported type combination";
- }
- break;
- }
- case (4): {
- uint32_t* d_outPtr = reinterpret_cast<uint32_t*>(out);
- uint32_t* d_inpPtr = reinterpret_cast<uint32_t*>(inp);
- switch (cTypeSize) {
- case (4):
- cuda::transpose_pseudo2D<uint32_t, uint32_t><<<grid, block, 0, stream>>>
- (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ);
- break;
- case (8):
- cuda::transpose_pseudo2D<uint32_t, uint64_t><<<grid, block, 0, stream>>>
- (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ);
- break;
- default:
- LOG(FATAL) << "Unsupported type combination";
- }
- break;
- }
- case (8): {
- uint64_t* d_outPtr = reinterpret_cast<uint64_t*>(out);
- uint64_t* d_inpPtr = reinterpret_cast<uint64_t*>(inp);
- switch (cTypeSize) {
- case (8):
- cuda::transpose_pseudo2D<uint64_t, uint64_t><<<grid, block, 0, stream>>>
- (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ);
- break;
- default:
- LOG(FATAL) << "Unsupported type combination";
- }
- break;
- }
- default:
- LOG(FATAL) << "Unsupported type combination";
+ default:
+ LOG(FATAL) << "Unsupported type combination. " << "Copy type size = " << cTypeSize;
}
auto cuErr = cudaPeekAtLastError();
- CHECK_EQ(cuErr, cudaSuccess) << "Transpose kernel failure: " << cudaGetErrorString(cuErr) << ". "
+ CHECK_EQ(cuErr, cudaSuccess) << "TransposePseudo2D kernel failure: "
+ << cudaGetErrorString(cuErr) << ". "
<< "block: (" << block.x << "," << block.y << "," << block.z << ")"
<< " grid: (" << grid.x << "," << grid.y << "," << grid.z << ")";
}
@@ -225,7 +191,6 @@ inline bool isPseudo2DTranspose(const TShape& params) {
return n_swpDims == 2;
}
-
struct pseudo2DSizes {
index_t leadDimS;
index_t M;
@@ -306,15 +271,14 @@ inline std::pair<dim3, dim3> calculateKernelParams(pseudo2DSizes sizes, const in
* \param outBlob Tensor blob to store result.
* \param inpBlob Tensor blob with input data.
* \param params Parameters (axes) of the transpose.
+ * \param is_addto Whether to add the transpose result to the outBlob
* \param s Pointer to GPU stream.
*/
-template <typename DType, typename gpu>
+template <typename DType, bool is_addto>
void transpose_pseudo2D(const TBlob& outBlob, const TBlob& inpBlob,
const TShape& params, mshadow::Stream<gpu>* s) {
const TShape& shape = inpBlob.shape_;
CHECK_EQ(shape.ndim(), params.ndim());
- auto ndim = params.ndim();
-
auto sizes = getPackedTransposeDimensions(shape, params);
index_t cTypeSize = getBestCopyTypeSize(sizeof(DType), sizes.M, sizes.N);
@@ -337,8 +301,10 @@ void transpose_pseudo2D(const TBlob& outBlob, const TBlob& inpBlob,
}
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
- call_transpose_pseudo2D(sizeof(DType), cTypeSize, grid, block, stream,
- outBlob.dptr_, inpBlob.dptr_, sizes.M, sizes.N, nIterY, nIterZ);
+ call_transpose_pseudo2D<DType, is_addto>
+ (cTypeSize, grid, block, stream,
+ outBlob.dptr<DType>(), inpBlob.dptr<DType>(),
+ sizes.M, sizes.N, nIterY, nIterZ);
}
} // namespace op
diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py
index cf913a6..21f949a 100644
--- a/tests/python/unittest/test_gluon_estimator.py
+++ b/tests/python/unittest/test_gluon_estimator.py
@@ -367,6 +367,7 @@ def test_default_handlers():
val_metrics = est.val_metrics
early_stopping = EarlyStoppingHandler(monitor=val_metrics[0])
handlers = est._prepare_default_handlers(val_data=None, event_handlers=[early_stopping])
- assert len(handlers) == 4
- assert isinstance(handlers[0], MetricHandler)
- assert isinstance(handlers[3], LoggingHandler)
+ assert len(handlers) == 5
+ assert isinstance(handlers[0], GradientUpdateHandler)
+ assert isinstance(handlers[1], MetricHandler)
+ assert isinstance(handlers[4], LoggingHandler)
diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py
index 17c7581..658fb88 100644
--- a/tests/python/unittest/test_gluon_event_handler.py
+++ b/tests/python/unittest/test_gluon_event_handler.py
@@ -17,13 +17,19 @@
import os
import logging
+import sys
+import re
import mxnet as mx
from common import TemporaryDirectory
from mxnet import nd
from mxnet.gluon import nn, loss
from mxnet.gluon.contrib.estimator import estimator, event_handler
-
+from mxnet.gluon.contrib.estimator.event_handler import LoggingHandler
+try:
+ from StringIO import StringIO
+except ImportError:
+ from io import StringIO
def _get_test_network(net=nn.Sequential()):
net.add(nn.Dense(128, activation='relu', flatten=False),
@@ -32,9 +38,9 @@ def _get_test_network(net=nn.Sequential()):
return net
-def _get_test_data():
- data = nd.ones((32, 100))
- label = nd.zeros((32, 1))
+def _get_test_data(in_size=32):
+ data = nd.ones((in_size, 100))
+ label = nd.zeros((in_size, 1))
data_arr = mx.gluon.data.dataset.ArrayDataset(data, label)
return mx.gluon.data.DataLoader(data_arr, batch_size=8)
@@ -200,3 +206,61 @@ def test_custom_handler():
est.fit(test_data, event_handlers=[custom_handler], epochs=10)
assert custom_handler.num_batch == 5 * 4
assert custom_handler.num_epoch == 5
+
+def test_logging_interval():
+ ''' test different options for logging handler '''
+ ''' test case #1: log interval is 1 '''
+ batch_size = 8
+ data_size = 100
+ old_stdout = sys.stdout
+ sys.stdout = mystdout = StringIO()
+ log_interval = 1
+ net = _get_test_network()
+ dataloader = _get_test_data(in_size=data_size)
+ num_epochs = 1
+ ce_loss = loss.SoftmaxCrossEntropyLoss()
+ acc = mx.metric.Accuracy()
+ logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval)
+ est = estimator.Estimator(net=net,
+ loss=ce_loss,
+ metrics=acc)
+
+ est.fit(train_data=dataloader,
+ epochs=num_epochs,
+ event_handlers=[logging])
+
+ sys.stdout = old_stdout
+ log_info_list = mystdout.getvalue().splitlines()
+ info_len = 0
+ for info in log_info_list:
+ match = re.match(
+ '(\[Epoch \d+\]\[Batch \d+\]\[Samples \d+\] time\/interval: \d+.\d+s' +
+ ' training accuracy: \d+.\d+)', info)
+ if match:
+ info_len += 1
+
+ assert(info_len == int(data_size/batch_size/log_interval) + 1)
+ ''' test case #2: log interval is 5 '''
+ old_stdout = sys.stdout
+ sys.stdout = mystdout = StringIO()
+ acc = mx.metric.Accuracy()
+ log_interval = 5
+ logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval)
+ est = estimator.Estimator(net=net,
+ loss=ce_loss,
+ metrics=acc)
+ est.fit(train_data=dataloader,
+ epochs=num_epochs,
+ event_handlers=[logging])
+ sys.stdout = old_stdout
+ log_info_list = mystdout.getvalue().splitlines()
+ info_len = 0
+ for info in log_info_list:
+ match = re.match(
+ '(\[Epoch \d+\]\[Batch \d+\]\[Samples \d+\] time\/interval: \d+.\d+s' +
+ ' training accuracy: \d+.\d+)', info)
+ if match:
+ info_len += 1
+
+ assert(info_len == int(data_size/batch_size/log_interval) + 1)
+
diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py
index 0ae8aea..a1e5128 100644
--- a/tests/python/unittest/test_metric.py
+++ b/tests/python/unittest/test_metric.py
@@ -17,6 +17,7 @@
import mxnet as mx
import numpy as np
+import scipy
import json
import math
from common import with_seed
@@ -263,13 +264,40 @@ def test_perplexity():
assert perplexity == perplexity_expected
def test_pearsonr():
- pred = mx.nd.array([[0.7, 0.3], [0.1, 0.9], [1., 0]])
- label = mx.nd.array([[0, 1], [1, 0], [1, 0]])
- pearsonr_expected = np.corrcoef(pred.asnumpy().ravel(), label.asnumpy().ravel())[0, 1]
- metric = mx.metric.create('pearsonr')
- metric.update([label], [pred])
- _, pearsonr = metric.get()
- assert pearsonr == pearsonr_expected
+ pred1 = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])
+ label1 = mx.nd.array([[1, 0], [0, 1], [0, 1]])
+ pearsonr_expected_np = np.corrcoef(pred1.asnumpy().ravel(), label1.asnumpy().ravel())[0, 1]
+ pearsonr_expected_scipy, _ = scipy.stats.pearsonr(pred1.asnumpy().ravel(), label1.asnumpy().ravel())
+ macro_pr = mx.metric.create('pearsonr', average='macro')
+ micro_pr = mx.metric.create('pearsonr', average='micro')
+
+ assert np.isnan(macro_pr.get()[1])
+ assert np.isnan(micro_pr.get()[1])
+
+ macro_pr.update([label1], [pred1])
+ micro_pr.update([label1], [pred1])
+
+ np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_np)
+ np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_scipy)
+ np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_np)
+ np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_scipy)
+
+ pred2 = mx.nd.array([[1, 2], [3, 2], [4, 6]])
+ label2 = mx.nd.array([[1, 0], [0, 1], [0, 1]])
+ # Note that pred12 = pred1 + pred2; label12 = label1 + label2
+ pred12 = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6],[1, 2], [3, 2], [4, 6]])
+ label12 = mx.nd.array([[1, 0], [0, 1], [0, 1], [1, 0], [0, 1], [0, 1]])
+
+ pearsonr_expected_np = np.corrcoef(pred12.asnumpy().ravel(), label12.asnumpy().ravel())[0, 1]
+ pearsonr_expected_scipy, _ = scipy.stats.pearsonr(pred12.asnumpy().ravel(), label12.asnumpy().ravel())
+
+ macro_pr.reset()
+ micro_pr.update([label2], [pred2])
+ macro_pr.update([label12], [pred12])
+ np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_np)
+ np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_scipy)
+ np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_np)
+ np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_scipy)
def cm_batch(cm):
# generate a batch yielding a given confusion matrix
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index 4c6d9f7..d097799 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -154,6 +154,22 @@ def test_ndarray_setitem():
assert x.shape == trivial_shape
assert same(x.asnumpy(), x_np)
+ # test https://github.com/apache/incubator-mxnet/issues/16647
+ dst = mx.nd.zeros((1, 3, 1)) # destination array
+ src = [1, 2, 3]
+ dst[0, :len(src), 0] = src
+ assert same(dst.asnumpy(), np.array([1, 2, 3], dtype=dst.dtype).reshape(dst.shape))
+
+ dst = mx.nd.zeros((1, 3, 1)) # destination array
+ src = [1, 2, 3]
+ dst[0, :len(src), 0] = mx.nd.array(src)
+ assert same(dst.asnumpy(), np.array([1, 2, 3], dtype=dst.dtype).reshape(dst.shape))
+
+ dst = mx.nd.zeros((1, 3, 1)) # destination array
+ src = [1, 2]
+ dst[0, :len(src), 0] = src
+ assert same(dst.asnumpy(), np.array([1, 2, 0], dtype=dst.dtype).reshape(dst.shape))
+
@with_seed()
def test_ndarray_elementwise():
diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py
index 9f4e62c..0bd620e 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -696,26 +696,37 @@ def test_np_ndarray_indexing():
np_indexed_array = _np.random.randint(low=-10000, high=0, size=indexed_array_shape)
# test value is a native numpy array without broadcast
assert_same(np_array, np_index, mx_array, index, np_indexed_array)
+ # test value is a list without broadcast
+ assert_same(np_array, np_index, mx_array, index, np_indexed_array.tolist())
# test value is a mxnet numpy array without broadcast
assert_same(np_array, np_index, mx_array, index, np.array(np_indexed_array))
# test value is an numeric_type
assert_same(np_array, np_index, mx_array, index, _np.random.randint(low=-10000, high=0))
- if len(indexed_array_shape) > 1:
- np_value = _np.random.randint(low=-10000, high=0, size=(indexed_array_shape[-1],))
- # test mxnet ndarray with broadcast
- assert_same(np_array, np_index, mx_array, index, np.array(np_value))
- # test native numpy array with broadcast
- assert_same(np_array, np_index, mx_array, index, np_value)
-
- # test value shape are expanded to be longer than index array's shape
- # this is currently only supported in basic indexing
- if _is_basic_index(index):
- expanded_value_shape = (1, 1, 1) + np_value.shape
- assert_same(np_array, np_index, mx_array, index, np.array(np_value.reshape(expanded_value_shape)))
- assert_same(np_array, np_index, mx_array, index, np_value.reshape(expanded_value_shape))
- # test list with broadcast
- assert_same(np_array, np_index, mx_array, index,
- [_np.random.randint(low=-10000, high=0)] * indexed_array_shape[-1])
+
+ np_value = _np.random.randint(low=-10000, high=0,
+ size=(indexed_array_shape[-1],) if len(indexed_array_shape) > 0 else ())
+ # test mxnet ndarray with broadcast
+ assert_same(np_array, np_index, mx_array, index, np.array(np_value))
+ # test native numpy array with broadcast
+ assert_same(np_array, np_index, mx_array, index, np_value)
+ # test python list with broadcast
+ assert_same(np_array, np_index, mx_array, index, np_value.tolist())
+
+ # test value shape are expanded to be longer than index array's shape
+ # this is currently only supported in basic indexing
+ if _is_basic_index(index):
+ expanded_value_shape = (1, 1) + np_value.shape
+ assert_same(np_array, np_index, mx_array, index, np.array(np_value.reshape(expanded_value_shape)))
+ assert_same(np_array, np_index, mx_array, index, np_value.reshape(expanded_value_shape))
+ if len(expanded_value_shape) <= np_array[index].ndim:
+ # NumPy does not allow value.ndim > np_array[index].ndim when value is a python list.
+ # It may be a bug of NumPy.
+ assert_same(np_array, np_index, mx_array, index, np_value.reshape(expanded_value_shape).tolist())
+
+ # test list with broadcast
+ assert_same(np_array, np_index, mx_array, index,
+ [_np.random.randint(low=-10000, high=0)] * indexed_array_shape[-1] if len(indexed_array_shape) > 0
+ else _np.random.randint(low=-10000, high=0))
def test_getitem_autograd(np_array, index):
"""
@@ -905,6 +916,9 @@ def test_np_ndarray_indexing():
range(4),
range(3, 0, -1),
(range(4,), [1]),
+ (1, 1, slice(None), 1),
+ (1, 1, slice(None, 3), 1),
+ (1, 1, slice(None, 8, 3), 1),
]
for index in index_list:
test_getitem(np_array, index)
@@ -925,8 +939,8 @@ def test_np_ndarray_indexing():
# test zero-size tensors get and setitem
shapes_indices = [
- ((0), [slice(None, None, None)]),
- ((3, 0), [2, (slice(None, None, None)), (slice(None, None, None), None)]),
+ ((0), [slice(None, None, None)]),
+ ((3, 0), [2, (slice(None, None, None)), (slice(None, None, None), None)]),
]
for shape, indices in shapes_indices:
np_array = _np.zeros(shape)
@@ -1198,11 +1212,14 @@ def test_np_ndarray_pickle():
a = np.random.uniform(size=(4, 5))
a_copy = a.copy()
import pickle
- with open("np_ndarray_pickle_test_file", 'wb') as f:
- pickle.dump(a_copy, f)
- with open("np_ndarray_pickle_test_file", 'rb') as f:
- a_load = pickle.load(f)
- same(a.asnumpy(), a_load.asnumpy())
+
+ with TemporaryDirectory() as work_dir:
+ fname = os.path.join(work_dir, 'np_ndarray_pickle_test_file')
+ with open(fname, 'wb') as f:
+ pickle.dump(a_copy, f)
+ with open(fname, 'rb') as f:
+ a_load = pickle.load(f)
+ same(a.asnumpy(), a_load.asnumpy())
if __name__ == '__main__':
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 5ec9944..1ff1b61 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -1302,7 +1302,9 @@ def test_np_transpose():
if axes is None or axes == ():
return _np.transpose(ograd, axes)
np_axes = _np.array(list(axes))
- return _np.transpose(ograd, tuple(list(_np.argsort(np_axes))))
+ transpose_axes = _np.zeros_like(np_axes)
+ transpose_axes[np_axes] = _np.arange(len(np_axes))
+ return _np.transpose(ograd, tuple(list(transpose_axes)))
class TestTranspose(HybridBlock):
def __init__(self, axes=None):
@@ -1311,45 +1313,57 @@ def test_np_transpose():
def hybrid_forward(self, F, a):
return F.np.transpose(a, self.axes)
+ test_workloads = [[(), [(), None]],
+ [(2,), [(0,), None]],
+ [(0, 2), [(0, 1), (1, 0)]],
+ [(5, 10), [(0, 1), (1, 0), None]],
+ [(8, 2, 3), [(2, 0, 1), (0, 2, 1), (0, 1, 2), (2, 1, 0), (-1, 1, 0), None]],
+ [(8, 2, 16), [(0, 2, 1), (2, 0, 1), (0, 1, 2), (2, 1, 0), (-1, -2, -3)]],
+ [(8, 3, 4, 8), [(0, 2, 3, 1), (1, 2, 3, 0), (0, 3, 2, 1)]],
+ [(8, 3, 2, 3, 8), [(0, 1, 3, 2, 4), (0, 1, 2, 3, 4), (4, 0, 1, 2, 3)]],
+ [(3, 4, 3, 4, 3, 2), [(0, 1, 3, 2, 4, 5), (2, 3, 4, 1, 0, 5), None]]]
for hybridize in [True, False]:
- for dtype in [_np.int32, _np.float32]:
- for ndim in range(7):
- shape = rand_shape_nd(ndim, dim=5, allow_zero_size=True)
- axeses = [None]
- if ndim == 0:
- axeses += [()]
- else:
- axes = [i for i in range(ndim)]
- axeses.append(tuple(axes))
- random.shuffle(axes)
- axeses.append(tuple(axes))
- axeses.append([i - len(axes) for i in axes])
- for axes in axeses:
- test_trans = TestTranspose(axes)
- if hybridize:
- test_trans.hybridize()
- x = rand_ndarray(shape).as_np_ndarray()
- x = x.astype(dtype)
- x.attach_grad()
- np_out = _np.transpose(x.asnumpy(), axes)
- with mx.autograd.record():
- mx_out = test_trans(x)
- assert mx_out.shape == np_out.shape
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
- mx_out.backward()
- np_backward = np_transpose_grad(np_out.shape, dtype, axes)
- assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5, use_broadcast=False)
-
- mx_out = x.transpose(axes)
- np_out = x.asnumpy().transpose(axes)
- assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
+ for dtype in [_np.float32, _np.float16, _np.int32]:
+ for data_shape, axes_workload in test_workloads:
+ for axes in axes_workload:
+ for grad_req in ['write', 'add']:
+ test_trans = TestTranspose(axes)
+ if hybridize:
+ test_trans.hybridize()
+ x = np.random.normal(0, 1, data_shape).astype(dtype)
+ x = x.astype(dtype)
+ x.attach_grad(grad_req=grad_req)
+ if grad_req == 'add':
+ x.grad[()] = np.random.normal(0, 1, x.grad.shape).astype(x.grad.dtype)
+ x_grad_np = x.grad.asnumpy()
+ np_out = _np.transpose(x.asnumpy(), axes)
+ with mx.autograd.record():
+ mx_out = test_trans(x)
+ assert mx_out.shape == np_out.shape
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
+ mx_out.backward()
+ np_backward = np_transpose_grad(np_out.shape, dtype, axes)
+ if grad_req == 'add':
+ assert_almost_equal(x.grad.asnumpy(), np_backward + x_grad_np,
+ rtol=1e-3, atol=1e-5, use_broadcast=False)
+ else:
+ assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5, use_broadcast=False)
- if isinstance(axes, (list, tuple)):
- mx_out = x.transpose(*axes)
- np_out = x.asnumpy().transpose(*axes)
+ mx_out = x.transpose(axes)
+ np_out = x.asnumpy().transpose(axes)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
+ if isinstance(axes, (list, tuple)):
+ mx_out = x.transpose(*axes)
+ np_out = x.asnumpy().transpose(*axes)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
+ # Test for error raising
+ dat = np.random.normal(0, 1, (3, 4, 5), dtype=np.float32)
+ assert_raises(MXNetError, lambda: dat.transpose((0, 0, 1)))
+ assert_raises(MXNetError, lambda: dat.transpose((0, 1, 3)))
+
+
@with_seed()
@use_np