You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/07/18 00:01:05 UTC
[incubator-mxnet] 29/42: [numpy] Change d2l chapters cv and gan to
use numpy (#15368)
This is an automated email from the ASF dual-hosted git repository.
haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit a216878bf022020b208faa5d05f615f3720a21c6
Author: reminisce <wu...@gmail.com>
AuthorDate: Wed Jun 26 20:35:06 2019 -0700
[numpy] Change d2l chapters cv and gan to use numpy (#15368)
* Change op name style to lower case underscore
* Add ops under image to npx
* Add image submodule to npx
* Fix split_and_load use np
* Fix fine tuning
* Fix bbox and anchor
* Fix odd
* Fix ssd and rcnn
* Remove restriction on binary element-wise scalar
* Fix gan
* Fix sanity
* Try to fix website build failure
* Add npx.random.seed
* Fix doc
---
python/mxnet/_numpy_op_doc.py | 5 +-
python/mxnet/base.py | 3 +-
python/mxnet/gluon/block.py | 23 ++++++-
python/mxnet/gluon/data/vision/datasets.py | 5 +-
python/mxnet/gluon/data/vision/transforms.py | 28 +++++++-
python/mxnet/gluon/loss.py | 39 ++++++++----
python/mxnet/gluon/model_zoo/vision/resnet.py | 19 ++++--
python/mxnet/gluon/nn/activations.py | 8 +--
python/mxnet/gluon/nn/basic_layers.py | 26 ++++----
python/mxnet/gluon/nn/conv_layers.py | 47 ++++++++++----
python/mxnet/gluon/rnn/rnn_layer.py | 2 +-
python/mxnet/gluon/utils.py | 25 ++++----
python/mxnet/image/detection.py | 17 +++--
python/mxnet/image/image.py | 44 +++++++++----
python/mxnet/ndarray/numpy_extension/__init__.py | 1 +
.../numpy_extension/image.py} | 8 +--
python/mxnet/numpy/__init__.py | 1 +
python/mxnet/numpy/arrayprint.py | 62 ++++++++++++++++++
python/mxnet/numpy/multiarray.py | 53 ++++++++++++++--
python/mxnet/numpy_extension/__init__.py | 2 +
.../__init__.py => numpy_extension/image.py} | 8 +--
python/mxnet/numpy_extension/random.py | 74 ++++++++++++++++++++++
python/mxnet/symbol/numpy_extension/__init__.py | 1 +
.../numpy_extension/{__init__.py => image.py} | 8 +--
src/io/image_io.cc | 3 +
src/ndarray/ndarray.cc | 2 +-
src/operator/contrib/multibox_detection.cc | 4 ++
src/operator/contrib/multibox_prior.cc | 3 +
src/operator/contrib/multibox_target.cc | 4 ++
src/operator/image/crop.cc | 1 +
src/operator/image/image_random.cc | 13 ++++
src/operator/image/resize.cc | 1 +
src/operator/leaky_relu.cc | 1 +
src/operator/nn/activation.cc | 2 +-
src/operator/nn/batch_norm.cc | 2 +-
src/operator/nn/convolution.cc | 2 +-
src/operator/nn/deconvolution.cc | 1 +
src/operator/nn/dropout.cc | 2 +-
src/operator/nn/fully_connected.cc | 2 +-
src/operator/nn/layer_norm.cc | 2 +-
src/operator/nn/pooling.cc | 2 +-
src/operator/numpy/np_elemwise_broadcast_op.cc | 11 +---
src/operator/rnn.cc | 2 +-
src/operator/roi_pooling.cc | 4 ++
src/operator/sequence_mask.cc | 2 +-
.../tensor/elemwise_binary_scalar_op_extended.cc | 3 +-
src/operator/tensor/elemwise_unary_op_basic.cc | 1 +
src/operator/tensor/indexing_op.cc | 2 +-
48 files changed, 451 insertions(+), 130 deletions(-)
diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py
index 995a65c..ca8636c 100644
--- a/python/mxnet/_numpy_op_doc.py
+++ b/python/mxnet/_numpy_op_doc.py
@@ -21,7 +21,10 @@
def _np_reshape(a, newshape, order='C'):
- """Gives a new shape to an array without changing its data.
+ """
+ reshape(a, newshape, order='C')
+
+ Gives a new shape to an array without changing its data.
Parameters
----------
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index a4f75c6..545c2ea 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -744,6 +744,7 @@ _NP_OP_PREFIX = '_np_'
_NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_']
_NP_EXT_OP_PREFIX = '_npx_'
+_NP_EXT_OP_SUBMODULE_LIST = ['_image_']
_NP_INTERNAL_OP_PREFIX = '_npi_'
@@ -784,7 +785,7 @@ def _init_np_op_module(root_module_name, np_module_name, mx_module_name, make_op
submodule_name_list = _NP_OP_SUBMODULE_LIST
elif np_module_name == 'numpy_extension':
op_name_prefix = _NP_EXT_OP_PREFIX
- submodule_name_list = []
+ submodule_name_list = _NP_EXT_OP_SUBMODULE_LIST
elif np_module_name == 'numpy._internal':
op_name_prefix = _NP_INTERNAL_OP_PREFIX
submodule_name_list = []
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 09a2e2a..4516952 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -26,7 +26,6 @@ import warnings
import re
from collections import OrderedDict
-
from ..base import mx_real_t, MXNetError
from .. import symbol, ndarray, initializer
from ..symbol import Symbol
@@ -37,7 +36,7 @@ from .utils import _indent, _brief_print_list, HookHandle
from .utils import _check_same_symbol_type, _check_all_np_ndarrays
from .. import numpy_extension as _mx_npx
from .. import numpy as _mx_np, numpy_extension as _mx_npx
-from .. util import is_np_array
+from .. util import is_np_array, np_shape, np_array
class _BlockScope(object):
@@ -387,7 +386,25 @@ class Block(object):
<https://mxnet.incubator.apache.org/tutorials/gluon/save_load_params.html>`_
"""
if is_np_array():
- loaded = _mx_npx.load(filename)
+ # failure may happen when loading parameters saved as NDArrays within
+ # NumPy semantics. Check the failure type and recover from it if it happens.
+ try:
+ loaded = _mx_npx.load(filename)
+ except MXNetError as e:
+ err_msg = str(e)
+ if 'is_np_shape' in err_msg:
+ # Loading failure due to parameters saved without numpy semantics.
+ # Temporarily disable numpy semantics and load parameters. After it's
+ # done, resume the numpy semantics. This is fine because the cases
+ # numpy ndarray covers is a superset of the legacy ndarray's.
+ with np_array(False):
+ with np_shape(False):
+ loaded_nds = ndarray.load(filename)
+ assert isinstance(loaded_nds, dict),\
+ 'expecting a dict type, got {}'.format(str(type(loaded_nds)))
+ loaded = {k: loaded_nds[k].as_np_ndarray() for k in loaded_nds}
+ else:
+ raise ValueError(err_msg)
else:
loaded = ndarray.load(filename)
params = self._collect_params_with_prefix()
diff --git a/python/mxnet/gluon/data/vision/datasets.py b/python/mxnet/gluon/data/vision/datasets.py
index 362cc9e..bdcaff5 100644
--- a/python/mxnet/gluon/data/vision/datasets.py
+++ b/python/mxnet/gluon/data/vision/datasets.py
@@ -188,8 +188,9 @@ class CIFAR10(dataset._DownloadedDataset):
data = np.concatenate(data)
label = np.concatenate(label)
- self._data = nd.array(data, dtype=data.dtype)
- self._label = label
+ array_fn = _mx_np.array if is_np_array() else nd.array
+ self._data = array_fn(data, dtype=data.dtype)
+ self._label = array_fn(label, dtype=label.dtype) if is_np_array() else label
class CIFAR100(CIFAR10):
diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py
index 54af87e..ab8f8ab 100644
--- a/python/mxnet/gluon/data/vision/transforms.py
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -23,7 +23,7 @@ from ...block import Block, HybridBlock
from ...nn import Sequential, HybridSequential
from .... import image
from ....base import numeric_types
-from ...utils import _adapt_np_array
+from ....util import is_np_array
class Compose(Sequential):
@@ -93,6 +93,8 @@ class Cast(HybridBlock):
self._dtype = dtype
def hybrid_forward(self, F, x):
+ if is_np_array():
+ F = F.npx
return F.cast(x, self._dtype)
@@ -134,8 +136,9 @@ class ToTensor(HybridBlock):
def __init__(self):
super(ToTensor, self).__init__()
- @_adapt_np_array
def hybrid_forward(self, F, x):
+ if is_np_array():
+ F = F.npx
return F.image.to_tensor(x)
@@ -189,6 +192,8 @@ class Normalize(HybridBlock):
self._std = std
def hybrid_forward(self, F, x):
+ if is_np_array():
+ F = F.npx
return F.image.normalize(x, self._mean, self._std)
@@ -370,8 +375,9 @@ class Resize(HybridBlock):
self._size = size
self._interpolation = interpolation
- @_adapt_np_array
def hybrid_forward(self, F, x):
+ if is_np_array():
+ F = F.npx
return F.image.resize(x, self._size, self._keep, self._interpolation)
class RandomFlipLeftRight(HybridBlock):
@@ -388,6 +394,8 @@ class RandomFlipLeftRight(HybridBlock):
super(RandomFlipLeftRight, self).__init__()
def hybrid_forward(self, F, x):
+ if is_np_array():
+ F = F.npx
return F.image.random_flip_left_right(x)
@@ -405,6 +413,8 @@ class RandomFlipTopBottom(HybridBlock):
super(RandomFlipTopBottom, self).__init__()
def hybrid_forward(self, F, x):
+ if is_np_array():
+ F = F.npx
return F.image.random_flip_top_bottom(x)
@@ -430,6 +440,8 @@ class RandomBrightness(HybridBlock):
self._args = (max(0, 1-brightness), 1+brightness)
def hybrid_forward(self, F, x):
+ if is_np_array():
+ F = F.npx
return F.image.random_brightness(x, *self._args)
@@ -455,6 +467,8 @@ class RandomContrast(HybridBlock):
self._args = (max(0, 1-contrast), 1+contrast)
def hybrid_forward(self, F, x):
+ if is_np_array():
+ F = F.npx
return F.image.random_contrast(x, *self._args)
@@ -480,6 +494,8 @@ class RandomSaturation(HybridBlock):
self._args = (max(0, 1-saturation), 1+saturation)
def hybrid_forward(self, F, x):
+ if is_np_array():
+ F = F.npx
return F.image.random_saturation(x, *self._args)
@@ -505,6 +521,8 @@ class RandomHue(HybridBlock):
self._args = (max(0, 1-hue), 1+hue)
def hybrid_forward(self, F, x):
+ if is_np_array():
+ F = F.npx
return F.image.random_hue(x, *self._args)
@@ -539,6 +557,8 @@ class RandomColorJitter(HybridBlock):
self._args = (brightness, contrast, saturation, hue)
def hybrid_forward(self, F, x):
+ if is_np_array():
+ F = F.npx
return F.image.random_color_jitter(x, *self._args)
@@ -562,4 +582,6 @@ class RandomLighting(HybridBlock):
self._alpha = alpha
def hybrid_forward(self, F, x):
+ if is_np_array():
+ F = F.npx
return F.image.random_lighting(x, self._alpha)
diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py
index 6c66d4c..d634e79 100644
--- a/python/mxnet/gluon/loss.py
+++ b/python/mxnet/gluon/loss.py
@@ -258,30 +258,47 @@ class SigmoidBinaryCrossEntropyLoss(Loss):
weight, batch_axis, **kwargs)
self._from_sigmoid = from_sigmoid
- @_adapt_np_array
def hybrid_forward(self, F, pred, label, sample_weight=None, pos_weight=None):
label = _reshape_like(F, label, pred)
+ if is_np_array():
+ relu_fn = F.npx.relu
+ act_fn = F.npx.activation
+ abs_fn = F.np.abs
+ mul_fn = F.np.multiply
+ log_fn = F.np.log
+ else:
+ relu_fn = F.relu
+ act_fn = F.Activation
+ abs_fn = F.abs
+ mul_fn = F.broadcast_mul
+ log_fn = F.log
if not self._from_sigmoid:
if pos_weight is None:
# We use the stable formula: max(x, 0) - x * z + log(1 + exp(-abs(x)))
- loss = F.relu(pred) - pred * label + \
- F.Activation(-F.abs(pred), act_type='softrelu')
+ loss = relu_fn(pred) - pred * label + \
+ act_fn(-abs_fn(pred), act_type='softrelu')
else:
# We use the stable formula: x - x * z + (1 + z * pos_weight - z) * \
# (log(1 + exp(-abs(x))) + max(-x, 0))
- log_weight = 1 + F.broadcast_mul(pos_weight - 1, label)
- loss = pred - pred * label + log_weight * \
- (F.Activation(-F.abs(pred), act_type='softrelu') + F.relu(-pred))
+ log_weight = 1 + mul_fn(pos_weight - 1, label)
+ loss = pred - pred * label + log_weight *\
+ (act_fn(-abs_fn(pred), act_type='softrelu') + relu_fn(-pred))
else:
eps = 1e-12
if pos_weight is None:
- loss = -(F.log(pred + eps) * label
- + F.log(1. - pred + eps) * (1. - label))
+ loss = -(log_fn(pred + eps) * label
+ + log_fn(1. - pred + eps) * (1. - label))
else:
- loss = -(F.broadcast_mul(F.log(pred + eps) * label, pos_weight)
- + F.log(1. - pred + eps) * (1. - label))
+ loss = -(mul_fn(log_fn(pred + eps) * label, pos_weight)
+ + log_fn(1. - pred + eps) * (1. - label))
loss = _apply_weighting(F, loss, self._weight, sample_weight)
- return F.mean(loss, axis=self._batch_axis, exclude=True)
+ if is_np_array():
+ if F is ndarray:
+ return F.np.mean(loss, axis=tuple(range(1, loss.ndim)))
+ else:
+ return F.npx.batch_flatten(loss).mean(axis=1)
+ else:
+ return F.mean(loss, axis=self._batch_axis, exclude=True)
SigmoidBCELoss = SigmoidBinaryCrossEntropyLoss
diff --git a/python/mxnet/gluon/model_zoo/vision/resnet.py b/python/mxnet/gluon/model_zoo/vision/resnet.py
index 48390de..50a65ec 100644
--- a/python/mxnet/gluon/model_zoo/vision/resnet.py
+++ b/python/mxnet/gluon/model_zoo/vision/resnet.py
@@ -33,6 +33,7 @@ from ....context import cpu
from ...block import HybridBlock
from ... import nn
from .... import base
+from .... util import is_np_array
# Helpers
def _conv3x3(channels, stride, in_channels):
@@ -81,7 +82,8 @@ class BasicBlockV1(HybridBlock):
if self.downsample:
residual = self.downsample(residual)
- x = F.Activation(residual+x, act_type='relu')
+ act = F.npx.activation if is_np_array() else F.Activation
+ x = act(residual+x, act_type='relu')
return x
@@ -129,7 +131,8 @@ class BottleneckV1(HybridBlock):
if self.downsample:
residual = self.downsample(residual)
- x = F.Activation(x + residual, act_type='relu')
+ act = F.npx.activation if is_np_array() else F.Activation
+ x = act(x + residual, act_type='relu')
return x
@@ -165,13 +168,14 @@ class BasicBlockV2(HybridBlock):
def hybrid_forward(self, F, x):
residual = x
x = self.bn1(x)
- x = F.Activation(x, act_type='relu')
+ act = F.npx.activation if is_np_array() else F.Activation
+ x = act(x, act_type='relu')
if self.downsample:
residual = self.downsample(x)
x = self.conv1(x)
x = self.bn2(x)
- x = F.Activation(x, act_type='relu')
+ x = act(x, act_type='relu')
x = self.conv2(x)
return x + residual
@@ -211,17 +215,18 @@ class BottleneckV2(HybridBlock):
def hybrid_forward(self, F, x):
residual = x
x = self.bn1(x)
- x = F.Activation(x, act_type='relu')
+ act = F.npx.activation if is_np_array() else F.Activation
+ x = act(x, act_type='relu')
if self.downsample:
residual = self.downsample(x)
x = self.conv1(x)
x = self.bn2(x)
- x = F.Activation(x, act_type='relu')
+ x = act(x, act_type='relu')
x = self.conv2(x)
x = self.bn3(x)
- x = F.Activation(x, act_type='relu')
+ x = act(x, act_type='relu')
x = self.conv3(x)
return x + residual
diff --git a/python/mxnet/gluon/nn/activations.py b/python/mxnet/gluon/nn/activations.py
index 6e0e7ca..a3baae0 100644
--- a/python/mxnet/gluon/nn/activations.py
+++ b/python/mxnet/gluon/nn/activations.py
@@ -49,9 +49,8 @@ class Activation(HybridBlock):
return self._act_type
def hybrid_forward(self, F, x):
- if is_np_array():
- F = F.npx
- return F.Activation(x, act_type=self._act_type, name='fwd')
+ act = F.npx.activation if is_np_array() else F.Activation
+ return act(x, act_type=self._act_type, name='fwd')
def __repr__(self):
s = '{name}({_act_type})'
@@ -91,7 +90,8 @@ class LeakyReLU(HybridBlock):
self._alpha = alpha
def hybrid_forward(self, F, x):
- return F.LeakyReLU(x, act_type='leaky', slope=self._alpha, name='fwd')
+ leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU
+ return leaky_relu(x, act_type='leaky', slope=self._alpha, name='fwd')
def __repr__(self):
s = '{name}({alpha})'
diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index 3c43ac3..a726727 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -218,10 +218,9 @@ class Dense(HybridBlock):
self.act = None
def hybrid_forward(self, F, x, weight, bias=None):
- if is_np_array():
- F = F.npx
- act = F.FullyConnected(x, weight, bias, no_bias=bias is None, num_hidden=self._units,
- flatten=self._flatten, name='fwd')
+ fc = F.npx.fully_connected if is_np_array() else F.FullyConnected
+ act = fc(x, weight, bias, no_bias=bias is None, num_hidden=self._units,
+ flatten=self._flatten, name='fwd')
if self.act is not None:
act = self.act(act)
return act
@@ -266,7 +265,7 @@ class Dropout(HybridBlock):
def hybrid_forward(self, F, x):
if self._rate > 0:
- dropout = F.npx.Dropout if is_np_array() else F.Dropout
+ dropout = F.npx.dropout if is_np_array() else F.Dropout
return dropout(x, p=self._rate, axes=self._axes, name='fwd', cudnn_off=False)
else:
copy = F.np.copy if is_np_array() else F.identity
@@ -361,10 +360,9 @@ class BatchNorm(HybridBlock):
super(BatchNorm, self).cast(dtype)
def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
- if is_np_array():
- F = F.npx
- return F.BatchNorm(x, gamma, beta, running_mean, running_var,
- name='fwd', **self._kwargs)
+ batch_norm = F.npx.batch_norm if is_np_array() else F.BatchNorm
+ return batch_norm(x, gamma, beta, running_mean, running_var,
+ name='fwd', **self._kwargs)
def __repr__(self):
s = '{name}({content}'
@@ -416,9 +414,8 @@ class Embedding(HybridBlock):
allow_deferred_init=True, grad_stype=grad_stype)
def hybrid_forward(self, F, x, weight):
- if is_np_array():
- F = F.npx
- return F.Embedding(x, weight, name='fwd', **self._kwargs)
+ embedding = F.npx.embedding if is_np_array() else F.Embedding
+ return embedding(x, weight, name='fwd', **self._kwargs)
def __repr__(self):
s = '{block_name}({input_dim} -> {output_dim}, {dtype})'
@@ -614,9 +611,8 @@ class LayerNorm(HybridBlock):
allow_deferred_init=True)
def hybrid_forward(self, F, data, gamma, beta):
- if is_np_array():
- F = F.npx
- return F.LayerNorm(data, gamma=gamma, beta=beta, axis=self._axis, eps=self._epsilon)
+ layer_norm = F.npx.layer_norm if is_np_array() else F.LayerNorm
+ return layer_norm(data, gamma=gamma, beta=beta, axis=self._axis, eps=self._epsilon)
def __repr__(self):
s = '{name}({content}'
diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py
index 3e8516b..4682684 100644
--- a/python/mxnet/gluon/nn/conv_layers.py
+++ b/python/mxnet/gluon/nn/conv_layers.py
@@ -34,8 +34,13 @@ from ...util import is_np_array
def _infer_weight_shape(op_name, data_shape, kwargs):
- op = getattr(symbol, op_name)
- sym = op(symbol.var('data', shape=data_shape), **kwargs)
+ data = symbol.var('data', shape=data_shape)
+ if is_np_array():
+ op = getattr(symbol.npx, op_name)
+ data = data.as_np_ndarray()
+ else:
+ op = getattr(symbol, op_name)
+ sym = op(data, **kwargs)
return sym.infer_shape_partial()[0]
@@ -242,9 +247,13 @@ class Conv1D(_Conv):
if isinstance(kernel_size, numeric_types):
kernel_size = (kernel_size,)
assert len(kernel_size) == 1, "kernel_size must be a number or a list of 1 ints"
+ op_name = kwargs.pop('op_name', 'Convolution')
+ if is_np_array():
+ op_name = 'convolution'
super(Conv1D, self).__init__(
channels, kernel_size, strides, padding, dilation, groups, layout,
- in_channels, activation, use_bias, weight_initializer, bias_initializer, **kwargs)
+ in_channels, activation, use_bias, weight_initializer, bias_initializer,
+ op_name, **kwargs)
class Conv2D(_Conv):
@@ -322,9 +331,13 @@ class Conv2D(_Conv):
if isinstance(kernel_size, numeric_types):
kernel_size = (kernel_size,)*2
assert len(kernel_size) == 2, "kernel_size must be a number or a list of 2 ints"
+ op_name = kwargs.pop('op_name', 'Convolution')
+ if is_np_array():
+ op_name = 'convolution'
super(Conv2D, self).__init__(
channels, kernel_size, strides, padding, dilation, groups, layout,
- in_channels, activation, use_bias, weight_initializer, bias_initializer, **kwargs)
+ in_channels, activation, use_bias, weight_initializer, bias_initializer,
+ op_name, **kwargs)
class Conv3D(_Conv):
@@ -403,9 +416,13 @@ class Conv3D(_Conv):
if isinstance(kernel_size, numeric_types):
kernel_size = (kernel_size,)*3
assert len(kernel_size) == 3, "kernel_size must be a number or a list of 3 ints"
+ op_name = kwargs.pop('op_name', 'Convolution')
+ if is_np_array():
+ op_name = 'convolution'
super(Conv3D, self).__init__(
channels, kernel_size, strides, padding, dilation, groups, layout,
- in_channels, activation, use_bias, weight_initializer, bias_initializer, **kwargs)
+ in_channels, activation, use_bias, weight_initializer, bias_initializer,
+ op_name, **kwargs)
class Conv1DTranspose(_Conv):
@@ -487,10 +504,13 @@ class Conv1DTranspose(_Conv):
output_padding = (output_padding,)
assert len(kernel_size) == 1, "kernel_size must be a number or a list of 1 ints"
assert len(output_padding) == 1, "output_padding must be a number or a list of 1 ints"
+ op_name = kwargs.pop('op_name', 'Deconvolution')
+ if is_np_array():
+ op_name = 'deconvolution'
super(Conv1DTranspose, self).__init__(
channels, kernel_size, strides, padding, dilation, groups, layout,
in_channels, activation, use_bias, weight_initializer,
- bias_initializer, op_name='Deconvolution', adj=output_padding, **kwargs)
+ bias_initializer, op_name=op_name, adj=output_padding, **kwargs)
self.outpad = output_padding
@@ -578,10 +598,13 @@ class Conv2DTranspose(_Conv):
output_padding = (output_padding,)*2
assert len(kernel_size) == 2, "kernel_size must be a number or a list of 2 ints"
assert len(output_padding) == 2, "output_padding must be a number or a list of 2 ints"
+ op_name = kwargs.pop('op_name', 'Deconvolution')
+ if is_np_array():
+ op_name = 'deconvolution'
super(Conv2DTranspose, self).__init__(
channels, kernel_size, strides, padding, dilation, groups, layout,
in_channels, activation, use_bias, weight_initializer,
- bias_initializer, op_name='Deconvolution', adj=output_padding, **kwargs)
+ bias_initializer, op_name=op_name, adj=output_padding, **kwargs)
self.outpad = output_padding
@@ -670,10 +693,13 @@ class Conv3DTranspose(_Conv):
output_padding = (output_padding,)*3
assert len(kernel_size) == 3, "kernel_size must be a number or a list of 3 ints"
assert len(output_padding) == 3, "output_padding must be a number or a list of 3 ints"
+ op_name = kwargs.pop('op_name', 'Deconvolution')
+ if is_np_array():
+ op_name = 'deconvolution'
super(Conv3DTranspose, self).__init__(
channels, kernel_size, strides, padding, dilation, groups, layout,
in_channels, activation, use_bias, weight_initializer, bias_initializer,
- op_name='Deconvolution', adj=output_padding, **kwargs)
+ op_name=op_name, adj=output_padding, **kwargs)
self.outpad = output_padding
@@ -700,9 +726,8 @@ class _Pooling(HybridBlock):
return 'pool'
def hybrid_forward(self, F, x):
- if is_np_array():
- F = F.npx
- return F.Pooling(x, name='fwd', **self._kwargs)
+ pooling = F.npx.pooling if is_np_array() else F.Pooling
+ return pooling(x, name='fwd', **self._kwargs)
def __repr__(self):
s = '{name}(size={kernel}, stride={stride}, padding={pad}, ceil_mode={ceil_mode}'
diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py
index 1104b1e..9807c5e 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -284,7 +284,7 @@ class _RNNLayer(HybridBlock):
else:
rnn_args = states
- rnn_fn = F.npx.RNN if is_np_array() else F.RNN
+ rnn_fn = F.npx.rnn if is_np_array() else F.RNN
rnn = rnn_fn(inputs, params, *rnn_args, use_sequence_length=self._use_sequence_length,
state_size=self._hidden_size, projection_size=self._projection_size,
num_layers=self._num_layers, bidirectional=self._dir == 2,
diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index 542a3c6..2822c70 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -86,12 +86,19 @@ def split_data(data, num_slice, batch_axis=0, even_split=True):
slices = [data[i*step:(i+1)*step] if i < num_slice - 1 else data[i*step:size]
for i in range(num_slice)]
elif even_split:
- slices = ndarray.split(data, num_outputs=num_slice, axis=batch_axis)
+ if is_np_array():
+ slices = _mx_np.split(data, indices_or_sections=num_slice, axis=batch_axis)
+ else:
+ slices = ndarray.split(data, num_outputs=num_slice, axis=batch_axis)
else:
- slices = [ndarray.slice_axis(data, batch_axis, i*step, (i+1)*step)
- if i < num_slice - 1 else
- ndarray.slice_axis(data, batch_axis, i*step, size)
- for i in range(num_slice)]
+ if is_np_array():
+ indices = [step * i for i in range(1, num_slice)]
+ slices = _mx_np.split(data, indices_or_sections=indices, axis=batch_axis)
+ else:
+ slices = [ndarray.slice_axis(data, batch_axis, i*step, (i+1)*step)
+ if i < num_slice - 1 else
+ ndarray.slice_axis(data, batch_axis, i*step, size)
+ for i in range(num_slice)]
return slices
@@ -101,7 +108,7 @@ def split_and_load(data, ctx_list, batch_axis=0, even_split=True):
Parameters
----------
- data : NDArray
+ data : NDArray or ndarray
A batch of data.
ctx_list : list of Context
A list of Contexts.
@@ -112,7 +119,7 @@ def split_and_load(data, ctx_list, batch_axis=0, even_split=True):
Returns
-------
- list of NDArray
+ list of NDArrays or ndarrays
Each corresponds to a context in `ctx_list`.
"""
array_fn = _mx_np.array if is_np_array() else ndarray.array
@@ -121,11 +128,7 @@ def split_and_load(data, ctx_list, batch_axis=0, even_split=True):
if len(ctx_list) == 1:
return [data.as_in_context(ctx_list[0])]
- # TODO(junwu): temp solution for supporting np.ndarray
- # rewrite this using np ops
slices = split_data(data, len(ctx_list), batch_axis, even_split)
- if is_np_array():
- slices = [i.as_np_ndarray() for i in slices]
return [i.as_in_context(ctx) for i, ctx in zip(slices, ctx_list)]
diff --git a/python/mxnet/image/detection.py b/python/mxnet/image/detection.py
index a70e572..f3b551b 100644
--- a/python/mxnet/image/detection.py
+++ b/python/mxnet/image/detection.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=unused-import
+# pylint: disable=unused-import, too-many-lines
"""Read images and perform augmentations for object detection."""
from __future__ import absolute_import, print_function
@@ -34,6 +34,8 @@ from .. import io
from .image import RandomOrderAug, ColorJitterAug, LightingAug, ColorNormalizeAug
from .image import ResizeAug, ForceResizeAug, CastAug, HueJitterAug, RandomGrayAug
from .image import fixed_crop, ImageIter, Augmenter
+from ..util import is_np_array
+from .. import numpy as _mx_np # pylint: disable=reimported
class DetAugmenter(object):
@@ -762,6 +764,7 @@ class ImageDetIter(ImageIter):
"""Override the helper function for batchifying data"""
i = start
batch_size = self.batch_size
+ array_fn = _mx_np.array if is_np_array() else nd.array
try:
while i < batch_size:
label, s = self.next_sample()
@@ -778,7 +781,7 @@ class ImageDetIter(ImageIter):
assert i < batch_size, 'Batch size must be multiples of augmenter output length'
batch_data[i] = self.postprocess_data(datum)
num_object = label.shape[0]
- batch_label[i][0:num_object] = nd.array(label)
+ batch_label[i][0:num_object] = array_fn(label)
if num_object < batch_label[i].shape[0]:
batch_label[i][num_object:] = -1
i += 1
@@ -801,8 +804,14 @@ class ImageDetIter(ImageIter):
batch_label = self._cache_label
i = self._cache_idx
else:
- batch_data = nd.zeros((batch_size, c, h, w))
- batch_label = nd.empty(self.provide_label[0][1])
+ if is_np_array():
+ zeros_fn = _mx_np.zeros
+ empty_fn = _mx_np.empty
+ else:
+ zeros_fn = nd.zeros
+ empty_fn = nd.empty
+ batch_data = zeros_fn((batch_size, c, h, w))
+ batch_label = empty_fn(self.provide_label[0][1])
batch_label[:] = -1
i = self._batchify(batch_data, batch_label)
# calculate the padding
diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py
index a142282..c48e2df 100644
--- a/python/mxnet/image/image.py
+++ b/python/mxnet/image/image.py
@@ -28,6 +28,7 @@ import logging
import json
import warnings
import numpy as np
+from .. import numpy as _mx_np # pylint: disable=reimported
try:
@@ -40,6 +41,8 @@ from .. import ndarray as nd
from ..ndarray import _internal
from .. import io
from .. import recordio
+from .. util import is_np_array
+from ..ndarray.numpy import _internal as _npi
def imread(filename, *args, **kwargs):
@@ -80,7 +83,11 @@ def imread(filename, *args, **kwargs):
>>> mx.img.imread("flower.jpg", to_rgb=0)
<NDArray 224x224x3 @cpu(0)>
"""
- return _internal._cvimread(filename, *args, **kwargs)
+ if is_np_array():
+ read_fn = _npi.cvimread
+ else:
+ read_fn = _internal._cvimread
+ return read_fn(filename, *args, **kwargs)
def imresize(src, w, h, *args, **kwargs):
@@ -137,7 +144,8 @@ def imresize(src, w, h, *args, **kwargs):
>>> new_image
<NDArray 240x360x3 @cpu(0)>
"""
- return _internal._cvimresize(src, w, h, *args, **kwargs)
+ resize_fn = _npi.cvimresize if is_np_array() else _internal._cvimresize
+ return resize_fn(src, w, h, *args, **kwargs)
def imdecode(buf, *args, **kwargs):
@@ -193,9 +201,11 @@ def imdecode(buf, *args, **kwargs):
if sys.version_info[0] == 3 and not isinstance(buf, (bytes, bytearray, np.ndarray)):
raise ValueError('buf must be of type bytes, bytearray or numpy.ndarray,'
'if you would like to input type str, please convert to bytes')
- buf = nd.array(np.frombuffer(buf, dtype=np.uint8), dtype=np.uint8)
+ array_fn = _mx_np.array if is_np_array() else nd.array
+ buf = array_fn(np.frombuffer(buf, dtype=np.uint8), dtype=np.uint8)
- return _internal._cvimdecode(buf, *args, **kwargs)
+ cvimdecode = _npi.cvimdecode if is_np_array() else _internal._cvimdecode
+ return cvimdecode(buf, *args, **kwargs)
def scale_down(src_size, size):
@@ -428,7 +438,7 @@ def fixed_crop(src, x0, y0, w, h, size=None, interp=2):
NDArray
An `NDArray` containing the cropped image.
"""
- out = nd.slice(src, begin=(y0, x0, 0), end=(y0 + h, x0 + w, int(src.shape[2])))
+ out = src[y0:y0+h, x0:x0+w]
if size is not None and (w, h) != size:
sizes = (h, w, size[1], size[0])
out = imresize(out, *size, interp=_get_interp_method(interp, sizes))
@@ -1206,6 +1216,7 @@ class ImageIter(io.DataIter):
else:
self.imgrec = None
+ array_fn = _mx_np.array if is_np_array() else nd.array
if path_imglist:
logging.info('%s: loading image list %s...', class_name, path_imglist)
with open(path_imglist) as fin:
@@ -1213,7 +1224,7 @@ class ImageIter(io.DataIter):
imgkeys = []
for line in iter(fin.readline, ''):
line = line.strip().split('\t')
- label = nd.array(line[1:-1], dtype=dtype)
+ label = array_fn(line[1:-1], dtype=dtype)
key = int(line[0])
imglist[key] = (label, line[-1])
imgkeys.append(key)
@@ -1227,11 +1238,11 @@ class ImageIter(io.DataIter):
key = str(index) # pylint: disable=redefined-variable-type
index += 1
if len(img) > 2:
- label = nd.array(img[:-1], dtype=dtype)
+ label = array_fn(img[:-1], dtype=dtype)
elif isinstance(img[0], numeric_types):
- label = nd.array([img[0]], dtype=dtype)
+ label = array_fn([img[0]], dtype=dtype)
else:
- label = nd.array(img[0], dtype=dtype)
+ label = array_fn(img[0], dtype=dtype)
result[key] = (label, img[-1])
imgkeys.append(str(key))
self.imglist = result
@@ -1367,8 +1378,14 @@ class ImageIter(io.DataIter):
i = self._cache_idx
# clear the cache data
else:
- batch_data = nd.zeros((batch_size, c, h, w))
- batch_label = nd.empty(self.provide_label[0][1])
+ if is_np_array():
+ zeros_fn = _mx_np.zeros
+ empty_fn = _mx_np.empty
+ else:
+ zeros_fn = nd.zeros
+ empty_fn = nd.empty
+ batch_data = zeros_fn((batch_size, c, h, w))
+ batch_label = empty_fn(self.provide_label[0][1])
i = self._batchify(batch_data, batch_label)
# calculate the padding
pad = batch_size - i
@@ -1445,4 +1462,7 @@ class ImageIter(io.DataIter):
def postprocess_data(self, datum):
"""Final postprocessing step before image is loaded into the batch."""
- return nd.transpose(datum, axes=(2, 0, 1))
+ if is_np_array():
+ return datum.transpose(2, 0, 1)
+ else:
+ return nd.transpose(datum, axes=(2, 0, 1))
diff --git a/python/mxnet/ndarray/numpy_extension/__init__.py b/python/mxnet/ndarray/numpy_extension/__init__.py
index a718274..5be34ac 100644
--- a/python/mxnet/ndarray/numpy_extension/__init__.py
+++ b/python/mxnet/ndarray/numpy_extension/__init__.py
@@ -18,6 +18,7 @@
"""Module for the ops not belonging to the official numpy package."""
from . import _op
+from . import image
from . import _register
from ._op import * # pylint: disable=wildcard-import
diff --git a/python/mxnet/symbol/numpy_extension/__init__.py b/python/mxnet/ndarray/numpy_extension/image.py
similarity index 80%
copy from python/mxnet/symbol/numpy_extension/__init__.py
copy to python/mxnet/ndarray/numpy_extension/image.py
index a718274..b3bd27f 100644
--- a/python/mxnet/symbol/numpy_extension/__init__.py
+++ b/python/mxnet/ndarray/numpy_extension/image.py
@@ -15,10 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""Module for the ops not belonging to the official numpy package."""
+"""Image pre-processing operators."""
-from . import _op
-from . import _register
-from ._op import * # pylint: disable=wildcard-import
-
-__all__ = _op.__all__
+__all__ = []
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py
index 7a9a2f6..1994148 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/numpy/__init__.py
@@ -29,5 +29,6 @@ from .utils import * # pylint: disable=wildcard-import
from .function_base import * # pylint: disable=wildcard-import
from .stride_tricks import * # pylint: disable=wildcard-import
from .io import * # pylint: disable=wildcard-import
+from .arrayprint import * # pylint: disable=wildcard-import
__all__ = []
diff --git a/python/mxnet/numpy/arrayprint.py b/python/mxnet/numpy/arrayprint.py
new file mode 100644
index 0000000..9be7faf
--- /dev/null
+++ b/python/mxnet/numpy/arrayprint.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""ndarray print format controller."""
+
+from __future__ import absolute_import, print_function
+
+import numpy as onp
+from ..util import set_module
+
+__all__ = ['set_printoptions']
+
+
+@set_module('mxnet.numpy')
+def set_printoptions(precision=None, threshold=None, **kwarg):
+ """
+ Set printing options.
+
+ These options determine the way floating point numbers and arrays are displayed.
+
+ Parameters
+ ----------
+ precision : int or None, optional
+ Number of digits of precision for floating point output (default 8).
+ May be `None` if `floatmode` is not `fixed`, to print as many digits as
+ necessary to uniquely specify the value.
+ threshold : int, optional
+ Total number of array elements which trigger summarization
+ rather than full repr (default 1000).
+
+ Examples
+ --------
+ Floating point precision can be set:
+
+ >>> np.set_printoptions(precision=4)
+ >>> print(np.array([1.123456789]))
+ [ 1.1235]
+
+ Long arrays can be summarised:
+
+ >>> np.set_printoptions(threshold=5)
+ >>> print(np.arange(10))
+ [0. 1. 2. ... 7. 8. 9.]
+ """
+ if kwarg:
+ raise NotImplementedError('mxnet.numpy.set_printoptions only supports parameters'
+ ' precision and threshold for now.')
+ onp.set_printoptions(precision=precision, threshold=threshold, **kwarg)
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 2a37af7..9d9966b 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -423,8 +423,53 @@ class ndarray(NDArray):
return self
def __repr__(self):
- """Returns a string representation of the array."""
+ """
+ Returns a string representation of the array. The dtype of the ndarray will not
+ be appended to the string if it is `float32`. The context of the ndarray will
+ be appended for devices other than CPU.
+
+ Examples
+ --------
+ >>> from mxnet import np, npx
+ >>> a = np.random.uniform(size=(2, 3))
+ >>> a
+ array([[0.5488135 , 0.5928446 , 0.71518934],
+ [0.84426576, 0.60276335, 0.8579456 ]])
+ >>> print(a)
+ [[0.5488135 0.5928446 0.71518934]
+ [0.84426576 0.60276335 0.8579456 ]]
+ >>> a.dtype
+ <class 'numpy.float32'>
+ >>> b = a.astype(np.float64)
+ >>> b
+ array([[0.54881352, 0.59284461, 0.71518934],
+ [0.84426576, 0.60276335, 0.85794562]], dtype=float64)
+ >>> print(b)
+ [[0.54881352 0.59284461 0.71518934]
+ [0.84426576 0.60276335 0.85794562]]
+ >>> b.dtype
+ <class 'numpy.float64'>
+ >>> c = a.copyto(npx.gpu(0))
+ >>> c
+ array([[0.5488135 , 0.5928446 , 0.71518934],
+ [0.84426576, 0.60276335, 0.8579456 ]], ctx=gpu(0))
+ >>> print(c)
+ [[0.5488135 0.5928446 0.71518934]
+ [0.84426576 0.60276335 0.8579456 ]] @gpu(0)
+ >>> d = b.copyto(npx.gpu(0))
+ >>> d
+ array([[0.54881352, 0.59284461, 0.71518934],
+ [0.84426576, 0.60276335, 0.85794562]], dtype=float64, ctx=gpu(0))
+ >>> print(d)
+ [[0.54881352 0.59284461 0.71518934]
+ [0.84426576 0.60276335 0.85794562]] @gpu(0)
+ """
array_str = self.asnumpy().__repr__()
+ dtype = self.dtype
+ if dtype == _np.float64:
+ array_str = array_str[:-1] + ', dtype=float64)'
+ elif dtype == _np.float32:
+ array_str = array_str[:array_str.rindex(', dtype=')] + ')'
context = self.context
if context.device_type == 'cpu':
return array_str
@@ -814,11 +859,7 @@ class ndarray(NDArray):
raise AttributeError('mxnet.numpy.ndarray object has no attribute tile')
def transpose(self, *axes): # pylint: disable=arguments-differ
- """Convenience fluent method for :py:func:`transpose`.
-
- The arguments are the same as for :py:func:`transpose`, with
- this array as data.
- """
+ """Permute the dimensions of an array."""
return _mx_np_op.transpose(self, axes=axes if len(axes) != 0 else None)
def flip(self, *args, **kwargs):
diff --git a/python/mxnet/numpy_extension/__init__.py b/python/mxnet/numpy_extension/__init__.py
index d80f0cc..6e89c00 100644
--- a/python/mxnet/numpy_extension/__init__.py
+++ b/python/mxnet/numpy_extension/__init__.py
@@ -21,6 +21,7 @@
from __future__ import absolute_import
from . import _op
+from . import image
from . import _register
from ._op import * # pylint: disable=wildcard-import
from ..context import * # pylint: disable=wildcard-import
@@ -30,5 +31,6 @@ from ..util import use_np_array, np_array, is_np_array
from ..util import set_np, use_np, reset_np
from ..ndarray import waitall
from .utils import * # pylint: disable=wildcard-import
+from .random import * # pylint: disable=wildcard-import
__all__ = []
diff --git a/python/mxnet/symbol/numpy_extension/__init__.py b/python/mxnet/numpy_extension/image.py
similarity index 80%
copy from python/mxnet/symbol/numpy_extension/__init__.py
copy to python/mxnet/numpy_extension/image.py
index a718274..b3bd27f 100644
--- a/python/mxnet/symbol/numpy_extension/__init__.py
+++ b/python/mxnet/numpy_extension/image.py
@@ -15,10 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""Module for the ops not belonging to the official numpy package."""
+"""Image pre-processing operators."""
-from . import _op
-from . import _register
-from ._op import * # pylint: disable=wildcard-import
-
-__all__ = _op.__all__
+__all__ = []
diff --git a/python/mxnet/numpy_extension/random.py b/python/mxnet/numpy_extension/random.py
new file mode 100644
index 0000000..bfe2270
--- /dev/null
+++ b/python/mxnet/numpy_extension/random.py
@@ -0,0 +1,74 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Namespace for ops used in imperative programming."""
+
+from __future__ import absolute_import
+from .. import random as _mx_rand
+
+
+__all__ = ['seed']
+
+
+def seed(seed, ctx='all'): # pylint: disable=redefined-outer-name
+ """Seeds the random number generators in MXNet.
+
+ This affects the behavior of modules in MXNet that uses random number generators,
+ like the dropout operator and `ndarray`'s random sampling operators.
+
+ Parameters
+ ----------
+ seed : int
+ The random number seed.
+
+ ctx : Context
+ The device context of the generator. The default is "all" which means seeding random
+ number generators of all devices.
+
+ Notes
+ -----
+ Random number generators in MXNet are device specific.
+ `mx.random.seed(seed_state)` sets the state of each generator using `seed_state` and the
+ device id. Therefore, random numbers generated from different devices can be different
+ even if they are seeded using the same seed.
+
+ To produce identical random number sequences independent of the device id,
+ set optional `ctx` argument. This produces the same sequence of random numbers independent
+ of the device id, but the sequence can be different on different kind of devices as MXNet's
+ random number generators for CPU and GPU use different algorithms.
+
+ Example
+ -------
+ >>> from mxnet import np, npx
+ >>> npx.set_np()
+ >>> npx.random.seed(0)
+ >>> np.random.uniform()
+ array(0.5488135)
+ >>> npx.random.seed(128)
+ >>> np.random.uniform()
+ array(0.03812965)
+ >>> npx.random.seed(128)
+ >>> np.random.uniform()
+ array(0.03812965)
+ >>> npx.random.seed(128)
+ >>> np.random.uniform(ctx=npx.gpu(0))
+ array(0.9894903, ctx=gpu(0))
+ >>> npx.random.seed(128)
+ >>> np.random.uniform(ctx=npx.gpu(0))
+ array(0.9894903, ctx=gpu(0))
+ """
+ _mx_rand.seed(seed_state=seed, ctx=ctx)
diff --git a/python/mxnet/symbol/numpy_extension/__init__.py b/python/mxnet/symbol/numpy_extension/__init__.py
index a718274..5be34ac 100644
--- a/python/mxnet/symbol/numpy_extension/__init__.py
+++ b/python/mxnet/symbol/numpy_extension/__init__.py
@@ -18,6 +18,7 @@
"""Module for the ops not belonging to the official numpy package."""
from . import _op
+from . import image
from . import _register
from ._op import * # pylint: disable=wildcard-import
diff --git a/python/mxnet/symbol/numpy_extension/__init__.py b/python/mxnet/symbol/numpy_extension/image.py
similarity index 80%
copy from python/mxnet/symbol/numpy_extension/__init__.py
copy to python/mxnet/symbol/numpy_extension/image.py
index a718274..b3bd27f 100644
--- a/python/mxnet/symbol/numpy_extension/__init__.py
+++ b/python/mxnet/symbol/numpy_extension/image.py
@@ -15,10 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""Module for the ops not belonging to the official numpy package."""
+"""Image pre-processing operators."""
-from . import _op
-from . import _register
-from ._op import * # pylint: disable=wildcard-import
-
-__all__ = _op.__all__
+__all__ = []
diff --git a/src/io/image_io.cc b/src/io/image_io.cc
index c035799..db9ac76 100644
--- a/src/io/image_io.cc
+++ b/src/io/image_io.cc
@@ -357,6 +357,7 @@ inline void copyMakeBorder(const nnvm::NodeAttrs& attrs,
}
NNVM_REGISTER_OP(_cvimdecode)
+.add_alias("_npi_cvimdecode")
.describe("Decode image with OpenCV. \n"
"Note: return image in RGB by default, "
"instead of OpenCV's default BGR.")
@@ -368,6 +369,7 @@ NNVM_REGISTER_OP(_cvimdecode)
.add_arguments(ImdecodeParam::__FIELDS__());
NNVM_REGISTER_OP(_cvimread)
+.add_alias("_npi_cvimread")
.describe("Read and decode image with OpenCV. \n"
"Note: return image in RGB by default, "
"instead of OpenCV's default BGR.")
@@ -378,6 +380,7 @@ NNVM_REGISTER_OP(_cvimread)
.add_arguments(ImreadParam::__FIELDS__());
NNVM_REGISTER_OP(_cvimresize)
+.add_alias("_npi_cvimresize")
.describe("Resize image with OpenCV. \n")
.set_num_inputs(1)
.set_num_outputs(1)
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index f10f5db..d8cb931 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -1728,7 +1728,7 @@ bool NDArray::Load(dmlc::Stream *strm) {
CHECK(!Imperative::Get()->is_np_shape())
<< "ndarray was not saved in np shape semantics, but being loaded in np shape semantics."
" Please turn off np shape semantics in Python using `with np_shape(False)`"
- " to scope of the code of loading the ndarray.";
+ " to scope the code of loading the ndarray.";
}
if (magic != NDARRAY_V2_MAGIC && magic != NDARRAY_V3_MAGIC) {
return LegacyLoad(strm, magic);
diff --git a/src/operator/contrib/multibox_detection.cc b/src/operator/contrib/multibox_detection.cc
index 37bb5a5..cb2dfe3 100644
--- a/src/operator/contrib/multibox_detection.cc
+++ b/src/operator/contrib/multibox_detection.cc
@@ -220,5 +220,9 @@ MXNET_REGISTER_OP_PROPERTY(_contrib_MultiBoxDetection, MultiBoxDetectionProp)
.add_argument("loc_pred", "NDArray-or-Symbol", "Location regression predictions.")
.add_argument("anchor", "NDArray-or-Symbol", "Multibox prior anchor boxes")
.add_arguments(MultiBoxDetectionParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_contrib_MultiBoxDetection)
+.add_alias("_npx_multibox_detection");
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/contrib/multibox_prior.cc b/src/operator/contrib/multibox_prior.cc
index 2ad173a2..66fd2c1 100644
--- a/src/operator/contrib/multibox_prior.cc
+++ b/src/operator/contrib/multibox_prior.cc
@@ -100,5 +100,8 @@ MXNET_REGISTER_OP_PROPERTY(_contrib_MultiBoxPrior, MultiBoxPriorProp)
.add_arguments(MultiBoxPriorParam::__FIELDS__())
.describe("Generate prior(anchor) boxes from data, sizes and ratios.");
+NNVM_REGISTER_OP(_contrib_MultiBoxPrior)
+.add_alias("_npx_multibox_prior");
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/contrib/multibox_target.cc b/src/operator/contrib/multibox_target.cc
index a1808c5..feab397 100644
--- a/src/operator/contrib/multibox_target.cc
+++ b/src/operator/contrib/multibox_target.cc
@@ -307,5 +307,9 @@ MXNET_REGISTER_OP_PROPERTY(_contrib_MultiBoxTarget, MultiBoxTargetProp)
.add_argument("label", "NDArray-or-Symbol", "Object detection labels.")
.add_argument("cls_pred", "NDArray-or-Symbol", "Class predictions.")
.add_arguments(MultiBoxTargetParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_contrib_MultiBoxTarget)
+.add_alias("_npx_multibox_target");
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/image/crop.cc b/src/operator/image/crop.cc
index 52d2f11..6067f89 100644
--- a/src/operator/image/crop.cc
+++ b/src/operator/image/crop.cc
@@ -35,6 +35,7 @@ namespace image {
DMLC_REGISTER_PARAMETER(CropParam);
NNVM_REGISTER_OP(_image_crop)
+.add_alias("_npx__image_crop")
.describe(R"code(Crop an image NDArray of shape (H x W x C) or (N x H x W x C)
to the given size.
Example:
diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc
index 34f4cb4..0c4603e 100644
--- a/src/operator/image/image_random.cc
+++ b/src/operator/image/image_random.cc
@@ -39,6 +39,7 @@ DMLC_REGISTER_PARAMETER(RandomLightingParam);
DMLC_REGISTER_PARAMETER(RandomColorJitterParam);
NNVM_REGISTER_OP(_image_to_tensor)
+.add_alias("_npx__image_to_tensor")
.describe(R"code(Converts an image NDArray of shape (H x W x C) or (N x H x W x C)
with values in the range [0, 255] to a tensor NDArray of shape (C x H x W) or (N x C x H x W)
with values in the range [0, 1]
@@ -102,6 +103,7 @@ Example:
.add_argument("data", "NDArray-or-Symbol", "Input ndarray");
NNVM_REGISTER_OP(_image_normalize)
+.add_alias("_npx__image_normalize")
.describe(R"code(Normalize an tensor of shape (C x H x W) or (N x C x H x W) with mean and
standard deviation.
@@ -189,28 +191,34 @@ NNVM_REGISTER_OP(_backward_image_normalize)
.set_attr<FCompute>("FCompute<cpu>", NormalizeOpBackward<cpu>);
MXNET_REGISTER_IMAGE_AUG_OP(_image_flip_left_right)
+.add_alias("_npx__image_flip_left_right")
.describe(R"code()code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", FlipLeftRight);
MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_flip_left_right)
+.add_alias("_npx__image_random_flip_left_right")
.describe(R"code()code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", RandomFlipLeftRight);
MXNET_REGISTER_IMAGE_AUG_OP(_image_flip_top_bottom)
+.add_alias("_npx__image_flip_top_bottom")
.describe(R"code()code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", FlipTopBottom);
MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_flip_top_bottom)
+.add_alias("_npx__image_random_flip_top_bottom")
.describe(R"code()code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", RandomFlipTopBottom);
MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_brightness)
+.add_alias("_npx__image_random_brightness")
.describe(R"code()code" ADD_FILELINE)
.set_attr_parser(ParamParser<RandomEnhanceParam>)
.set_attr<FCompute>("FCompute<cpu>", RandomBrightness)
.add_arguments(RandomEnhanceParam::__FIELDS__());
MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_contrast)
+.add_alias("_npx__image_random_contrast")
.describe(R"code()code" ADD_FILELINE)
.set_attr_parser(ParamParser<RandomEnhanceParam>)
.set_attr<FCompute>("FCompute<cpu>", RandomContrast)
@@ -218,6 +226,7 @@ MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_contrast)
MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_saturation)
+.add_alias("_npx__image_random_saturation")
.describe(R"code()code" ADD_FILELINE)
.set_attr_parser(ParamParser<RandomEnhanceParam>)
.set_attr<FCompute>("FCompute<cpu>", RandomSaturation)
@@ -225,6 +234,7 @@ MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_saturation)
MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_hue)
+.add_alias("_npx__image_random_hue")
.describe(R"code()code" ADD_FILELINE)
.set_attr_parser(ParamParser<RandomEnhanceParam>)
.set_attr<FCompute>("FCompute<cpu>", RandomHue)
@@ -232,6 +242,7 @@ MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_hue)
MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_color_jitter)
+.add_alias("_npx__image_random_color_jitter")
.describe(R"code()code" ADD_FILELINE)
.set_attr_parser(ParamParser<RandomColorJitterParam>)
.set_attr<FCompute>("FCompute<cpu>", RandomColorJitter)
@@ -239,6 +250,7 @@ MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_color_jitter)
MXNET_REGISTER_IMAGE_AUG_OP(_image_adjust_lighting)
+.add_alias("_npx__image_adjust_lighting")
.describe(R"code(Adjust the lighting level of the input. Follow the AlexNet style.)code" ADD_FILELINE)
.set_attr_parser(ParamParser<AdjustLightingParam>)
.set_attr<FCompute>("FCompute<cpu>", AdjustLighting)
@@ -246,6 +258,7 @@ MXNET_REGISTER_IMAGE_AUG_OP(_image_adjust_lighting)
MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_lighting)
+.add_alias("_npx__image_random_lighting")
.describe(R"code(Randomly add PCA noise. Follow the AlexNet style.)code" ADD_FILELINE)
.set_attr_parser(ParamParser<RandomLightingParam>)
.set_attr<FCompute>("FCompute<cpu>", RandomLighting)
diff --git a/src/operator/image/resize.cc b/src/operator/image/resize.cc
index d93769f..d2397ea 100644
--- a/src/operator/image/resize.cc
+++ b/src/operator/image/resize.cc
@@ -34,6 +34,7 @@ namespace image {
DMLC_REGISTER_PARAMETER(ResizeParam);
NNVM_REGISTER_OP(_image_resize)
+.add_alias("_npx__image_resize")
.describe(R"code(Resize an image NDArray of shape (H x W x C) or (N x H x W x C)
to the given size
Example:
diff --git a/src/operator/leaky_relu.cc b/src/operator/leaky_relu.cc
index 214e41a..c25833b 100644
--- a/src/operator/leaky_relu.cc
+++ b/src/operator/leaky_relu.cc
@@ -71,6 +71,7 @@ The following modified ReLU Activation functions are supported:
.add_arguments(LeakyReLUParam::__FIELDS__());
NNVM_REGISTER_OP(LeakyReLU)
+.add_alias("_npx_leaky_relu")
.set_attr<nnvm::FSetInputVarAttrOnCompose>("FSetInputVarAttrOnCompose",
[](const nnvm::NodeAttrs& attrs, nnvm::NodePtr var, const int index) {
if (index == 1 && var->attrs.dict.find("__init__") == var->attrs.dict.end()) {
diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc
index 3d668c8..5abb667 100644
--- a/src/operator/nn/activation.cc
+++ b/src/operator/nn/activation.cc
@@ -154,7 +154,7 @@ inline static bool BackwardActStorageType(const nnvm::NodeAttrs& attrs,
MXNET_OPERATOR_REGISTER_UNARY(Activation)
-.add_alias("_npx_Activation")
+.add_alias("_npx_activation")
.describe(R"code(Applies an activation function element-wise to the input.
The following activation functions are supported:
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index 030f589..6382d46 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -520,7 +520,7 @@ std::vector<nnvm::NodeEntry> BatchNormGrad(const nnvm::NodePtr& n,
}
NNVM_REGISTER_OP(BatchNorm)
-.add_alias("_npx_BatchNorm")
+.add_alias("_npx_batch_norm")
.describe(R"code(Batch normalization.
Normalizes a data batch by mean and variance, and applies a scale ``gamma`` as
diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc
index 6ab388a..32ed93e 100644
--- a/src/operator/nn/convolution.cc
+++ b/src/operator/nn/convolution.cc
@@ -397,7 +397,7 @@ struct ConvolutionGrad {
};
NNVM_REGISTER_OP(Convolution)
-.add_alias("_npx_Convolution")
+.add_alias("_npx_convolution")
.describe(R"code(Compute *N*-D convolution on *(N+2)*-D input.
In the 2-D convolution, given input data with shape *(batch_size,
diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc
index 09b255d..9f461f4e 100644
--- a/src/operator/nn/deconvolution.cc
+++ b/src/operator/nn/deconvolution.cc
@@ -408,6 +408,7 @@ struct DeconvolutionGrad {
DMLC_REGISTER_PARAMETER(DeconvolutionParam);
NNVM_REGISTER_OP(Deconvolution)
+.add_alias("_npx_deconvolution")
.describe("Computes 1D or 2D transposed convolution (aka fractionally strided convolution) of the "
"input tensor. This operation can be seen as the gradient of Convolution operation with "
"respect to its input. Convolution usually reduces the size of the input. Transposed "
diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc
index 72ba422..29f13a4 100644
--- a/src/operator/nn/dropout.cc
+++ b/src/operator/nn/dropout.cc
@@ -65,7 +65,7 @@ struct DropoutGrad {
DMLC_REGISTER_PARAMETER(DropoutParam);
NNVM_REGISTER_OP(Dropout)
-.add_alias("_npx_Dropout")
+.add_alias("_npx_dropout")
.describe(R"(Applies dropout operation to input array.
- During training, each element of the input is set to zero with probability p.
diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc
index 9f30ed2..06ad6d0 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -244,7 +244,7 @@ DMLC_REGISTER_PARAMETER(FullyConnectedParam);
NNVM_REGISTER_OP(FullyConnected)
MXNET_ADD_SPARSE_OP_ALIAS(FullyConnected)
-.add_alias("_npx_FullyConnected")
+.add_alias("_npx_fully_connected")
.describe(R"code(Applies a linear transformation: :math:`Y = XW^T + b`.
If ``flatten`` is set to be true, then the shapes are:
diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc
index 7c6ddcb..0b53d50 100644
--- a/src/operator/nn/layer_norm.cc
+++ b/src/operator/nn/layer_norm.cc
@@ -127,7 +127,7 @@ void LayerNormGradCompute<cpu>(const nnvm::NodeAttrs& attrs,
}
NNVM_REGISTER_OP(LayerNorm)
-.add_alias("_npx_LayerNorm")
+.add_alias("_npx_layer_norm")
.describe(R"code(Layer normalization.
Normalizes the channels of the input tensor by mean and variance, and applies a scale ``gamma`` as
diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc
index 0df5827..485fc13 100644
--- a/src/operator/nn/pooling.cc
+++ b/src/operator/nn/pooling.cc
@@ -364,7 +364,7 @@ inline static bool BackwardPoolingStorageType(const nnvm::NodeAttrs &attrs,
DMLC_REGISTER_PARAMETER(PoolingParam);
NNVM_REGISTER_OP(Pooling)
-.add_alias("_npx_Pooling")
+.add_alias("_npx_pooling")
.describe(R"code(Performs pooling on the input.
The shapes for 1-D pooling are
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc
index 2ffa3b8..fe5aeb0 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cc
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cc
@@ -34,14 +34,9 @@ bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
- const int itype = in_attrs->at(0);
- if (itype == -1) return false;
- auto is_float = [](const int dtype) {
- return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16;
- };
- CHECK(is_float(itype)) << "numpy binary scalar op currently only supports float dtype";
- TYPE_ASSIGN_CHECK(*out_attrs, 0, itype);
- return true;
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
+ TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
+ return in_attrs->at(0) != -1;
}
#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc
index 58f190a..244e393 100644
--- a/src/operator/rnn.cc
+++ b/src/operator/rnn.cc
@@ -634,7 +634,7 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr,
#endif
NNVM_REGISTER_OP(RNN)
-.add_alias("_npx_RNN")
+.add_alias("_npx_rnn")
.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are
implemented, with both multi-layer and bidirectional support.
diff --git a/src/operator/roi_pooling.cc b/src/operator/roi_pooling.cc
index bba3bea..56c8725 100644
--- a/src/operator/roi_pooling.cc
+++ b/src/operator/roi_pooling.cc
@@ -230,5 +230,9 @@ Example::
"corners of designated region of interest. `batch_index` indicates the index of corresponding "
"image in the input array")
.add_arguments(ROIPoolingParam::__FIELDS__());
+
+NNVM_REGISTER_OP(ROIPooling)
+.add_alias("_npx_roi_pooling");
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/sequence_mask.cc b/src/operator/sequence_mask.cc
index ca58be1..d773102 100644
--- a/src/operator/sequence_mask.cc
+++ b/src/operator/sequence_mask.cc
@@ -192,7 +192,7 @@ Example::
.add_arguments(SequenceMaskParam::__FIELDS__());
NNVM_REGISTER_OP(SequenceMask)
-.add_alias("_npx_SequenceMask");
+.add_alias("_npx_sequence_mask");
} // namespace op
} // namespace mxnet
diff --git a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc
index f027665..3a687c2 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc
+++ b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc
@@ -84,7 +84,8 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_hypot_scalar)
cpu, mshadow_op::hypot_grad_left>);
NNVM_REGISTER_OP(smooth_l1)
- .describe(R"code(Calculate Smooth L1 Loss(lhs, scalar) by summing
+.add_alias("_npx_smooth_l1")
+.describe(R"code(Calculate Smooth L1 Loss(lhs, scalar) by summing
.. math::
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc
index a955508..3dffc73 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -650,6 +650,7 @@ Example::
DMLC_REGISTER_PARAMETER(CastParam);
NNVM_REGISTER_OP(Cast)
.add_alias("cast")
+.add_alias("_npx_cast")
.describe(R"code(Casts all elements of the input to a new type.
.. note:: ``Cast`` is deprecated. Use ``cast`` instead.
diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index f229fef..ad4e54d 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -466,7 +466,7 @@ DMLC_REGISTER_PARAMETER(ScatterNDParam);
NNVM_REGISTER_OP(Embedding)
MXNET_ADD_SPARSE_OP_ALIAS(Embedding)
-.add_alias("_npx_Embedding")
+.add_alias("_npx_embedding")
.describe(R"code(Maps integer indices to vector representations (embeddings).
This operator maps words to real-valued vectors in a high-dimensional space,