You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/07/18 00:00:42 UTC
[incubator-mxnet] 06/42: [numpy] Some np ops for d2l (#14924)
This is an automated email from the ASF dual-hosted git repository.
haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 2d0c8043d020a8a72a5b6cb4cc4c709a9a372026
Author: reminisce <wu...@gmail.com>
AuthorDate: Thu May 9 20:15:18 2019 -0700
[numpy] Some np ops for d2l (#14924)
* Add np transpose
More ops and namespaces for submodules
Add relu and sigmoid
Add reshape
Fix symbolic name mismatch
Add maximum and minimum
* Add convenience fluent method
* Add ndarray.item()
* Fix CI
* Fix lint
* Fix lint
* Fix reshape gpu
* Add example
* Remove python notebook outputs
* Remove notebook output
* Add one more example
---
example/numpy/demo.ipynb | 415 +++++++++++++++++++++
include/mxnet/tuple.h | 8 +
python/mxnet/base.py | 9 +-
python/mxnet/ndarray/numpy/__init__.py | 3 +
python/mxnet/ndarray/numpy/_op.py | 90 ++++-
.../{numpy/linalg.py => ndarray/numpy/ext.py} | 2 +-
python/mxnet/{ => ndarray}/numpy/linalg.py | 2 +-
python/mxnet/{ => ndarray}/numpy/random.py | 2 +-
python/mxnet/numpy/__init__.py | 5 +-
python/mxnet/numpy/{linalg.py => ext.py} | 2 +-
python/mxnet/numpy/linalg.py | 2 +-
python/mxnet/numpy/multiarray.py | 112 +++++-
python/mxnet/numpy/random.py | 2 +-
python/mxnet/symbol/numpy/__init__.py | 3 +
python/mxnet/symbol/numpy/_symbol.py | 92 ++++-
.../mxnet/{numpy/linalg.py => symbol/numpy/ext.py} | 2 +-
python/mxnet/{ => symbol}/numpy/linalg.py | 2 +-
python/mxnet/{ => symbol}/numpy/random.py | 2 +-
src/c_api/c_api_common.h | 6 +-
src/operator/numpy/np_elemwise_broadcast_op.cc | 18 +
src/operator/numpy/np_elemwise_broadcast_op.cu | 15 +-
src/operator/numpy/np_elemwise_unary_op_basic.cc | 63 ++++
src/operator/numpy/np_elemwise_unary_op_basic.cu | 39 ++
src/operator/numpy/np_matrix_op-inl.h | 65 ++++
src/operator/numpy/np_matrix_op.cc | 218 +++++++++++
src/operator/numpy/np_matrix_op.cu | 37 ++
src/operator/tensor/elemwise_binary_broadcast_op.h | 1 +
src/operator/tensor/matrix_op-inl.h | 8 +-
tests/python/unittest/test_numpy_ndarray.py | 1 -
tests/python/unittest/test_numpy_op.py | 120 ++++++
30 files changed, 1295 insertions(+), 51 deletions(-)
diff --git a/example/numpy/demo.ipynb b/example/numpy/demo.ipynb
new file mode 100644
index 0000000..d8e6e06
--- /dev/null
+++ b/example/numpy/demo.ipynb
@@ -0,0 +1,415 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Fundamentals of MXNet Numpy Module\n",
+ "\n",
+ "## Operator Namespaces for Imperative Programming\n",
+ "- `mxnet.numpy`: Regular NumPy operators\n",
+ "- `mxnet.numpy.random`: NumPy random operators\n",
+ "- `mxnet.numpy.linalg`: NumPy linear algebra operators\n",
+ "- `mxnet.numpy.ext`: Operators implemented in MXNet that do not exist in official NumPy\n",
+ "\n",
+ "## Operator Namespaces for Gluon\n",
+ "`F` can be either `mxnet.ndarray` or `mxnet.symbol`.\n",
+ "- `F.np`: Regular NumPy operators\n",
+ "- `F.np.random`: NumPy random operators\n",
+ "- `F.np.linalg`: NumPy linear algebra operators\n",
+ "- `F.np.ext`: Operators implemented in MXNet that do not exist in official NumPy\n",
+ "\n",
+ "## New `ndarray` and `symbol`\n",
+ "`mxnet.numpy.ndarray` and `mxnet.symbol.numpy._NumpySymbol` (not visible to users)\n",
+ "- Same name as in the official NumPy package\n",
+ "- Dispatch convience fluent method calls to MXNet Numpy operators\n",
+ "- Override many convenience fluent methods that do not exist in the official NumPy ndarray\n",
+ "- Make the behavior of built-in methods consistent with the official NumPy\n",
+ " - Indexing: `__getitem__` and `__setitem__`\n",
+ " - Many binary element-wise with broadcasting, not supported in `mxnet.symbol.Symbol`\n",
+ " \n",
+ "## Examples of ndarray and symbol Basics\n",
+ "### Scalar and zero-size tensors"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import mxnet as mx\n",
+ "from mxnet import numpy as np\n",
+ "\n",
+ "# use numpy-compatible semantics\n",
+ "mx.set_np_compat(True)\n",
+ "\n",
+ "# create a scalar tensor\n",
+ "x = np.array(3.14)\n",
+ "print(x)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "s = x.item() # copy the element from the scalar tensor to a python scalar\n",
+ "print('s = {}'.format(str(s)))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# create a scalar tensors with only one element 1.0\n",
+ "y = np.ones(())\n",
+ "print(y)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# create a zero-size tensor\n",
+ "x = np.ones((5, 4, 0, 6))\n",
+ "print(x)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# transpose the zero-size tensor\n",
+ "y = np.transpose(x)\n",
+ "print(y)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Conversion between classic and numpy ndarrays"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# create a classic MXNet NDArray\n",
+ "x = mx.nd.random.uniform(shape=(2, 3))\n",
+ "print(x)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# convert classic NDArray type to mxnet.numpy.ndarray with zero-copy\n",
+ "y = x.as_np_ndarray()\n",
+ "print(y)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# changing y's content changes x's content too\n",
+ "y[:] = 1\n",
+ "print(x)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# convert mxnet.numpy.ndarray to classic NDArray with zero-copy\n",
+ "z = y.as_classic_ndarray()\n",
+ "print(z)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# changing z's content changes y's content too\n",
+ "z[:] = 2\n",
+ "print(y)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Binary element-wise operations with broadcasting in new and old symbols"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from mxnet import gluon\n",
+ "class TestBinaryBroadcast(gluon.HybridBlock):\n",
+ " def hybrid_forward(self, F, x1, x2):\n",
+ " print(\"x1 type:\", str(type(x1)))\n",
+ " print(\"x2 type:\", str(type(x2)))\n",
+ " return x1 + x2\n",
+ "\n",
+ "net = TestBinaryBroadcast()\n",
+ "x1 = mx.nd.ones((2, 1))\n",
+ "x2 = mx.nd.ones((1, 3))\n",
+ "out = net(x1, x2) # ok: imperative execution supports broadcasting\n",
+ "print(out)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "net.hybridize() # mark the block for execution using a computational graph\n",
+ "try:\n",
+ " out = net(x1, x2) # error: old symbol `+` operation does not support broadcasting\n",
+ " assert False # should not reach here\n",
+ "except mx.MXNetError:\n",
+ " print(\"ERROR: cannot perform broadcast add for two symbols of mxnet.sym.Symbol\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class TestBinaryBroadcast2(gluon.HybridBlock):\n",
+ " def hybrid_forward(self, F, x1, x2):\n",
+ " print(\"x1 type:\", str(type(x1)))\n",
+ " print(\"x2 type:\", str(type(x2)))\n",
+ " return x1.as_np_ndarray() + x2 # convert x1 to new numpy ndarray/symbol\n",
+ "\n",
+ "net2 = TestBinaryBroadcast2()\n",
+ "net2.hybridize()\n",
+ "\n",
+ "out =net2(x1, x2)\n",
+ "print(out)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "net = TestBinaryBroadcast() # Create a new block object to clear the graph\n",
+ "net.hybridize() # mark the block for execution using a computational graph\n",
+ "\n",
+ "x1 = x1.as_np_ndarray() # convert x1 to np.ndarray so that _NumpySymbol will be used in graph construction\n",
+ "x2 = x2.as_np_ndarray() # convert x2 to np.ndarray so that _NumpySymbol will be used in graph construction\n",
+ "out = net(x1, x2) # ok: `+` operation supports broadcasting for _NumpySymbol\n",
+ "print(out) # mxnet.numpy.ndarray type, because it's from a np operator"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## A Simple Linear Regression Model\n",
+ "Let's consider a simple linear regression model as the following.\n",
+ "Given dataset `{x, y}`, where `x`s represent input examples and `y`s represent observed data, find the parameters `w1` and `w2` for the following model.\n",
+ "```\n",
+ "y_pred = np.dot(np.maximum(np.dot(x, w1), 0), w2)\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## MXNet Numpy Operators in Imperative Programming"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import mxnet as mx\n",
+ "from mxnet import numpy as np\n",
+ "from mxnet import autograd\n",
+ "try:\n",
+ " from mxboard import SummaryWriter\n",
+ "except ImportError:\n",
+ " SummaryWriter = None\n",
+ "\n",
+ "# create a summary writer for visualization\n",
+ "sw = SummaryWriter(logdir='./logs', flush_secs=2) if SummaryWriter is not None else None\n",
+ "\n",
+ "# Use numpy-compatible semantics to support scalar tensors\n",
+ "mx.set_np_compat(True)\n",
+ "\n",
+ "# N is number of examples; D_in is input dimension;\n",
+ "# H is hidden dimension; D_out is output dimension.\n",
+ "N, D_in, H, D_out = 64, 1000, 100, 10\n",
+ "\n",
+ "# Create random input and output data\n",
+ "x = mx.nd.random.normal(shape=(N, D_in)).as_np_ndarray() # x is of type mxnet.numpy.ndarray\n",
+ "y = mx.nd.random.normal(shape=(N, D_out)).as_np_ndarray() # y is of type mxnet.numpy.ndarray\n",
+ "\n",
+ "# Randomly initialize weights\n",
+ "w1 = mx.nd.random.normal(shape=(D_in, H)).as_np_ndarray() # w1 is of type mxnet.numpy.ndarray\n",
+ "w1.attach_grad() # w1.grad is of type mxnet.numpy.ndarray\n",
+ "w2 = mx.nd.random.normal(shape=(H, D_out)).as_np_ndarray() # w2 is of type mxnet.numpy.ndarray\n",
+ "w2.attach_grad() # w2.grad is of type mxnet.numpy.ndarray\n",
+ "\n",
+ "learning_rate = 1e-6\n",
+ "\n",
+ "\n",
+ "for t in range(1000):\n",
+ " with autograd.record():\n",
+ " # Forward pass: compute predicted y\n",
+ " h = x.dot(w1) # equivalent to np.dot(x, w1)\n",
+ " h_relu = np.ext.relu(h) # equivalent to mx.nd.relu(h)\n",
+ " y_pred = h_relu.dot(w2) # equivalent to np.dot(h_relu, w2)\n",
+ "\n",
+ " # Compute loss\n",
+ " # (y_pred - y) ** 2 calls np.ndarray.__pow__\n",
+ " # sum() calls np.sum() which should return a scalar tensor\n",
+ " loss = ((y_pred - y) ** 2).sum()\n",
+ " # Note that the print function will invoke loss.asnumpy()\n",
+ " print(t, loss) # loss is a scalar tensor of type mxnet.numpy.ndarray\n",
+ " loss.backward()\n",
+ "\n",
+ " # Update weights\n",
+ " w1 -= learning_rate * w1.grad\n",
+ " w2 -= learning_rate * w2.grad\n",
+ "\n",
+ " if sw is not None:\n",
+ " sw.add_scalar('loss', loss.item(), global_step=t) # loss.item() copies the tensor element to a python scalar\n",
+ " if t % 50 == 0:\n",
+ " sw.add_histogram(tag='w1', values=w1, global_step=t)\n",
+ " sw.add_histogram(tag='w2', values=w2, global_step=t)\n",
+ "\n",
+ "if sw is not None:\n",
+ " sw.close()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## MXNet Numpy Operators in Gluon `HybridBlock`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import mxnet as mx\n",
+ "from mxnet import gluon, autograd\n",
+ "try:\n",
+ " from mxboard import SummaryWriter\n",
+ "except ImportError:\n",
+ " SummaryWriter = None\n",
+ "\n",
+ "# create a summary writer for visualization\n",
+ "sw = SummaryWriter(logdir='./logs', flush_secs=2) if SummaryWriter is not None else None\n",
+ "\n",
+ "# Use numpy-compatible semantics to support scalar tensors\n",
+ "mx.set_np_compat(True)\n",
+ "\n",
+ "\n",
+ "class LinearRegression(gluon.HybridBlock):\n",
+ " def __init__(self, num_input_dim=1000, num_hidden_dim=100, num_output_dim=10):\n",
+ " super(LinearRegression, self).__init__()\n",
+ " with self.name_scope():\n",
+ " self.w1 = self.params.get('w1', shape=(num_input_dim, num_hidden_dim),\n",
+ " allow_deferred_init=True)\n",
+ " self.w2 = self.params.get('w2', shape=(num_hidden_dim, num_output_dim),\n",
+ " allow_deferred_init=True)\n",
+ "\n",
+ " def hybrid_forward(self, F, x, w1, w2):\n",
+ " h = x.dot(w1) # equivalent to F.np.dot(x, w1)\n",
+ " h_relu = F.np.ext.relu(h) # equivalent to F.relu(h)\n",
+ " y_pred = h_relu.dot(w2) # equivalent to F.np.dot(h_relu, w2)\n",
+ " return y_pred\n",
+ "\n",
+ "\n",
+ "class TotalLoss(gluon.HybridBlock):\n",
+ " def hybrid_forward(self, F, pred, label):\n",
+ " return ((pred - label) ** 2).sum() # equivalent to F.np.sum(F.np.square(pred - label))\n",
+ "\n",
+ "\n",
+ "regressor = LinearRegression()\n",
+ "regressor.initialize(mx.init.Normal())\n",
+ "regressor.hybridize()\n",
+ "\n",
+ "# Create random input and output data\n",
+ "x = mx.nd.random.normal(shape=(64, 1000)).as_np_ndarray() # x is of type mxnet.numpy.ndarray\n",
+ "y = mx.nd.random.normal(shape=(64, 10)).as_np_ndarray() # y is of type mxnet.numpy.ndarray\n",
+ "\n",
+ "total_loss = TotalLoss()\n",
+ "trainer = gluon.Trainer(regressor.collect_params(), 'sgd', {'learning_rate': 1e-3, 'momentum': 0.9})\n",
+ "\n",
+ "for t in range(1000):\n",
+ " with autograd.record():\n",
+ " output = regressor(x) # output is a type of np.ndarray because np.dot is the last op in the network\n",
+ " loss = total_loss(output, y) # loss is a scalar np.ndarray\n",
+ " loss.backward()\n",
+ " print(t, loss) # note that loss.asnumpy() is called\n",
+ " trainer.step(1)\n",
+ " if sw is not None:\n",
+ " sw.add_scalar('loss', loss.item(), global_step=t) # loss.item() copies the tensor element to a python scalar\n",
+ " if t % 50 == 0:\n",
+ " for k, v in regressor.collect_params().items():\n",
+ " sw.add_histogram(tag=k, values=v.data(), global_step=t)\n",
+ "\n",
+ "if sw is not None:\n",
+ " sw.close()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/include/mxnet/tuple.h b/include/mxnet/tuple.h
index bc630f1..08381e2 100644
--- a/include/mxnet/tuple.h
+++ b/include/mxnet/tuple.h
@@ -272,6 +272,14 @@ class Tuple {
is.get();
if (ch == '(' || ch == '[') break;
if (!isspace(ch)) {
+ if (ch == 'N') {
+ std::string tmp_val;
+ is >> tmp_val;
+ if (tmp_val == "one") { // is stores "None"
+ t.SetDim(-1);
+ return is;
+ }
+ }
is.setstate(std::ios::failbit);
return is;
}
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 0d4bf53..131cb4d 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -743,7 +743,7 @@ def _sanity_check_params(func_name, unsupported_params, param_dict):
.format(func_name, param_name))
-_NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_']
+_NP_OP_SUBMODULE_LIST = ['_ext_', '_random_', '_linalg_']
_NP_OP_PREFIX = '_numpy_'
@@ -792,10 +792,9 @@ def _init_np_op_module(root_namespace, module_name, make_op_func):
submodule_pattern = "%s.%s.numpy.%s"
module_np_op = sys.modules[module_pattern % (root_namespace, module_name)]
submodule_dict = {}
- # TODO(junwu): uncomment the following lines when adding numpy ops in submodules, e.g. np.random
- # for submodule_name in _NP_OP_SUBMODULE_LIST:
- # submodule_dict[submodule_name] = \
- # sys.modules[submodule_pattern % (root_namespace, module_name, submodule_name[1:-1])]
+ for submodule_name in _NP_OP_SUBMODULE_LIST:
+ submodule_dict[submodule_name] = \
+ sys.modules[submodule_pattern % (root_namespace, module_name, submodule_name[1:-1])]
for name in op_names:
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
diff --git a/python/mxnet/ndarray/numpy/__init__.py b/python/mxnet/ndarray/numpy/__init__.py
index a714a4b..d97e808 100644
--- a/python/mxnet/ndarray/numpy/__init__.py
+++ b/python/mxnet/ndarray/numpy/__init__.py
@@ -17,6 +17,9 @@
"""numpy module for numpy ops under mxnet.ndarray."""
+from . import ext
+from . import random
+from . import linalg
from . import _op
from . import _register
from ._op import * # pylint: disable=wildcard-import
diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 383bf2f..9b32c31 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -19,11 +19,12 @@
from __future__ import absolute_import
import numpy as _np
-from ...base import _sanity_check_params, use_np_compat
+from ...base import _sanity_check_params, use_np_compat, numeric_types
from ...context import current_context
from .. import _internal
+from ..ndarray import NDArray
-__all__ = ['zeros', 'ones']
+__all__ = ['zeros', 'ones', 'maximum', 'minimum']
@use_np_compat
@@ -86,3 +87,88 @@ def ones(shape, dtype=None, **kwargs):
ctx = current_context()
dtype = _np.float32 if dtype is None else dtype
return _internal._np_ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
+
+
+#pylint: disable= too-many-arguments, no-member, protected-access
+def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, out=None):
+ """ Helper function for element-wise operation.
+ The function will perform numpy-like broadcasting if needed and call different functions.
+
+ Parameters
+ --------
+ lhs : NDArray or numeric value
+ Left-hand side operand.
+
+ rhs : NDArray or numeric value
+ Right-hand operand,
+
+ fn_array : function
+ Function to be called if both lhs and rhs are of ``NDArray`` type.
+
+ fn_scalar : function
+ Function to be called if both lhs and rhs are numeric values.
+
+ lfn_scalar : function
+ Function to be called if lhs is ``NDArray`` while rhs is numeric value
+
+ rfn_scalar : function
+ Function to be called if lhs is numeric value while rhs is ``NDArray``;
+ if none is provided, then the function is commutative, so rfn_scalar is equal to lfn_scalar
+
+ Returns
+ --------
+ mxnet.numpy.ndarray
+ result array
+ """
+ if isinstance(lhs, numeric_types):
+ if isinstance(rhs, numeric_types):
+ return fn_scalar(lhs, rhs, out=out)
+ else:
+ if rfn_scalar is None:
+ # commutative function
+ return lfn_scalar(rhs, float(lhs), out=out)
+ else:
+ return rfn_scalar(rhs, float(lhs), out=out)
+ elif isinstance(rhs, numeric_types):
+ return lfn_scalar(lhs, float(rhs), out=out)
+ elif isinstance(rhs, NDArray):
+ return fn_array(lhs, rhs, out=out)
+ else:
+ raise TypeError('type %s not supported' % str(type(rhs)))
+#pylint: enable= too-many-arguments, no-member, protected-access
+
+
+@use_np_compat
+def maximum(x1, x2, out=None):
+ """Returns element-wise maximum of the input arrays with broadcasting.
+
+ Parameters
+ ----------
+ x1, x2 : scalar or mxnet.numpy.ndarray
+ The arrays holding the elements to be compared. They must have the same shape,
+ or shapes that can be broadcast to a single shape.
+
+ Returns
+ -------
+ out : mxnet.numpy.ndarray or scalar
+ The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
+ return _ufunc_helper(x1, x2, _internal._np_maximum, _np.maximum,
+ _internal._np_maximum_scalar, None, out)
+
+
+@use_np_compat
+def minimum(x1, x2, out=None):
+ """Returns element-wise minimum of the input arrays with broadcasting.
+
+ Parameters
+ ----------
+ x1, x2 : scalar or mxnet.numpy.ndarray
+ The arrays holding the elements to be compared. They must have the same shape,
+ or shapes that can be broadcast to a single shape.
+
+ Returns
+ -------
+ out : mxnet.numpy.ndarray or scalar
+ The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
+ return _ufunc_helper(x1, x2, _internal._np_minimum, _np.minimum,
+ _internal._np_minimum_scalar, None, out)
diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/ndarray/numpy/ext.py
similarity index 89%
copy from python/mxnet/numpy/linalg.py
copy to python/mxnet/ndarray/numpy/ext.py
index 1527c61..e13423f 100644
--- a/python/mxnet/numpy/linalg.py
+++ b/python/mxnet/ndarray/numpy/ext.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""namespace for registering numpy ops of linear algebra."""
+"""numpy.ext namespace for operators used in Gluon APIs dispatched by F=ndarray module."""
__all__ = []
diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py
similarity index 89%
copy from python/mxnet/numpy/linalg.py
copy to python/mxnet/ndarray/numpy/linalg.py
index 1527c61..b8f10b3 100644
--- a/python/mxnet/numpy/linalg.py
+++ b/python/mxnet/ndarray/numpy/linalg.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""namespace for registering numpy ops of linear algebra."""
+"""numpy.linalg namespace for operators used in Gluon APIs dispatched by F=symbol module."""
__all__ = []
diff --git a/python/mxnet/numpy/random.py b/python/mxnet/ndarray/numpy/random.py
similarity index 89%
copy from python/mxnet/numpy/random.py
copy to python/mxnet/ndarray/numpy/random.py
index 461da66..60908b5 100644
--- a/python/mxnet/numpy/random.py
+++ b/python/mxnet/ndarray/numpy/random.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""namespace for registering numpy random operators."""
+"""numpy.random namespace for operators used in Gluon APIs dispatched by F=ndarray module."""
__all__ = []
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py
index c4dea9e..2a58f27 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/numpy/__init__.py
@@ -20,10 +20,11 @@
"""numpy module for imperative programming."""
from __future__ import absolute_import
-from .multiarray import * # pylint: disable=wildcard-import
-from . import _op
from . import random
from . import linalg
+from . import ext
+from .multiarray import * # pylint: disable=wildcard-import
+from . import _op
from . import _register
from ._op import * # pylint: disable=wildcard-import
diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/ext.py
similarity index 91%
copy from python/mxnet/numpy/linalg.py
copy to python/mxnet/numpy/ext.py
index 1527c61..e4c8251 100644
--- a/python/mxnet/numpy/linalg.py
+++ b/python/mxnet/numpy/ext.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""namespace for registering numpy ops of linear algebra."""
+"""namespace for registering numpy.ext ops for imperative programming."""
__all__ = []
diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py
index 1527c61..96c7ddc 100644
--- a/python/mxnet/numpy/linalg.py
+++ b/python/mxnet/numpy/linalg.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""namespace for registering numpy ops of linear algebra."""
+"""namespace for registering numpy.linalg ops for imperative programming."""
__all__ = []
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 9f47ce1..6c414b4 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -27,14 +27,14 @@ import ctypes
import numpy as _np
from ..ndarray import NDArray, _DTYPE_NP_TO_MX
from ..ndarray._internal import _set_np_ndarray_class
-from . import _op
+from . import _op as _mx_np_op
from ..base import use_np_compat, check_call, _LIB, NDArrayHandle, _sanity_check_params
from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types
from ..context import current_context
from ..ndarray import numpy as _mx_nd_np
from ..ndarray import _internal as _nd_internal
-__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones']
+__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum']
# This function is copied from ndarray.py since pylint
@@ -73,7 +73,7 @@ def _np_ndarray_cls(handle, writable=True, stype=0):
_set_np_ndarray_class(_np_ndarray_cls)
-class ndarray(NDArray):
+class ndarray(NDArray): # pylint: disable=invalid-name
"""An array object represents a multidimensional, homogeneous array of fixed-size items.
An associated data-type object describes the format of each element in the array
(its byte-order, how many bytes it occupies in memory, whether it is an integer, a
@@ -104,7 +104,15 @@ class ndarray(NDArray):
@use_np_compat
def __iadd__(self, other):
- raise NotImplementedError
+ """x.__iadd__(y) <=> x += y"""
+ if not self.writable:
+ raise ValueError('trying to add to a readonly ndarray')
+ if isinstance(other, NDArray):
+ return _nd_internal._np_add(self, other, out=self)
+ elif isinstance(other, numeric_types):
+ return _nd_internal._np_add_scalar(self, float(other), out=self)
+ else:
+ raise TypeError('type {} is not supported'.format(str(type(other))))
@use_np_compat
def __sub__(self, other):
@@ -118,7 +126,15 @@ class ndarray(NDArray):
@use_np_compat
def __isub__(self, other):
- raise NotImplementedError
+ """x.__isub__(y) <=> x -= y"""
+ if not self.writable:
+ raise ValueError('trying to subtract from a readonly ndarray')
+ if isinstance(other, NDArray):
+ return _nd_internal._np_subtract(self, other, out=self)
+ elif isinstance(other, numeric_types):
+ return _nd_internal._np_subtract_scalar(self, float(other), out=self)
+ else:
+ raise TypeError('type {} is not supported'.format(str(type(other))))
@use_np_compat
def __rsub__(self, other):
@@ -285,6 +301,36 @@ class ndarray(NDArray):
def __reduce__(self):
return ndarray, (None,), self.__getstate__()
+ def item(self, *args):
+ """Copy an element of an array to a standard Python scalar and return it.
+
+ Parameters
+ ----------
+ *args : Arguments (variable number and type)
+ none: in this case, the method only works for arrays with one element (a.size == 1),
+ which element is copied into a standard Python scalar object and returned.
+
+ int_type: this argument is interpreted as a flat index into the array, specifying which
+ element to copy and return.
+
+ tuple of int_types: functions as does a single int_type argument, except that the
+ argument is interpreted as an nd-index into the array.
+
+ Returns
+ -------
+ z : Standard Python scalar object
+ A copy of the specified element of the array as a suitable Python scalar.
+ """
+ # TODO(junwu): no need to call asnumpy() on the whole array.
+ return self.asnumpy().item(*args)
+
+ @property
+ # pylint: disable= invalid-name, undefined-variable
+ def T(self):
+ """Same as self.transpose(). This always returns a copy of self."""
+ return self.transpose()
+ # pylint: enable= invalid-name, undefined-variable
+
@use_np_compat
def _slice(self, start, stop):
raise NotImplementedError
@@ -380,9 +426,16 @@ class ndarray(NDArray):
return super(ndarray, self).copy().as_np_ndarray()
@use_np_compat
- def reshape(self, *shape, **kwargs):
+ def dot(self, b, out=None):
+ return _mx_np_op.dot(self, b, out=out)
+
+ @use_np_compat
+ def reshape(self, shape, order='C'): # pylint: disable=arguments-differ
"""Returns an array containing the same data with a new shape."""
- raise NotImplementedError
+ if order != 'C':
+ raise NotImplementedError('reshape only supports C-order,'
+ ' while received {}'.format(order))
+ return _mx_np_op.reshape(self, shape=shape, order=order)
def reshape_like(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`reshape_like`.
@@ -626,13 +679,13 @@ class ndarray(NDArray):
raise AttributeError('mxnet.numpy.ndarray object has no attribute tile')
@use_np_compat
- def transpose(self, *args, **kwargs):
+ def transpose(self, *axes): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`transpose`.
The arguments are the same as for :py:func:`transpose`, with
this array as data.
"""
- raise NotImplementedError
+ return _mx_np_op.transpose(self, axes=axes if len(axes) != 0 else None)
def flip(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`flip`.
@@ -667,13 +720,13 @@ class ndarray(NDArray):
raise AttributeError('mxnet.numpy.ndarray object has no attribute diag')
@use_np_compat
- def sum(self, *args, **kwargs):
+ def sum(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`sum`.
The arguments are the same as for :py:func:`sum`, with
this array as data.
"""
- return _op.sum(self, *args, **kwargs)
+ return _mx_np_op.sum(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims)
def nansum(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`nansum`.
@@ -1069,11 +1122,6 @@ class ndarray(NDArray):
def stype(self):
raise AttributeError('mxnet.numpy.ndarray object has no attribute stype')
- @property
- @use_np_compat
- def T(self):
- raise NotImplementedError
-
def tostype(self, stype):
raise AttributeError('mxnet.numpy.ndarray object has no attribute tostype')
@@ -1198,3 +1246,35 @@ def ones(shape, dtype=None, **kwargs):
Array of zeros with the given shape, dtype, and ctx.
"""
return _mx_nd_np.ones(shape, dtype, **kwargs)
+
+
+def maximum(x1, x2, out=None):
+ """Returns element-wise maximum of the input arrays with broadcasting.
+
+ Parameters
+ ----------
+ x1, x2 : scalar or mxnet.numpy.ndarray
+ The arrays holding the elements to be compared. They must have the same shape,
+ or shapes that can be broadcast to a single shape.
+
+ Returns
+ -------
+ out : mxnet.numpy.ndarray or scalar
+ The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
+ return _mx_nd_np.maximum(x1, x2, out=out)
+
+
+def minimum(x1, x2, out=None):
+ """Returns element-wise minimum of the input arrays with broadcasting.
+
+ Parameters
+ ----------
+ x1, x2 : scalar or mxnet.numpy.ndarray
+ The arrays holding the elements to be compared. They must have the same shape,
+ or shapes that can be broadcast to a single shape.
+
+ Returns
+ -------
+ out : mxnet.numpy.ndarray or scalar
+ The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
+ return _mx_nd_np.minimum(x1, x2, out=out)
diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py
index 461da66..b1f4b02 100644
--- a/python/mxnet/numpy/random.py
+++ b/python/mxnet/numpy/random.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""namespace for registering numpy random operators."""
+"""namespace for registering numpy.random ops for imperative programming."""
__all__ = []
diff --git a/python/mxnet/symbol/numpy/__init__.py b/python/mxnet/symbol/numpy/__init__.py
index d63daa2..1f20c03 100644
--- a/python/mxnet/symbol/numpy/__init__.py
+++ b/python/mxnet/symbol/numpy/__init__.py
@@ -17,6 +17,9 @@
"""numpy module for numpy ops under mxnet.symbol."""
+from . import random
+from . import linalg
+from . import ext
from . import _op, _symbol
from ._symbol import _NumpySymbol
from . import _register
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 087f118..8cf6e30 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -15,12 +15,13 @@
# specific language governing permissions and limitations
# under the License.
+# pylint: disable=too-many-lines
"""numpy namespace for operators used in Gluon APIs dispatched by F=symbol module."""
from __future__ import absolute_import
import ctypes
import numpy as _np
-from . import _op as _np_op
+from . import _op as _mx_np_op
from ...base import _sanity_check_params, use_np_compat, check_call, _LIB, SymbolHandle
from ...base import numeric_types
from ...context import current_context
@@ -29,7 +30,7 @@ from ..symbol import Symbol
from .._internal import _set_np_symbol_class
from .. import _internal as _sym_internal
-__all__ = ['zeros', 'ones']
+__all__ = ['zeros', 'ones', 'maximum', 'minimum']
class _NumpySymbol(Symbol):
@@ -237,13 +238,27 @@ class _NumpySymbol(Symbol):
check_call(_LIB.MXShallowCopySymbol(self.handle, ctypes.byref(hdl)))
return Symbol(handle=hdl)
+ @property
+ # pylint: disable= invalid-name, undefined-variable
+ def T(self):
+ """Same as self.transpose()."""
+ return self.transpose()
+ # pylint: enable= invalid-name, undefined-variable
+
@use_np_compat
def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ
raise NotImplementedError
@use_np_compat
- def reshape(self, *shape, **kwargs):
- raise NotImplementedError
+ def dot(self, b, out=None):
+ return _mx_np_op.dot(self, b, out=out)
+
+ @use_np_compat
+ def reshape(self, shape, order='C'): # pylint: disable=arguments-differ
+ if order != 'C':
+ raise NotImplementedError('ndarray.copy only supports order=\'C\', while '
+ 'received {}'.format(str(order)))
+ return _mx_np_op.reshape(self, shape=shape, order=order)
def reshape_like(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`reshape_like`.
@@ -487,13 +502,13 @@ class _NumpySymbol(Symbol):
raise AttributeError('_NumpySymbol object has no attribute tile')
@use_np_compat
- def transpose(self, *args, **kwargs):
+ def transpose(self, *axes): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`transpose`.
The arguments are the same as for :py:func:`transpose`, with
this array as data.
"""
- raise NotImplementedError
+ return _mx_np_op.transpose(self, axes=axes if len(axes) != 0 else None)
def flip(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`flip`.
@@ -528,13 +543,13 @@ class _NumpySymbol(Symbol):
raise AttributeError('_NumpySymbol object has no attribute diag')
@use_np_compat
- def sum(self, *args, **kwargs):
+ def sum(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`sum`.
The arguments are the same as for :py:func:`sum`, with
this array as data.
"""
- return _np_op.sum(self, *args, **kwargs)
+ return _mx_np_op.sum(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims)
def nansum(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`nansum`.
@@ -971,4 +986,65 @@ def ones(shape, dtype=None, **kwargs):
return _internal._np_ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
+#pylint: disable= too-many-arguments, no-member, protected-access
+def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, out=None):
+ """ Helper function for element-wise operation.
+ The function will perform numpy-like broadcasting if needed and call different functions.
+
+ Parameters
+ --------
+ lhs : Symbol or numeric value
+ Left-hand side operand.
+
+ rhs : Symbol or numeric value
+ Right-hand operand,
+
+ fn_array : function
+ Function to be called if both lhs and rhs are of ``Symbol`` type.
+
+ fn_scalar : function
+ Function to be called if both lhs and rhs are numeric values.
+
+ lfn_scalar : function
+ Function to be called if lhs is ``Symbol`` while rhs is numeric value
+
+ rfn_scalar : function
+ Function to be called if lhs is numeric value while rhs is ``Symbol``;
+ if none is provided, then the function is commutative, so rfn_scalar is equal to lfn_scalar
+
+ Returns
+ --------
+ mxnet.numpy.ndarray
+ result array
+ """
+ if isinstance(lhs, numeric_types):
+ if isinstance(rhs, numeric_types):
+ return fn_scalar(lhs, rhs, out=out)
+ else:
+ if rfn_scalar is None:
+ # commutative function
+ return lfn_scalar(rhs, float(lhs), out=out)
+ else:
+ return rfn_scalar(rhs, float(lhs), out=out)
+ elif isinstance(rhs, numeric_types):
+ return lfn_scalar(lhs, float(rhs), out=out)
+ elif isinstance(rhs, Symbol):
+ return fn_array(lhs, rhs, out=out)
+ else:
+ raise TypeError('type %s not supported' % str(type(rhs)))
+#pylint: enable= too-many-arguments, no-member, protected-access
+
+
+@use_np_compat
+def maximum(x1, x2, out=None):
+ return _ufunc_helper(x1, x2, _internal._np_maximum, _np.maximum,
+ _internal._np_maximum_scalar, None, out)
+
+
+@use_np_compat
+def minimum(x1, x2, out=None):
+ return _ufunc_helper(x1, x2, _internal._np_minimum, _np.minimum,
+ _internal._np_minimum_scalar, None, out)
+
+
_set_np_symbol_class(_NumpySymbol)
diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/symbol/numpy/ext.py
similarity index 89%
copy from python/mxnet/numpy/linalg.py
copy to python/mxnet/symbol/numpy/ext.py
index 1527c61..12c5f15 100644
--- a/python/mxnet/numpy/linalg.py
+++ b/python/mxnet/symbol/numpy/ext.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""namespace for registering numpy ops of linear algebra."""
+"""numpy.ext namespace for operators used in Gluon APIs dispatched by F=symbol module."""
__all__ = []
diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py
similarity index 89%
copy from python/mxnet/numpy/linalg.py
copy to python/mxnet/symbol/numpy/linalg.py
index 1527c61..b8f10b3 100644
--- a/python/mxnet/numpy/linalg.py
+++ b/python/mxnet/symbol/numpy/linalg.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""namespace for registering numpy ops of linear algebra."""
+"""numpy.linalg namespace for operators used in Gluon APIs dispatched by F=symbol module."""
__all__ = []
diff --git a/python/mxnet/numpy/random.py b/python/mxnet/symbol/numpy/random.py
similarity index 89%
copy from python/mxnet/numpy/random.py
copy to python/mxnet/symbol/numpy/random.py
index 461da66..79c73d8 100644
--- a/python/mxnet/numpy/random.py
+++ b/python/mxnet/symbol/numpy/random.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""namespace for registering numpy random operators."""
+"""numpy.random namespace for operators used in Gluon APIs dispatched by F=symbol module."""
__all__ = []
diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h
index ab1f5f7..82fe28b 100644
--- a/src/c_api/c_api_common.h
+++ b/src/c_api/c_api_common.h
@@ -177,11 +177,7 @@ extern const std::vector<std::string> kHiddenKeys;
inline bool IsNumpyCompatOp(const nnvm::Op* op) {
static const auto& is_np_compat =
nnvm::Op::GetAttr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible");
- if (is_np_compat.get(op, false)) {
- return true;
- }
- static const std::string prefix = "_numpy_";
- return op->name.find(prefix.c_str(), 0, prefix.size()) != std::string::npos;
+ return is_np_compat.get(op, false);
}
#endif // MXNET_C_API_C_API_COMMON_H_
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc
index e8988c8..5d36c29 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cc
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cc
@@ -161,6 +161,16 @@ Example::
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"})
.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_maximum)
+.describe(R"code()code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::maximum>)
+.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_minimum)
+.describe(R"code()code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::minimum>)
+.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+
MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_add_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::plus>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});
@@ -193,5 +203,13 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_rpower_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rpower>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_rpower_scalar"});
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_maximum_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::maximum>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_maximum_scalar"});
+
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_minimum_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::minimum>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_minimum_scalar"});
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu
index 186bd1b..26e2fce 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cu
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cu
@@ -42,6 +42,12 @@ NNVM_REGISTER_OP(_np_mod)
NNVM_REGISTER_OP(_np_power)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::power>);
+NNVM_REGISTER_OP(_np_maximum)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::maximum>);
+
+NNVM_REGISTER_OP(_np_minimum)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::minimum>);
+
NNVM_REGISTER_OP(_np_add_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::plus>);
@@ -52,8 +58,7 @@ NNVM_REGISTER_OP(_np_rsubtract_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rminus>);
NNVM_REGISTER_OP(_np_multiply_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::mul>)
-.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryScalarOp::ComputeEx<gpu, op::mshadow_op::mul>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::mul>);
NNVM_REGISTER_OP(_np_mod_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::mod>);
@@ -67,5 +72,11 @@ NNVM_REGISTER_OP(_np_power_scalar)
NNVM_REGISTER_OP(_np_rpower_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rpower>);
+NNVM_REGISTER_OP(_np_maximum_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::maximum>);
+
+NNVM_REGISTER_OP(_np_minimum_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::minimum>);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cc b/src/operator/numpy/np_elemwise_unary_op_basic.cc
new file mode 100644
index 0000000..f31ed5e
--- /dev/null
+++ b/src/operator/numpy/np_elemwise_unary_op_basic.cc
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file np_elemwise_unary_op_basic.cc
+ * \brief CPU Implementation of numpy elementwise unary function.
+ */
+#include <mxnet/base.h>
+#include "../tensor/elemwise_unary_op.h"
+
+namespace mxnet {
+namespace op {
+
+MXNET_OPERATOR_REGISTER_UNARY(_numpy__ext_relu)
+.describe(R"code(Computes rectified linear activation.
+
+.. math::
+ max(features, 0)
+
+)code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::relu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_relu"})
+.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+
+MXNET_OPERATOR_REGISTER_UNARY(_numpy__ext_sigmoid)
+.describe(R"code(Computes sigmoid of x element-wise.
+
+.. math::
+ y = 1 / (1 + exp(-x))
+
+)code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::sigmoid>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_sigmoid"})
+.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+
+MXNET_OPERATOR_REGISTER_UNARY(_np_copy)
+.MXNET_DESCRIBE("Returns a copy of the input.")
+.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
+.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
+ [](const NodeAttrs& attrs){
+ return std::vector<bool>{true};
+ })
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"})
+.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cu b/src/operator/numpy/np_elemwise_unary_op_basic.cu
new file mode 100644
index 0000000..9f108f7
--- /dev/null
+++ b/src/operator/numpy/np_elemwise_unary_op_basic.cu
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file np_elemwise_unary_op_basic.cu
+ * \brief GPU Implementation of numpy unary functions.
+ */
+#include "../tensor/elemwise_binary_op.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(_numpy__ext_relu)
+.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::relu>);
+
+NNVM_REGISTER_OP(_numpy__ext_sigmoid)
+.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::sigmoid>);
+
+NNVM_REGISTER_OP(_np_copy)
+.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h
new file mode 100644
index 0000000..44a6c90
--- /dev/null
+++ b/src/operator/numpy/np_matrix_op-inl.h
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_matrix_op-inl.h
+ * \brief Function definition of matrix related operators
+ */
+#ifndef MXNET_OPERATOR_NUMPY_NP_MATRIX_OP_INL_H_
+#define MXNET_OPERATOR_NUMPY_NP_MATRIX_OP_INL_H_
+
+#include <vector>
+#include "../tensor/matrix_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+struct NumpyTransposeParam : public dmlc::Parameter<NumpyTransposeParam> {
+ mxnet::TShape axes;
+ DMLC_DECLARE_PARAMETER(NumpyTransposeParam) {
+ DMLC_DECLARE_FIELD(axes).set_default(mxnet::TShape(-1, 0))
+ .describe("By default, reverse the dimensions, otherwise permute "
+ "the axes according to the values given.");
+ }
+};
+
+template<typename xpu>
+void NumpyTranspose(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ const NumpyTransposeParam& param = nnvm::get<NumpyTransposeParam>(attrs.parsed);
+ CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace";
+ if (ndim_is_known(param.axes)) {
+ TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], param.axes);
+ } else {
+ mxnet::TShape axes(inputs[0].ndim(), -1);
+ for (int i = 0; i < axes.ndim(); ++i) {
+ axes[i] = axes.ndim() - 1 - i;
+ }
+ TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
+ }
+}
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_OPERATOR_NUMPY_NP_MATRIX_OP_INL_H_
diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc
new file mode 100644
index 0000000..215b1c5
--- /dev/null
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_matrix_op.cc
+ * \brief CPU Implementation of numpy matrix operations
+ */
+
+#include "./np_matrix_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(NumpyTransposeParam);
+
+bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector *in_attrs,
+ mxnet::ShapeVector *out_attrs) {
+ const NumpyTransposeParam& param = nnvm::get<NumpyTransposeParam>(attrs.parsed);
+ CHECK_EQ(in_attrs->size(), 1U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ mxnet::TShape& shp = (*in_attrs)[0];
+ CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
+ mxnet::TShape ret(shp.ndim(), -1);
+ if (ndim_is_known(param.axes)) {
+ CHECK_EQ(shp.ndim(), param.axes.ndim());
+ for (int i = 0; i < shp.ndim(); ++i) {
+ CHECK(param.axes[i] < static_cast<int64_t>(shp.ndim()));
+ ret[i] = shp[param.axes[i]];
+ }
+ } else {
+ for (int i = 0; i < shp.ndim(); ++i) {
+ ret[i] = shp[shp.ndim()-1-i];
+ }
+ }
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret);
+ return shape_is_known(ret);
+}
+
+NNVM_REGISTER_OP(_numpy_transpose)
+.describe(R"code(Permute the dimensions of an array.
+
+Examples::
+
+ x = [[ 1, 2],
+ [ 3, 4]]
+
+ transpose(x) = [[ 1., 3.],
+ [ 2., 4.]]
+
+ x = [[[ 1., 2.],
+ [ 3., 4.]],
+
+ [[ 5., 6.],
+ [ 7., 8.]]]
+
+ transpose(x) = [[[ 1., 5.],
+ [ 3., 7.]],
+
+ [[ 2., 6.],
+ [ 4., 8.]]]
+
+ transpose(x, axes=(1,0,2)) = [[[ 1., 2.],
+ [ 5., 6.]],
+
+ [[ 3., 4.],
+ [ 7., 8.]]]
+)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<NumpyTransposeParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", NumpyTransposeShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FGradient>("FGradient",
+ [](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
+ const NumpyTransposeParam& param = nnvm::get<NumpyTransposeParam>(n->attrs.parsed);
+ if (ndim_is_known(param.axes)) {
+ mxnet::TShape axes = mxnet::TShape(param.axes.ndim(), -1);
+ for (int i = 0; i < axes.ndim(); ++i) {
+ axes[param.axes[i]] = i;
+ }
+ std::ostringstream os;
+ os << axes;
+ return MakeNonlossGradNode("transpose", n, ograds, {}, {{"axes", os.str()}});
+ } else {
+ return MakeNonlossGradNode("transpose", n, ograds, {},
+ std::unordered_map<std::string, std::string>());
+ }
+ })
+.set_attr<FCompute>("FCompute<cpu>", NumpyTranspose<cpu>)
+.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"a"};
+ })
+.add_argument("a", "NDArray-or-Symbol", "Source input")
+.add_arguments(NumpyTransposeParam::__FIELDS__());
+
+struct NumpyReshapeParam : public dmlc::Parameter<NumpyReshapeParam> {
+ mxnet::TShape newshape;
+ std::string order;
+ DMLC_DECLARE_PARAMETER(NumpyReshapeParam) {
+ DMLC_DECLARE_FIELD(newshape)
+ .describe("The new shape should be compatible with the original shape."
+ " If an integer, then the result will be a 1-D array of that length."
+ " One shape dimension can be -1. In this case, the value is inferred"
+ " from the length of the array and remaining dimensions.");
+ DMLC_DECLARE_FIELD(order)
+ .set_default("C")
+ .describe("Read the elements of a using this index order, and place the elements into"
+ " the reshaped array using this index order. 'C' means to read/write the elements"
+ " using C-like index order, with the last axis index changing fastest, back to the"
+ " first axis index changing slowest. Note that currently only C-like order is"
+ " supported");
+ }
+};
+
+DMLC_REGISTER_PARAMETER(NumpyReshapeParam);
+
+bool NumpyReshapeInferShape(const mxnet::TShape& src, mxnet::TShape* dst) {
+ if (shape_is_known(src) && shape_is_known(*dst)) {
+ CHECK_EQ(src.Size(), dst->Size()) << "Cannot reshape array of size "
+ << src.Size() << " into shape " << *dst;
+ return true;
+ } else if (!shape_is_known(src) || !ndim_is_known(*dst)) {
+ return false;
+ } else {
+ int unknown_axis = -1;
+ dim_t known_dim_size_prod = 1;
+ for (int i = 0; i < dst->ndim(); ++i) {
+ if (!dim_size_is_known(*dst, i)) {
+ if (unknown_axis == -1) {
+ unknown_axis = i;
+ } else {
+ return false; // more than one unknown dim
+ }
+ } else {
+ known_dim_size_prod *= (*dst)[i];
+ }
+ }
+ CHECK_NE(known_dim_size_prod, 0) << "Cannot reshape array of size "
+ << src.Size() << " into shape " << *dst;
+ CHECK_EQ(src.Size() % known_dim_size_prod, 0) << "Cannot reshape array of size "
+ << src.Size() << " into shape " << *dst;
+ (*dst)[unknown_axis] = src.Size() / known_dim_size_prod;
+ return true;
+ }
+}
+
+bool NumpyReshapeShape(const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector* in_attrs,
+ mxnet::ShapeVector* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]";
+ CHECK_EQ(out_attrs->size(), 1U);
+ const NumpyReshapeParam& param = nnvm::get<NumpyReshapeParam>(attrs.parsed);
+ // sanity check
+ bool has_unknown_dim_size = false;
+ for (int i = 0; i < param.newshape.ndim(); ++i) {
+ if (param.newshape[i] < 0) {
+ CHECK_EQ(param.newshape[i], -1) << "The shape dimension size to inferred must be -1";
+ CHECK(!has_unknown_dim_size) << "Can only specify one unknown dimension";
+ has_unknown_dim_size = true;
+ }
+ }
+
+ mxnet::TShape target_shape = param.newshape;
+ bool success = NumpyReshapeInferShape(in_attrs->at(0), &target_shape);
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, target_shape);
+ if (!success) {
+ success = NumpyReshapeInferShape(out_attrs->at(0), &in_attrs->at(0));
+ }
+ return success;
+}
+
+NNVM_REGISTER_OP(_numpy_reshape)
+.describe(R"code()code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<NumpyReshapeParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", NumpyReshapeShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_reshape"})
+.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::pair<int, int> >{{0, 0}};
+ })
+.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
+ [](const NodeAttrs& attrs){
+ return std::vector<bool>{true};
+ })
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"a"};
+ })
+.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
+.add_argument("a", "NDArray-or-Symbol", "Array to be reshaped.")
+.add_arguments(NumpyReshapeParam::__FIELDS__());
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu
new file mode 100644
index 0000000..9753566
--- /dev/null
+++ b/src/operator/numpy/np_matrix_op.cu
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_matrix_op.cu
+ * \brief GPU Implementation of numpy matrix operations
+ */
+#include "./np_matrix_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(_numpy_transpose)
+.set_attr<FCompute>("FCompute<gpu>", NumpyTranspose<gpu>);
+
+NNVM_REGISTER_OP(_numpy_reshape)
+.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h
index f84767d..8a81bbc 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -292,6 +292,7 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
+ if (outputs[0].shape_.Size() == 0U) return;
mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_,
&new_lshape, &new_rshape, &new_oshape);
diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index 5cd7bf6..4e13354 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -265,11 +265,17 @@ void TransposeImpl(RunContext ctx,
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(src.type_flag_, ret.type_flag_);
+ // zero-size tensor, no need to compute
+ if (src.shape_.Size() == 0U) return;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
switch (axes.ndim()) {
- case 0:
+ case 0: {
+ Tensor<xpu, 1, DType> in = src.get_with_shape<xpu, 1, DType>(mshadow::Shape1(1), s);
+ Tensor<xpu, 1, DType> out = ret.get_with_shape<xpu, 1, DType>(mshadow::Shape1(1), s);
+ Copy(out, in, s);
break;
+ }
case 1: {
Tensor<xpu, 1, DType> in = src.get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> out = ret.get<xpu, 1, DType>(s);
diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py
index 88e56ac..141d153 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -24,7 +24,6 @@ from mxnet import numpy as np
from mxnet.gluon import HybridBlock
from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, assert_exception
from common import with_seed
-import random
@with_seed()
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 024c893..8c13227 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -192,6 +192,126 @@ def test_np_mean():
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+@with_seed()
+@mx.use_np_compat
+def test_np_transpose():
+ # TODO(junwu): Add more test cases
+ data = mx.sym.var('a')
+ ret = mx.sym.np.transpose(data)
+ assert type(ret) == mx.sym.np._NumpySymbol
+
+ dtypes = ['float32', 'int32']
+ for dtype in dtypes:
+ for ndim in [0, 1, 2, 3, 4, 5, 6]:
+ shape = rand_shape_nd(ndim, dim=5, allow_zero_size=True)
+ np_data = _np.random.uniform(low=-100, high=100, size=shape).astype(dtype)
+ mx_data = np.array(np_data, dtype=dtype)
+ axes = [None]
+ if ndim == 0:
+ axes += [()]
+ else:
+ axis = [i for i in range(ndim)]
+ axes.append(tuple(axis))
+ random.shuffle(axis)
+ axes.append(tuple(axis))
+ for axis in axes:
+ np_out = _np.transpose(np_data, axes=axis)
+ mx_out = np.transpose(mx_data, axes=axis)
+ assert np_out.dtype == mx_out.dtype
+ assert same(mx_out.asnumpy(), np_out)
+ # TODO(junwu): Add numerical gradient test and Gluon API test.
+
+
+@with_seed()
+@mx.use_np_compat
+def test_relu():
+ # TODO(junwu): Add more test cases
+ data = mx.sym.var('data')
+ ret = mx.sym.np.ext.relu(data)
+ assert type(ret) == mx.sym.np._NumpySymbol
+
+ shapes = [(), (0, 2, 0)]
+ shapes.extend([rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)])
+ for shape in shapes:
+ data = np.array(_np.random.uniform(size=shape).astype('float32'))
+ ret = np.ext.relu(data)
+ assert type(ret) == np.ndarray
+
+
+@with_seed()
+@mx.use_np_compat
+def test_sigmoid():
+ # TODO(junwu): Add more test cases
+ data = mx.sym.var('data')
+ ret = mx.sym.np.ext.sigmoid(data)
+ assert type(ret) == mx.sym.np._NumpySymbol
+
+ shapes = [(), (0, 2, 0)]
+ shapes.extend([rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)])
+ for shape in shapes:
+ data = np.array(_np.random.uniform(size=shape).astype('float32'))
+ ret = np.ext.sigmoid(data)
+ assert type(ret) == np.ndarray
+
+
+@with_seed()
+@mx.use_np_compat
+def test_np_reshape():
+ # TODO(junwu): Add more test cases
+ data = mx.sym.var('a')
+ ret = mx.sym.np.reshape(data, newshape=())
+ assert type(ret) == mx.sym.np._NumpySymbol
+
+ data = np.ones((1, 1, 1))
+ ret = np.reshape(data, ())
+ assert ret.shape == ()
+ ret = np.reshape(ret, (1, 1, 1, 1))
+ assert ret.shape == (1, 1, 1, 1)
+ assert type(ret) == np.ndarray
+
+
+@with_seed()
+@mx.use_np_compat
+def test_np_maximum():
+ # TODO(junwu): Add more test cases
+ x1, x2 = mx.sym.var('x1'), mx.sym.var('x2')
+ ret = mx.sym.np.maximum(x1, x2)
+ assert type(ret) == mx.sym.np._NumpySymbol
+
+ def check_maximum(x1, x2):
+ mx_out = np.maximum(x1, x2)
+ if isinstance(x1, np.ndarray) or isinstance(x2, np.ndarray):
+ assert type(mx_out) == np.ndarray
+ np_out = _np.maximum(x1.asnumpy() if isinstance(x1, np.ndarray) else x1,
+ x2.asnumpy() if isinstance(x2, np.ndarray) else x2)
+ assert same(mx_out.asnumpy() if isinstance(mx_out, np.ndarray) else mx_out, np_out)
+
+ check_maximum(np.zeros((2, 1)), np.ones((5, 1, 4)))
+ check_maximum(np.zeros((2, 0)), np.ones((5, 1, 1)))
+ check_maximum(np.zeros(()), np.ones((5, 1, 4)))
+
+
+@with_seed()
+@mx.use_np_compat
+def test_np_minimum():
+ # TODO(junwu): Add more test cases
+ x1, x2 = mx.sym.var('x1'), mx.sym.var('x2')
+ ret = mx.sym.np.minimum(x1, x2)
+ assert type(ret) == mx.sym.np._NumpySymbol
+
+ def check_minimum(x1, x2):
+ mx_out = np.minimum(x1, x2)
+ if isinstance(x1, np.ndarray) or isinstance(x2, np.ndarray):
+ assert type(mx_out) == np.ndarray
+ np_out = _np.minimum(x1.asnumpy() if isinstance(x1, np.ndarray) else x1,
+ x2.asnumpy() if isinstance(x2, np.ndarray) else x2)
+ assert same(mx_out.asnumpy() if isinstance(mx_out, np.ndarray) else mx_out, np_out)
+
+ check_minimum(np.zeros((2, 1)), np.ones((5, 1, 4)))
+ check_minimum(np.zeros((2, 0)), np.ones((5, 1, 1)))
+ check_minimum(np.zeros(()), np.ones((5, 1, 4)))
+
+
if __name__ == '__main__':
import nose
nose.runmodule()