You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2019/05/18 20:31:08 UTC
[incubator-mxnet] branch numpy updated: [numpy] Refactor np modules
(#14989)
This is an automated email from the ASF dual-hosted git repository.
reminisce pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/numpy by this push:
new 7f7b13d [numpy] Refactor np modules (#14989)
7f7b13d is described below
commit 7f7b13dcd2c6ab2438d8b34b03037aa17438cbb7
Author: reminisce <wu...@gmail.com>
AuthorDate: Sat May 18 13:30:29 2019 -0700
[numpy] Refactor np modules (#14989)
* Refactor
* Initial refactoring
* Fix notebook
* Move numpy op check from backend to frontend
* Add homogeneous ndarray check
* Fix grouping inhomogeneous types of symbols
* Improve error handling of different types of symbols as outputs
* Fix test
* Fix numpy test
* Fix ci
* Try to fix gpu ci failure
---
example/numpy/demo.ipynb | 73 ++----
include/mxnet/c_api.h | 17 --
include/mxnet/op_attr_types.h | 9 -
python/mxnet/__init__.py | 3 +
python/mxnet/_ctypes/ndarray.py | 19 +-
python/mxnet/_ctypes/symbol.py | 10 +-
python/mxnet/base.py | 125 ++++++----
python/mxnet/gluon/block.py | 6 +-
python/mxnet/gluon/utils.py | 23 ++
python/mxnet/ndarray/__init__.py | 3 +-
python/mxnet/ndarray/ndarray.py | 48 +---
python/mxnet/ndarray/numpy/__init__.py | 5 +-
.../{numpy/ext.py => ndarray/numpy/_internal.py} | 2 +-
python/mxnet/ndarray/numpy/_op.py | 20 +-
python/mxnet/ndarray/numpy/_register.py | 8 +-
python/mxnet/ndarray/numpy/linalg.py | 2 +-
python/mxnet/ndarray/numpy/random.py | 2 +-
.../ndarray/{numpy => numpy_extension}/__init__.py | 5 +-
.../{numpy/ext.py => numpy_extension/_op.py} | 3 +-
.../{numpy => numpy_extension}/_register.py | 5 +-
python/mxnet/ndarray/register.py | 66 +++++-
python/mxnet/numpy/__init__.py | 4 +-
python/mxnet/numpy/_op.py | 2 +-
python/mxnet/numpy/_register.py | 5 +-
python/mxnet/numpy/linalg.py | 2 +-
python/mxnet/numpy/multiarray.py | 185 ++++++++++-----
python/mxnet/numpy/random.py | 2 +-
.../mxnet/{numpy => numpy_extension}/__init__.py | 7 +-
python/mxnet/{numpy => numpy_extension}/_op.py | 2 +-
.../mxnet/{numpy => numpy_extension}/_register.py | 5 +-
python/mxnet/symbol/__init__.py | 4 +-
python/mxnet/symbol/numpy/__init__.py | 7 +-
.../{numpy/_op.py => symbol/numpy/_internal.py} | 2 +-
python/mxnet/symbol/numpy/_op.py | 2 +-
python/mxnet/symbol/numpy/_register.py | 9 +-
python/mxnet/symbol/numpy/_symbol.py | 258 ++++++++++-----------
python/mxnet/symbol/numpy/ext.py | 20 --
python/mxnet/symbol/numpy/linalg.py | 2 +-
python/mxnet/symbol/numpy/random.py | 2 +-
.../numpy => symbol/numpy_extension}/__init__.py | 5 +-
.../mxnet/symbol/{numpy => numpy_extension}/_op.py | 3 +-
.../symbol/{numpy => numpy_extension}/_register.py | 5 +-
python/mxnet/symbol/register.py | 74 +++++-
python/mxnet/symbol/symbol.py | 57 ++---
python/mxnet/test_utils.py | 6 +
src/c_api/c_api_common.h | 17 --
src/c_api/c_api_ndarray.cc | 16 --
src/c_api/c_api_symbolic.cc | 7 -
src/operator/numpy/np_broadcast_reduce_op.h | 1 +
src/operator/numpy/np_broadcast_reduce_op_value.cc | 14 +-
src/operator/numpy/np_broadcast_reduce_op_value.cu | 8 +-
src/operator/numpy/np_dot-inl.h | 11 +-
src/operator/numpy/np_dot.cc | 2 +-
src/operator/numpy/np_dot.cu | 2 +-
src/operator/numpy/np_elemwise_broadcast_op.cc | 56 ++---
src/operator/numpy/np_elemwise_broadcast_op.cu | 34 +--
src/operator/numpy/np_elemwise_unary_op_basic.cc | 28 ++-
src/operator/numpy/np_elemwise_unary_op_basic.cu | 4 +-
src/operator/numpy/np_init_op.cc | 64 ++++-
src/operator/numpy/np_init_op.cu | 10 +-
src/operator/numpy/np_matrix_op.cc | 6 +-
src/operator/numpy/np_matrix_op.cu | 4 +-
src/operator/numpy/np_true_divide.cc | 9 +-
src/operator/numpy/np_true_divide.cu | 6 +-
tests/python/unittest/test_numpy_ndarray.py | 95 ++++----
tests/python/unittest/test_numpy_op.py | 78 ++++---
66 files changed, 876 insertions(+), 720 deletions(-)
diff --git a/example/numpy/demo.ipynb b/example/numpy/demo.ipynb
index d8e6e06..7ba184d 100644
--- a/example/numpy/demo.ipynb
+++ b/example/numpy/demo.ipynb
@@ -6,21 +6,21 @@
"source": [
"# Fundamentals of MXNet Numpy Module\n",
"\n",
- "## Operator Namespaces for Imperative Programming\n",
+ "## Namespaces for Imperative Programming\n",
"- `mxnet.numpy`: Regular NumPy operators\n",
"- `mxnet.numpy.random`: NumPy random operators\n",
"- `mxnet.numpy.linalg`: NumPy linear algebra operators\n",
- "- `mxnet.numpy.ext`: Operators implemented in MXNet that do not exist in official NumPy\n",
+ "- `mxnet.numpy_extension`: Operators implemented in MXNet that do not exist in the official NumPy\n",
"\n",
"## Operator Namespaces for Gluon\n",
- "`F` can be either `mxnet.ndarray` or `mxnet.symbol`.\n",
+ "`F` can be either `mxnet.ndarray` or `mxnet.symbol`. Note that `np` and `npe` are aliases of `numpy` and `numpy_extension`, respectively.\n",
"- `F.np`: Regular NumPy operators\n",
"- `F.np.random`: NumPy random operators\n",
"- `F.np.linalg`: NumPy linear algebra operators\n",
- "- `F.np.ext`: Operators implemented in MXNet that do not exist in official NumPy\n",
+ "- `F.npe`: Operators implemented in MXNet that do not exist in official NumPy\n",
"\n",
"## New `ndarray` and `symbol`\n",
- "`mxnet.numpy.ndarray` and `mxnet.symbol.numpy._NumpySymbol` (not visible to users)\n",
+ "`mxnet.numpy.ndarray` (visible to users) and `mxnet.symbol.numpy._Symbol` (not visible to users)\n",
"- Same name as in the official NumPy package\n",
"- Dispatch convience fluent method calls to MXNet Numpy operators\n",
"- Override many convenience fluent methods that do not exist in the official NumPy ndarray\n",
@@ -46,7 +46,7 @@
"\n",
"# create a scalar tensor\n",
"x = np.array(3.14)\n",
- "print(x)"
+ "print(x) # x is actually an ndarray, but a scalar value will be printed"
]
},
{
@@ -170,13 +170,15 @@
"from mxnet import gluon\n",
"class TestBinaryBroadcast(gluon.HybridBlock):\n",
" def hybrid_forward(self, F, x1, x2):\n",
- " print(\"x1 type:\", str(type(x1)))\n",
- " print(\"x2 type:\", str(type(x2)))\n",
+ " print(\"x1 type in hybrid_forward:\", str(type(x1)))\n",
+ " print(\"x2 type in hybrid_forward:\", str(type(x2)))\n",
" return x1 + x2\n",
"\n",
"net = TestBinaryBroadcast()\n",
"x1 = mx.nd.ones((2, 1))\n",
"x2 = mx.nd.ones((1, 3))\n",
+ "print('x1 input tensor type: ', str(type(x1)))\n",
+ "print('x2 input tensor type: ', str(type(x2)))\n",
"out = net(x1, x2) # ok: imperative execution supports broadcasting\n",
"print(out)"
]
@@ -203,13 +205,15 @@
"source": [
"class TestBinaryBroadcast2(gluon.HybridBlock):\n",
" def hybrid_forward(self, F, x1, x2):\n",
- " print(\"x1 type:\", str(type(x1)))\n",
- " print(\"x2 type:\", str(type(x2)))\n",
+ " print(\"x1 type in hybrid_forward:\", str(type(x1)))\n",
+ " print(\"x2 type in hybrid_forward:\", str(type(x2)))\n",
" return x1.as_np_ndarray() + x2 # convert x1 to new numpy ndarray/symbol\n",
"\n",
"net2 = TestBinaryBroadcast2()\n",
"net2.hybridize()\n",
"\n",
+ "print('x1 input tensor type: ', str(type(x1)))\n",
+ "print('x2 input tensor type: ', str(type(x2)))\n",
"out =net2(x1, x2)\n",
"print(out)"
]
@@ -224,7 +228,9 @@
"net.hybridize() # mark the block for execution using a computational graph\n",
"\n",
"x1 = x1.as_np_ndarray() # convert x1 to np.ndarray so that _NumpySymbol will be used in graph construction\n",
+ "print('x1 input tensor type: ', str(type(x1)))\n",
"x2 = x2.as_np_ndarray() # convert x2 to np.ndarray so that _NumpySymbol will be used in graph construction\n",
+ "print('x2 input tensor type: ', str(type(x2)))\n",
"out = net(x1, x2) # ok: `+` operation supports broadcasting for _NumpySymbol\n",
"print(out) # mxnet.numpy.ndarray type, because it's from a np operator"
]
@@ -245,7 +251,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## MXNet Numpy Operators in Imperative Programming"
+ "### MXNet Numpy Operators in Imperative Programming"
]
},
{
@@ -255,15 +261,9 @@
"outputs": [],
"source": [
"import mxnet as mx\n",
- "from mxnet import numpy as np\n",
+ "from mxnet import numpy as np, numpy_extension as npe\n",
"from mxnet import autograd\n",
- "try:\n",
- " from mxboard import SummaryWriter\n",
- "except ImportError:\n",
- " SummaryWriter = None\n",
"\n",
- "# create a summary writer for visualization\n",
- "sw = SummaryWriter(logdir='./logs', flush_secs=2) if SummaryWriter is not None else None\n",
"\n",
"# Use numpy-compatible semantics to support scalar tensors\n",
"mx.set_np_compat(True)\n",
@@ -285,11 +285,11 @@
"learning_rate = 1e-6\n",
"\n",
"\n",
- "for t in range(1000):\n",
+ "for t in range(50):\n",
" with autograd.record():\n",
" # Forward pass: compute predicted y\n",
" h = x.dot(w1) # equivalent to np.dot(x, w1)\n",
- " h_relu = np.ext.relu(h) # equivalent to mx.nd.relu(h)\n",
+ " h_relu = npe.relu(h) # equivalent to mx.nd.relu(h)\n",
" y_pred = h_relu.dot(w2) # equivalent to np.dot(h_relu, w2)\n",
"\n",
" # Compute loss\n",
@@ -302,23 +302,14 @@
"\n",
" # Update weights\n",
" w1 -= learning_rate * w1.grad\n",
- " w2 -= learning_rate * w2.grad\n",
- "\n",
- " if sw is not None:\n",
- " sw.add_scalar('loss', loss.item(), global_step=t) # loss.item() copies the tensor element to a python scalar\n",
- " if t % 50 == 0:\n",
- " sw.add_histogram(tag='w1', values=w1, global_step=t)\n",
- " sw.add_histogram(tag='w2', values=w2, global_step=t)\n",
- "\n",
- "if sw is not None:\n",
- " sw.close()"
+ " w2 -= learning_rate * w2.grad"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "## MXNet Numpy Operators in Gluon `HybridBlock`"
+ "### MXNet Numpy Operators in Gluon `HybridBlock`"
]
},
{
@@ -329,13 +320,7 @@
"source": [
"import mxnet as mx\n",
"from mxnet import gluon, autograd\n",
- "try:\n",
- " from mxboard import SummaryWriter\n",
- "except ImportError:\n",
- " SummaryWriter = None\n",
"\n",
- "# create a summary writer for visualization\n",
- "sw = SummaryWriter(logdir='./logs', flush_secs=2) if SummaryWriter is not None else None\n",
"\n",
"# Use numpy-compatible semantics to support scalar tensors\n",
"mx.set_np_compat(True)\n",
@@ -352,7 +337,7 @@
"\n",
" def hybrid_forward(self, F, x, w1, w2):\n",
" h = x.dot(w1) # equivalent to F.np.dot(x, w1)\n",
- " h_relu = F.np.ext.relu(h) # equivalent to F.relu(h)\n",
+ " h_relu = F.npe.relu(h) # equivalent to F.relu(h)\n",
" y_pred = h_relu.dot(w2) # equivalent to F.np.dot(h_relu, w2)\n",
" return y_pred\n",
"\n",
@@ -373,21 +358,13 @@
"total_loss = TotalLoss()\n",
"trainer = gluon.Trainer(regressor.collect_params(), 'sgd', {'learning_rate': 1e-3, 'momentum': 0.9})\n",
"\n",
- "for t in range(1000):\n",
+ "for t in range(50):\n",
" with autograd.record():\n",
" output = regressor(x) # output is a type of np.ndarray because np.dot is the last op in the network\n",
" loss = total_loss(output, y) # loss is a scalar np.ndarray\n",
" loss.backward()\n",
" print(t, loss) # note that loss.asnumpy() is called\n",
- " trainer.step(1)\n",
- " if sw is not None:\n",
- " sw.add_scalar('loss', loss.item(), global_step=t) # loss.item() copies the tensor element to a python scalar\n",
- " if t % 50 == 0:\n",
- " for k, v in regressor.collect_params().items():\n",
- " sw.add_histogram(tag=k, values=v.data(), global_step=t)\n",
- "\n",
- "if sw is not None:\n",
- " sw.close()"
+ " trainer.step(1)"
]
}
],
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 7a5dc13..f99c223 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -2789,14 +2789,6 @@ MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));
/*!
- * \brief Determines if an op is a Numpy op by its name prefix.
- * Every Numpy op starts with a prefix string "_numpy_".
- * \param creator Operator handle
- * \param is_np_op Indicator of whether creator is a numpy op handle
- */
-MXNET_DLL int MXIsNumpyCompatOp(AtomicSymbolCreator creator,
- int* is_np_op);
-/*!
* \brief Create an NDArray from source sharing the same data chunk.
* \param src source NDArray
* \param out new NDArray sharing the same data chunck with src
@@ -2808,15 +2800,6 @@ MXNET_DLL int MXShallowCopyNDArray(NDArrayHandle src, NDArrayHandle* out);
* \param out new Symbol sharing the same graph structure with src
*/
MXNET_DLL int MXShallowCopySymbol(SymbolHandle src, SymbolHandle * out);
-/*!
- * \brief Checks if an output of CachedOp is from a numpy op.
- * \param handle CachedOp shared ptr
- * \param output_idx index of the output of the CachedOp
- * \param is_from_np_op indicator of whether the output is from a numpy op
- */
-MXNET_DLL int MXIsCachedOpOutputFromNumpyCompatOp(CachedOpHandle handle,
- int output_idx,
- int* is_from_np_op);
#ifdef __cplusplus
}
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index 0e4e322..889b502 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -319,15 +319,6 @@ using FNeedRequantize = std::function<bool (const NodeAttrs& attrs)>;
using FAvoidQuantizeInput = std::function<bool (const NodeAttrs& attrs,
size_t index)>;
-/*!
- * \brief Indicates whether this operator is NumPy compatible.
- * It is for distinguishing the operator from classic MXNet operators
- * which do not support zero-dim and zero-size tensors.
- * In Python, it is used to determine whether to output numpy ndarrays
- * or symbols that are NumPy compatible.
- */
-using TIsNumpyCompatible = bool;
-
} // namespace mxnet
#endif // MXNET_OP_ATTR_TYPES_H_
diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py
index c4d4ff8..8d570d5 100644
--- a/python/mxnet/__init__.py
+++ b/python/mxnet/__init__.py
@@ -29,6 +29,9 @@ from . import contrib
from . import ndarray
from . import ndarray as nd
from . import numpy
+from . import numpy_extension
+from . import numpy as np
+from . import numpy_extension as npe
from . import name
# use mx.sym as short for symbol
from . import symbol as sym
diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py
index 60ec248..6404d89 100644
--- a/python/mxnet/_ctypes/ndarray.py
+++ b/python/mxnet/_ctypes/ndarray.py
@@ -26,7 +26,7 @@ import ctypes
from ..base import _LIB
from ..base import c_str_array, c_handle_array
from ..base import NDArrayHandle, CachedOpHandle
-from ..base import check_call, _is_np_compat_op
+from ..base import check_call
class NDArrayBase(object):
@@ -70,7 +70,7 @@ def _set_np_ndarray_class(cls):
_np_ndarray_cls = cls
-def _imperative_invoke(handle, ndargs, keys, vals, out):
+def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op):
"""ctypes implementation of imperative invoke wrapper"""
if out is not None:
original_output = out
@@ -99,9 +99,9 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
c_str_array([str(s) for s in vals]),
ctypes.byref(out_stypes)))
+ create_ndarray_fn = _np_ndarray_cls if is_np_op else _ndarray_cls
if original_output is not None:
return original_output
- create_ndarray_fn = _np_ndarray_cls if _is_np_compat_op(handle) else _ndarray_cls
if num_output.value == 1:
return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle),
stype=out_stypes[0])
@@ -112,11 +112,14 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
class CachedOp(object):
"""Cached operator handle."""
- __slots__ = ["handle"]
+ __slots__ = ["handle", "is_np_sym"]
def __init__(self, sym, flags=()):
self.handle = CachedOpHandle()
+ from ..symbol.numpy._symbol import _Symbol
+ self.is_np_sym = True if isinstance(sym, _Symbol) else False
+
check_call(_LIB.MXCreateCachedOpEx(
sym.handle,
len(flags),
@@ -167,12 +170,10 @@ class CachedOp(object):
if original_output is not None:
return original_output
+ create_ndarray_fn = _np_ndarray_cls if self.is_np_sym else _ndarray_cls
if num_output.value == 1:
- create_ndarray_fn = _np_ndarray_cls if self._is_from_np_compat_op(0) else _ndarray_cls
return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle),
stype=out_stypes[0])
else:
- return [_np_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), stype=out_stypes[i])
- if self._is_from_np_compat_op(i) else
- _ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), stype=out_stypes[i])
- for i in range(num_output.value)]
+ return [create_ndarray_fn(ctypes.cast(output_vars[i], NDArrayHandle),
+ stype=out_stypes[i]) for i in range(num_output.value)]
diff --git a/python/mxnet/_ctypes/symbol.py b/python/mxnet/_ctypes/symbol.py
index 7aea0a2..fc159f8 100644
--- a/python/mxnet/_ctypes/symbol.py
+++ b/python/mxnet/_ctypes/symbol.py
@@ -22,7 +22,7 @@ from __future__ import absolute_import as _abs
import ctypes
from ..base import _LIB
-from ..base import c_str_array, c_handle_array, c_str, mx_uint, _is_np_compat_op
+from ..base import c_str_array, c_handle_array, c_str, mx_uint
from ..base import SymbolHandle
from ..base import check_call
@@ -122,7 +122,7 @@ def _set_np_symbol_class(cls):
_np_symbol_cls = cls
-def _symbol_creator(handle, args, kwargs, keys, vals, name):
+def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op):
sym_handle = SymbolHandle()
check_call(_LIB.MXSymbolCreateAtomicSymbol(
ctypes.c_void_p(handle),
@@ -135,10 +135,8 @@ def _symbol_creator(handle, args, kwargs, keys, vals, name):
raise TypeError(
'Operators with variable length input can only accept input'
'Symbols either as positional or keyword arguments, not both')
- if _is_np_compat_op(handle):
- s = _np_symbol_cls(sym_handle)
- else:
- s = _symbol_cls(sym_handle)
+ create_symbol_fn = _np_symbol_cls if is_np_op else _symbol_cls
+ s = create_symbol_fn(sym_handle)
if args:
s._compose(*args, name=name)
elif kwargs:
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 12d1042..af4b2c5 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -16,7 +16,7 @@
# under the License.
# coding: utf-8
-# pylint: disable=invalid-name, no-member, trailing-comma-tuple, bad-mcs-classmethod-argument, unnecessary-pass
+# pylint: disable=invalid-name, no-member, trailing-comma-tuple, bad-mcs-classmethod-argument, unnecessary-pass, too-many-lines
"""ctypes library of mxnet and helper functions."""
from __future__ import absolute_import
@@ -599,7 +599,9 @@ def _init_op_module(root_namespace, module_name, make_op_func):
ctypes.byref(plist)))
op_names = []
for i in range(size.value):
- op_names.append(py_str(plist[i]))
+ op_name = py_str(plist[i])
+ if not _is_np_op(op_name):
+ op_names.append(op_name)
module_op = sys.modules["%s.%s.op" % (root_namespace, module_name)]
module_internal = sys.modules["%s.%s._internal" % (root_namespace, module_name)]
@@ -693,7 +695,9 @@ def _generate_op_module_signature(root_namespace, module_name, op_code_gen_func)
ctypes.byref(plist)))
op_names = []
for i in range(size.value):
- op_names.append(py_str(plist[i]))
+ op_name = py_str(plist[i])
+ if not _is_np_op(op_name):
+ op_names.append(op_name)
module_op_file = get_module_file("%s.%s.op" % (root_namespace, module_name))
module_op_all = []
@@ -874,12 +878,6 @@ def use_np_compat(func):
return _with_np_compat
-def _is_np_compat_op(op_handle):
- is_np_op = ctypes.c_int(0)
- check_call(_LIB.MXIsNumpyCompatOp(ctypes.c_void_p(op_handle), ctypes.byref(is_np_op)))
- return is_np_op.value != 0
-
-
def _sanity_check_params(func_name, unsupported_params, param_dict):
for param_name in unsupported_params:
if param_name in param_dict:
@@ -887,19 +885,28 @@ def _sanity_check_params(func_name, unsupported_params, param_dict):
.format(func_name, param_name))
-_NP_OP_SUBMODULE_LIST = ['_ext_', '_random_', '_linalg_']
-_NP_OP_PREFIX = '_numpy_'
+_NP_OP_PREFIX = '_np_'
+_NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_']
+
+_NP_EXT_OP_PREFIX = '_npe_'
+
+_NP_INTERNAL_OP_PREFIX = '_npi_'
+
+def _is_np_op(op_name):
+ return op_name.startswith(_NP_OP_PREFIX) or op_name.startswith(_NP_EXT_OP_PREFIX)\
+ or op_name.startswith(_NP_INTERNAL_OP_PREFIX)
-def _get_np_op_submodule_name(op_name):
- assert op_name.startswith(_NP_OP_PREFIX)
- for name in _NP_OP_SUBMODULE_LIST:
- if op_name[len(_NP_OP_PREFIX):].startswith(name):
- return name
+
+def _get_op_submodule_name(op_name, op_name_prefix, submodule_name_list):
+ assert op_name.startswith(op_name_prefix)
+ for submodule_name in submodule_name_list:
+ if op_name[len(op_name_prefix):].startswith(submodule_name):
+ return submodule_name
return ""
-def _init_np_op_module(root_namespace, module_name, make_op_func):
+def _init_np_op_module(root_module_name, np_module_name, mx_module_name, make_op_func):
"""
Register numpy operators in namespaces `mxnet.numpy`, `mxnet.ndarray.numpy`
and `mxnet.symbol.numpy`. They are used in imperative mode, Gluon APIs w/o hybridization,
@@ -909,51 +916,89 @@ def _init_np_op_module(root_namespace, module_name, make_op_func):
Parameters
----------
- root_namespace : str
+ root_module_name : str
Top level module name, `mxnet` in the current cases.
- module_name : str
- Second level module name, `ndarray` or `symbol` in the current case.
+ np_module_name : str
+ Second level module name, `numpy` or `numpy_extension` in the current case.
make_op_func : function
Function for creating op functions.
"""
+ if np_module_name == 'numpy':
+ op_name_prefix = _NP_OP_PREFIX
+ submodule_name_list = _NP_OP_SUBMODULE_LIST
+ elif np_module_name == 'numpy_extension':
+ op_name_prefix = _NP_EXT_OP_PREFIX
+ submodule_name_list = []
+ elif np_module_name == 'numpy._internal':
+ op_name_prefix = _NP_INTERNAL_OP_PREFIX
+ submodule_name_list = []
+ else:
+ raise ValueError('unsupported np module name {}'.format(np_module_name))
+
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
-
check_call(_LIB.MXListAllOpNames(ctypes.byref(size), ctypes.byref(plist)))
op_names = []
for i in range(size.value):
name = py_str(plist[i])
- if name.startswith(_NP_OP_PREFIX):
+ if name.startswith(op_name_prefix):
op_names.append(name)
- if module_name == 'numpy':
- # register ops for mxnet.numpy
- module_pattern = "%s.%s._op"
- submodule_pattern = "%s.%s.%s"
+ if mx_module_name is None:
+ # register np/npe ops for imperative programming
+ op_module_name = "%s.%s._op" % (root_module_name, np_module_name) # e.g. mxnet.numpy._op
+ op_submodule_name = "%s.%s" % (root_module_name, np_module_name) # e.g. mxnet.numpy.random
+ elif mx_module_name == 'ndarray' or mx_module_name == 'symbol':
+ # register numpy internal ops and np/npe ops for use in Gluon
+ # np internal ops are registered in mxnet.ndarray/symbol.numpy._internal
+ # np ops are registered in mxnet.ndarray/symbol.numpy._op
+ # npe ops are registered in mxnet.ndarray/symbol.numpy_extension._op
+ op_module_name = "%s.%s.%s" % (root_module_name, mx_module_name, np_module_name)
+ if op_name_prefix != _NP_INTERNAL_OP_PREFIX:
+ op_module_name += '._op'
+ # e.g. mxnet.symbol.numpy.random
+ op_submodule_name = "%s.%s.%s" % (root_module_name, mx_module_name, np_module_name)
else:
- # register ops for mxnet.ndarray.numpy or mxnet.symbol.numpy
- module_pattern = "%s.%s.numpy._op"
- submodule_pattern = "%s.%s.numpy.%s"
- module_np_op = sys.modules[module_pattern % (root_namespace, module_name)]
+ raise ValueError('unsupported mxnet module {}'.format(mx_module_name))
+ op_submodule_name += '.%s'
+
+ op_module = sys.modules[op_module_name]
submodule_dict = {}
- for submodule_name in _NP_OP_SUBMODULE_LIST:
- submodule_dict[submodule_name] = \
- sys.modules[submodule_pattern % (root_namespace, module_name, submodule_name[1:-1])]
+ for submodule_name in submodule_name_list:
+ submodule_dict[submodule_name] = sys.modules[op_submodule_name % submodule_name[1:-1]]
for name in op_names:
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
- submodule_name = _get_np_op_submodule_name(name)
- module_name_local = module_name
+ submodule_name = _get_op_submodule_name(name, op_name_prefix, submodule_name_list)
if len(submodule_name) > 0:
- func_name = name[(len(_NP_OP_PREFIX) + len(submodule_name)):]
+ func_name = name[(len(op_name_prefix) + len(submodule_name)):]
cur_module = submodule_dict[submodule_name]
- module_name_local = submodule_pattern % (root_namespace,
- module_name, submodule_name[1:-1])
+ module_name_local = op_submodule_name % submodule_name[1:-1]
else:
- func_name = name[len(_NP_OP_PREFIX):]
- cur_module = module_np_op
+ func_name = name[len(op_name_prefix):]
+ cur_module = op_module
+ module_name_local =\
+ op_module_name[:-len('._op')] if op_module_name.endswith('._op') else op_module_name
function = make_op_func(hdl, name, func_name)
function.__module__ = module_name_local
setattr(cur_module, function.__name__, function)
cur_module.__all__.append(function.__name__)
+
+
+def set_module(module):
+ """Decorator for overriding __module__ on a function or class.
+
+ Example usage::
+
+ @set_module('mxnet.numpy')
+ def example():
+ pass
+
+ assert example.__module__ == 'numpy'
+ """
+ def decorator(func):
+ if module is not None:
+ func.__module__ = module
+ return func
+ return decorator
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index a142bc4..4f5d696 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -32,7 +32,7 @@ from ..symbol import Symbol
from ..ndarray import NDArray
from .. import name as _name
from .parameter import Parameter, ParameterDict, DeferredInitializationError
-from .utils import _indent, _brief_print_list, HookHandle
+from .utils import _indent, _brief_print_list, HookHandle, _check_same_symbol_type
from .. import numpy as _mx_np
@@ -746,7 +746,7 @@ class HybridBlock(Block):
out = self.hybrid_forward(symbol, *grouped_inputs, **params) # pylint: disable=no-value-for-parameter
out, self._out_format = _flatten(out, "output")
- self._cached_graph = inputs, symbol.Group(out)
+ self._cached_graph = inputs, symbol.Group(out, _check_same_symbol_type(out))
return self._cached_graph
@@ -1049,7 +1049,7 @@ class SymbolBlock(HybridBlock):
syms, self._in_format = _flatten(inputs, "input")
out, self._out_format = _flatten(outputs, "output")
- out = symbol.Group(out)
+ out = symbol.Group(out, _check_same_symbol_type(out))
input_names = set()
for i in syms:
diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index 8615422..f953774 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -412,3 +412,26 @@ class HookHandle(object):
def __exit__(self, ptype, value, trace):
self.detach()
+
+
+def _check_same_symbol_type(symbols):
+ """Check whether all the symbols in the list are of the same type.
+ Raise type error if the types are different. Return the class of
+ the symbols."""
+ from ..symbol.numpy import _Symbol as np_symbol
+ from ..symbol import Symbol as classic_symbol
+ is_np_sym = True if isinstance(symbols[0], np_symbol) else False
+ for s in symbols[1:]:
+ if is_np_sym != isinstance(s, np_symbol):
+ raise TypeError('Found both classic symbol (mx.sym.Symbol) and numpy symbol '
+ '(mx.sym.np._Symbol) in outputs. This will prevent you from building '
+ 'a computation graph by grouping them since different types of symbols '
+ 'are not allowed to be grouped in Gluon to form a computation graph. '
+ 'You will need to convert them to the same type of symbols, either '
+ 'classic or numpy following this rule: if you want numpy ndarray '
+ 'output(s) from the computation graph, please convert all the classic '
+ 'symbols in the list to numpy symbols by calling `as_np_ndarray()` '
+ 'on each of them; if you want classic ndarray output(s) from the '
+ 'computation graph, please convert all the numpy symbols in the list '
+ 'to classic symbols by calling `as_classic_ndarray()` on each of them.')
+ return np_symbol if is_np_sym else classic_symbol
diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py
index f0e6edb..c326850 100644
--- a/python/mxnet/ndarray/__init__.py
+++ b/python/mxnet/ndarray/__init__.py
@@ -31,6 +31,7 @@ from .utils import load, load_frombuffer, save, zeros, empty, array
from .sparse import _ndarray_cls
from .ndarray import _GRAD_REQ_MAP, _DTYPE_MX_TO_NP, _DTYPE_NP_TO_MX, _new_empty_handle
from . import numpy as np
+from . import numpy_extension as npe
__all__ = op.__all__ + ndarray.__all__ + utils.__all__ + \
- ['contrib', 'linalg', 'random', 'sparse', 'image']
+ ['contrib', 'linalg', 'random', 'sparse', 'image', 'numpy', 'numpy_extension']
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 9489d18..c461d22 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -186,15 +186,15 @@ fixed-size items.
def as_np_ndarray(self):
"""Convert mxnet.ndarray.NDArray to mxnet.numpy.ndarray."""
+ storage_type = self.stype
+ if storage_type != 'default':
+ raise ValueError('cannot convert ndarray of stype {} to numpy ndarray'
+ .format(str(type(storage_type))))
from ..numpy import ndarray
hdl = NDArrayHandle()
check_call(_LIB.MXShallowCopyNDArray(self.handle, ctypes.byref(hdl)))
return ndarray(handle=hdl, writable=self.writable)
- def _is_np_compat(self):
- """Always returns False except for mxnet.numpy.ndarray."""
- return False
-
@property
def _tvm_handle(self):
return self.handle.value
@@ -219,8 +219,6 @@ fixed-size items.
def __add__(self, other):
"""x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """
# other may be the type of mxnet.numpy.ndarray
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__add__(self)
return add(self, other)
def __iadd__(self, other):
@@ -235,15 +233,11 @@ fixed-size items.
raise TypeError('type %s not supported' % str(type(other)))
def __radd__(self, other):
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__add__(self)
return self.__add__(other)
def __sub__(self, other):
"""x.__sub__(y) <=> x-y <=> mx.nd.subtract(x, y) """
# other may be the type of mxnet.numpy.ndarray
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__rsub__(self)
return subtract(self, other)
def __isub__(self, other):
@@ -259,14 +253,10 @@ fixed-size items.
def __rsub__(self, other):
"""x.__rsub__(y) <=> y-x <=> mx.nd.subtract(y, x) """
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__sub__(self)
return subtract(other, self)
def __mul__(self, other):
"""x.__mul__(y) <=> x*y <=> mx.nd.multiply(x, y) """
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__mul__(self)
return multiply(self, other)
def __neg__(self):
@@ -285,20 +275,14 @@ fixed-size items.
raise TypeError('type %s not supported' % str(type(other)))
def __rmul__(self, other):
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__mul__(self)
return self.__mul__(other)
def __div__(self, other):
"""x.__div__(y) <=> x/y <=> mx.nd.divide(x, y) """
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__rtruediv__(self)
return divide(self, other)
def __rdiv__(self, other):
"""x.__rdiv__(y) <=> y/x <=> mx.nd.divide(y, x) """
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__truediv__(self)
return divide(other, self)
def __idiv__(self, other):
@@ -313,13 +297,9 @@ fixed-size items.
raise TypeError('type %s not supported' % str(type(other)))
def __truediv__(self, other):
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__rtruediv__(self)
return divide(self, other)
def __rtruediv__(self, other):
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__truediv__(self)
return divide(other, self)
def __itruediv__(self, other):
@@ -327,14 +307,10 @@ fixed-size items.
def __mod__(self, other):
"""x.__mod__(y) <=> x%y <=> mx.nd.modulo(x, y) """
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__rmod__(self)
return modulo(self, other)
def __rmod__(self, other):
"""x.__rmod__(y) <=> y%x <=> mx.nd.modulo(y, x) """
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__mod__(self)
return modulo(other, self)
def __imod__(self, other):
@@ -350,20 +326,14 @@ fixed-size items.
def __pow__(self, other):
"""x.__pow__(y) <=> x**y <=> mx.nd.power(x,y) """
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__rpow__(self)
return power(self, other)
def __rpow__(self, other):
"""x.__pow__(y) <=> y**x <=> mx.nd.power(y,x) """
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__pow__(self)
return power(other, self)
def __eq__(self, other):
"""x.__eq__(y) <=> x==y <=> mx.nd.equal(x, y) """
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__eq__(self)
return equal(self, other)
def __hash__(self):
@@ -372,32 +342,22 @@ fixed-size items.
def __ne__(self, other):
"""x.__ne__(y) <=> x!=y <=> mx.nd.not_equal(x, y) """
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__ne__(self)
return not_equal(self, other)
def __gt__(self, other):
"""x.__gt__(y) <=> x>y <=> mx.nd.greater(x, y) """
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__lt__(self)
return greater(self, other)
def __ge__(self, other):
"""x.__ge__(y) <=> x>=y <=> mx.nd.greater_equal(x, y) """
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__le__(self)
return greater_equal(self, other)
def __lt__(self, other):
"""x.__lt__(y) <=> x<y <=> mx.nd.lesser(x, y) """
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__gt__(self)
return lesser(self, other)
def __le__(self, other):
"""x.__le__(y) <=> x<=y <=> mx.nd.less_equal(x, y) """
- if isinstance(other, NDArray) and other._is_np_compat():
- return other.__ge__(self)
return lesser_equal(self, other)
def __bool__(self):
diff --git a/python/mxnet/ndarray/numpy/__init__.py b/python/mxnet/ndarray/numpy/__init__.py
index d97e808..7eb478f 100644
--- a/python/mxnet/ndarray/numpy/__init__.py
+++ b/python/mxnet/ndarray/numpy/__init__.py
@@ -15,12 +15,11 @@
# specific language governing permissions and limitations
# under the License.
-"""numpy module for numpy ops under mxnet.ndarray."""
+"""Module for numpy ops under mxnet.ndarray."""
-from . import ext
from . import random
from . import linalg
-from . import _op
+from . import _op, _internal
from . import _register
from ._op import * # pylint: disable=wildcard-import
diff --git a/python/mxnet/numpy/ext.py b/python/mxnet/ndarray/numpy/_internal.py
similarity index 91%
rename from python/mxnet/numpy/ext.py
rename to python/mxnet/ndarray/numpy/_internal.py
index e4c8251..c5f2928 100644
--- a/python/mxnet/numpy/ext.py
+++ b/python/mxnet/ndarray/numpy/_internal.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""namespace for registering numpy.ext ops for imperative programming."""
+"""Namespace for numpy internal ops."""
__all__ = []
diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 9b32c31..e905fdf 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -15,18 +15,19 @@
# specific language governing permissions and limitations
# under the License.
-"""numpy namespace for operators used in Gluon APIs dispatched by F=ndarray module."""
+"""Namespace for numpy operators used in Gluon dispatched by F=ndarray."""
from __future__ import absolute_import
import numpy as _np
-from ...base import _sanity_check_params, use_np_compat, numeric_types
+from ...base import _sanity_check_params, use_np_compat, numeric_types, set_module
from ...context import current_context
-from .. import _internal
+from . import _internal as _npi
from ..ndarray import NDArray
__all__ = ['zeros', 'ones', 'maximum', 'minimum']
+@set_module('mxnet.ndarray.numpy')
@use_np_compat
def zeros(shape, dtype=_np.float32, **kwargs):
"""Return a new array of given shape and type, filled with zeros.
@@ -55,9 +56,10 @@ def zeros(shape, dtype=_np.float32, **kwargs):
if ctx is None:
ctx = current_context()
dtype = _np.float32 if dtype is None else dtype
- return _internal._np_zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
+ return _npi.zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
+@set_module('mxnet.ndarray.numpy')
@use_np_compat
def ones(shape, dtype=None, **kwargs):
"""Return a new array of given shape and type, filled with ones.
@@ -86,7 +88,7 @@ def ones(shape, dtype=None, **kwargs):
if ctx is None:
ctx = current_context()
dtype = _np.float32 if dtype is None else dtype
- return _internal._np_ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
+ return _npi.ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
#pylint: disable= too-many-arguments, no-member, protected-access
@@ -138,6 +140,7 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou
#pylint: enable= too-many-arguments, no-member, protected-access
+@set_module('mxnet.ndarray.numpy')
@use_np_compat
def maximum(x1, x2, out=None):
"""Returns element-wise maximum of the input arrays with broadcasting.
@@ -152,10 +155,10 @@ def maximum(x1, x2, out=None):
-------
out : mxnet.numpy.ndarray or scalar
The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
- return _ufunc_helper(x1, x2, _internal._np_maximum, _np.maximum,
- _internal._np_maximum_scalar, None, out)
+ return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out)
+@set_module('mxnet.ndarray.numpy')
@use_np_compat
def minimum(x1, x2, out=None):
"""Returns element-wise minimum of the input arrays with broadcasting.
@@ -170,5 +173,4 @@ def minimum(x1, x2, out=None):
-------
out : mxnet.numpy.ndarray or scalar
The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
- return _ufunc_helper(x1, x2, _internal._np_minimum, _np.minimum,
- _internal._np_minimum_scalar, None, out)
+ return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum, _npi.minimum_scalar, None, out)
diff --git a/python/mxnet/ndarray/numpy/_register.py b/python/mxnet/ndarray/numpy/_register.py
index 840797f..3ac464e 100644
--- a/python/mxnet/ndarray/numpy/_register.py
+++ b/python/mxnet/ndarray/numpy/_register.py
@@ -15,10 +15,14 @@
# specific language governing permissions and limitations
# under the License.
-"""module for registering numpy ops under mxnet.ndarray.numpy."""
+"""Registering numpy ops."""
from ...base import _init_np_op_module
from ..register import _make_ndarray_function
-_init_np_op_module('mxnet', 'ndarray', _make_ndarray_function)
+_init_np_op_module(root_module_name='mxnet', np_module_name='numpy',
+ mx_module_name='ndarray', make_op_func=_make_ndarray_function)
+
+_init_np_op_module(root_module_name='mxnet', np_module_name='numpy._internal',
+ mx_module_name='ndarray', make_op_func=_make_ndarray_function)
diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py
index b8f10b3..8f521fd 100644
--- a/python/mxnet/ndarray/numpy/linalg.py
+++ b/python/mxnet/ndarray/numpy/linalg.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""numpy.linalg namespace for operators used in Gluon APIs dispatched by F=symbol module."""
+"""Namespace for operators used in Gluon dispatched by F=ndarray."""
__all__ = []
diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py
index 60908b5..8f521fd 100644
--- a/python/mxnet/ndarray/numpy/random.py
+++ b/python/mxnet/ndarray/numpy/random.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""numpy.random namespace for operators used in Gluon APIs dispatched by F=ndarray module."""
+"""Namespace for operators used in Gluon dispatched by F=ndarray."""
__all__ = []
diff --git a/python/mxnet/ndarray/numpy/__init__.py b/python/mxnet/ndarray/numpy_extension/__init__.py
similarity index 88%
copy from python/mxnet/ndarray/numpy/__init__.py
copy to python/mxnet/ndarray/numpy_extension/__init__.py
index d97e808..a718274 100644
--- a/python/mxnet/ndarray/numpy/__init__.py
+++ b/python/mxnet/ndarray/numpy_extension/__init__.py
@@ -15,11 +15,8 @@
# specific language governing permissions and limitations
# under the License.
-"""numpy module for numpy ops under mxnet.ndarray."""
+"""Module for the ops not belonging to the official numpy package."""
-from . import ext
-from . import random
-from . import linalg
from . import _op
from . import _register
from ._op import * # pylint: disable=wildcard-import
diff --git a/python/mxnet/ndarray/numpy/ext.py b/python/mxnet/ndarray/numpy_extension/_op.py
similarity index 86%
rename from python/mxnet/ndarray/numpy/ext.py
rename to python/mxnet/ndarray/numpy_extension/_op.py
index e13423f..22738a0 100644
--- a/python/mxnet/ndarray/numpy/ext.py
+++ b/python/mxnet/ndarray/numpy_extension/_op.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-"""numpy.ext namespace for operators used in Gluon APIs dispatched by F=ndarray module."""
+"""Namespace for the operators not belonging to the official numpy package
+used in Gluon dispatched by F=ndarray module."""
__all__ = []
diff --git a/python/mxnet/ndarray/numpy/_register.py b/python/mxnet/ndarray/numpy_extension/_register.py
similarity index 81%
copy from python/mxnet/ndarray/numpy/_register.py
copy to python/mxnet/ndarray/numpy_extension/_register.py
index 840797f..32cd068 100644
--- a/python/mxnet/ndarray/numpy/_register.py
+++ b/python/mxnet/ndarray/numpy_extension/_register.py
@@ -15,10 +15,11 @@
# specific language governing permissions and limitations
# under the License.
-"""module for registering numpy ops under mxnet.ndarray.numpy."""
+"""Registering numpy_extension ops."""
from ...base import _init_np_op_module
from ..register import _make_ndarray_function
-_init_np_op_module('mxnet', 'ndarray', _make_ndarray_function)
+_init_np_op_module(root_module_name='mxnet', np_module_name='numpy_extension',
+ mx_module_name='ndarray', make_op_func=_make_ndarray_function)
diff --git a/python/mxnet/ndarray/register.py b/python/mxnet/ndarray/register.py
index 1ccf228..a285e50 100644
--- a/python/mxnet/ndarray/register.py
+++ b/python/mxnet/ndarray/register.py
@@ -24,12 +24,60 @@ import numpy as _np # pylint: disable=unused-import
from ._internal import NDArrayBase, _imperative_invoke # pylint: disable=unused-import
from ..ndarray_doc import _build_doc
-from ..base import mx_uint, check_call, _LIB, py_str, _init_op_module, _Null # pylint: disable=unused-import
+from ..base import mx_uint, check_call, _LIB, py_str, _init_op_module, _Null, _is_np_op # pylint: disable=unused-import
+
+
+def _verify_all_np_ndarrays(op_name, func_name, *array_list):
+ """Verify if all the arrays are numpy ndarrays.
+
+ Parameters
+ ----------
+ op_name : str
+ Operator full name registered in backend.
+ func_name : str
+ Operator name exposed to users. This is usually the name by stripping off
+ the prefix of the full operator names registered in backend.
+ array_list : list of arrays
+ """
+ from ..numpy import ndarray as np_ndarray
+ for array in array_list:
+ if (array is not None) and (not isinstance(array, np_ndarray)):
+ raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
+ 'This is a numpy operator which can only accept '
+ 'MXNet numpy ndarrays, while received a classic ndarray. '
+ 'Please call `as_np_ndarray()` upon the classic ndarray to '
+ 'convert it to an MXNet numpy ndarray, and then feed the converted '
+ 'array to this operator.'
+ .format(op_name, func_name))
+
+
+def _verify_all_classic_ndarrays(op_name, func_name, *array_list):
+ """Verify if all the arrays are classic ndarrays.
+
+ Parameters
+ ----------
+ op_name : str
+ Operator full name registered in backend.
+ func_name : str
+ Operator name exposed to users. This is usually the name by stripping off
+ the prefix of the full operator names registered in backend.
+ array_list : list of arrays
+ """
+ from ..numpy import ndarray as np_ndarray
+ for array in array_list:
+ if (array is not None) and (isinstance(array, np_ndarray)):
+ raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
+ 'This is a classic operator which can only accept '
+ 'classic ndarrays, while received an MXNet numpy ndarray. '
+ 'Please call `as_classic_ndarray()` upon the numpy ndarray to '
+ 'convert it to a classic ndarray, and then feed the converted '
+ 'array to this operator.'
+ .format(op_name, func_name))
# pylint: disable=too-many-locals
-def _generate_ndarray_function_code(handle, name, func_name, signature_only=False):
- """Generate function for ndarray op by handle and function name."""
+def _generate_ndarray_function_code(handle, op_name, func_name, signature_only=False):
+ """Generate function for ndarray op by handle and function op_name."""
real_name = ctypes.c_char_p()
desc = ctypes.c_char_p()
num_args = mx_uint()
@@ -52,7 +100,7 @@ def _generate_ndarray_function_code(handle, name, func_name, signature_only=Fals
arg_types = [py_str(arg_types[i]) for i in range(narg)]
key_var_num_args = py_str(key_var_num_args.value)
ret_type = py_str(ret_type.value) if ret_type.value is not None else ''
- doc_str = _build_doc(name,
+ doc_str = _build_doc(op_name,
py_str(desc.value),
arg_names,
arg_types,
@@ -139,10 +187,16 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
keys.append('%s')
vals.append(_np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))
+ is_np_op = _is_np_op(op_name)
+ verify_ndarrays_fn =\
+ _verify_all_np_ndarrays.__name__ if is_np_op else _verify_all_classic_ndarrays.__name__
if not signature_only:
code.append("""
- return _imperative_invoke(%d, ndargs, keys, vals, out)"""%(
- handle.value))
+ {}("{}", "{}", out, *ndargs)
+ """.format(verify_ndarrays_fn, op_name, func_name))
+ code.append("""
+ return _imperative_invoke(%d, ndargs, keys, vals, out, %s)"""%(
+ handle.value, str(is_np_op)))
else:
code.append("""
return (0,)""")
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py
index 2a58f27..0f3c3c7 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/numpy/__init__.py
@@ -17,15 +17,15 @@
# specific language governing permissions and limitations
# under the License.
-"""numpy module for imperative programming."""
+"""Module for numpy ops used in imperative programming."""
from __future__ import absolute_import
from . import random
from . import linalg
-from . import ext
from .multiarray import * # pylint: disable=wildcard-import
from . import _op
from . import _register
from ._op import * # pylint: disable=wildcard-import
+from ..base import use_np_compat, set_np_compat, np_compat
__all__ = []
diff --git a/python/mxnet/numpy/_op.py b/python/mxnet/numpy/_op.py
index e6a918c..8f6f9cc 100644
--- a/python/mxnet/numpy/_op.py
+++ b/python/mxnet/numpy/_op.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""namespace for registering numpy ops for imperative programming."""
+"""Namespace for registering numpy ops for imperative programming."""
__all__ = []
diff --git a/python/mxnet/numpy/_register.py b/python/mxnet/numpy/_register.py
index 53ceecd..8a2d2ea 100644
--- a/python/mxnet/numpy/_register.py
+++ b/python/mxnet/numpy/_register.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-"""Register backend ops in mxnet.ndarray namespace."""
+"""Registering ops in mxnet.numpy for imperative programming."""
from __future__ import absolute_import
@@ -23,4 +23,5 @@ from ..base import _init_np_op_module
from ..ndarray.register import _make_ndarray_function
-_init_np_op_module('mxnet', 'numpy', _make_ndarray_function)
+_init_np_op_module(root_module_name='mxnet', np_module_name='numpy',
+ mx_module_name=None, make_op_func=_make_ndarray_function)
diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py
index 96c7ddc..e49bfcf 100644
--- a/python/mxnet/numpy/linalg.py
+++ b/python/mxnet/numpy/linalg.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""namespace for registering numpy.linalg ops for imperative programming."""
+"""Namespace for ops used in imperative programming."""
__all__ = []
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 6c414b4..dfcce0b 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -25,14 +25,14 @@ from __future__ import division
from array import array as native_array
import ctypes
import numpy as _np
-from ..ndarray import NDArray, _DTYPE_NP_TO_MX
+from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _GRAD_REQ_MAP
from ..ndarray._internal import _set_np_ndarray_class
from . import _op as _mx_np_op
from ..base import use_np_compat, check_call, _LIB, NDArrayHandle, _sanity_check_params
-from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types
+from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types, set_module
from ..context import current_context
from ..ndarray import numpy as _mx_nd_np
-from ..ndarray import _internal as _nd_internal
+from ..ndarray.numpy import _internal as _npi
__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum']
@@ -73,16 +73,14 @@ def _np_ndarray_cls(handle, writable=True, stype=0):
_set_np_ndarray_class(_np_ndarray_cls)
-class ndarray(NDArray): # pylint: disable=invalid-name
+@set_module('mxnet.numpy') # pylint: disable=invalid-name
+class ndarray(NDArray):
"""An array object represents a multidimensional, homogeneous array of fixed-size items.
An associated data-type object describes the format of each element in the array
(its byte-order, how many bytes it occupies in memory, whether it is an integer, a
floating point number, or something else, etc.). Arrays should be constructed using
`array`, `zeros` or `empty`. Currently, only c-contiguous arrays are supported."""
- def _is_np_compat(self):
- return True
-
@use_np_compat
def __getitem__(self, item):
# TODO(junwu): make output shape of integer indexing correct
@@ -90,15 +88,15 @@ class ndarray(NDArray): # pylint: disable=invalid-name
@use_np_compat
def __setitem__(self, key, value):
- super(ndarray, self).__setitem__(key, value)
+ self.as_classic_ndarray().__setitem__(key, value)
@use_np_compat
def __add__(self, other):
"""x.__add__(y) <=> x + y"""
- if isinstance(other, NDArray):
- return _nd_internal._np_add(self, other)
+ if isinstance(other, ndarray):
+ return _npi.add(self, other)
elif isinstance(other, numeric_types):
- return _nd_internal._np_add_scalar(self, float(other))
+ return _npi.add_scalar(self, float(other))
else:
raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
@@ -107,20 +105,20 @@ class ndarray(NDArray): # pylint: disable=invalid-name
"""x.__iadd__(y) <=> x += y"""
if not self.writable:
raise ValueError('trying to add to a readonly ndarray')
- if isinstance(other, NDArray):
- return _nd_internal._np_add(self, other, out=self)
+ if isinstance(other, ndarray):
+ return _npi.add(self, other, out=self)
elif isinstance(other, numeric_types):
- return _nd_internal._np_add_scalar(self, float(other), out=self)
+ return _npi.add_scalar(self, float(other), out=self)
else:
raise TypeError('type {} is not supported'.format(str(type(other))))
@use_np_compat
def __sub__(self, other):
"""x.__sub__(y) <=> x - y"""
- if isinstance(other, NDArray):
- return _nd_internal._np_subtract(self, other)
+ if isinstance(other, ndarray):
+ return _npi.subtract(self, other)
elif isinstance(other, numeric_types):
- return _nd_internal._np_subtract_scalar(self, float(other))
+ return _npi.subtract_scalar(self, float(other))
else:
raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
@@ -129,30 +127,30 @@ class ndarray(NDArray): # pylint: disable=invalid-name
"""x.__isub__(y) <=> x -= y"""
if not self.writable:
raise ValueError('trying to subtract from a readonly ndarray')
- if isinstance(other, NDArray):
- return _nd_internal._np_subtract(self, other, out=self)
+ if isinstance(other, ndarray):
+ return _npi.subtract(self, other, out=self)
elif isinstance(other, numeric_types):
- return _nd_internal._np_subtract_scalar(self, float(other), out=self)
+ return _npi.subtract_scalar(self, float(other), out=self)
else:
raise TypeError('type {} is not supported'.format(str(type(other))))
@use_np_compat
def __rsub__(self, other):
"""x.__rsub__(y) <=> y - x"""
- if isinstance(other, NDArray):
- return _nd_internal._np_subtract(other, self)
+ if isinstance(other, ndarray):
+ return _npi.subtract(other, self)
elif isinstance(other, numeric_types):
- return _nd_internal._np_rsubtract_scalar(self, float(other))
+ return _npi.rsubtract_scalar(self, float(other))
else:
raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
@use_np_compat
def __mul__(self, other):
"""x.__mul__(y) <=> x * y"""
- if isinstance(other, NDArray):
- return _nd_internal._np_multiply(self, other)
+ if isinstance(other, ndarray):
+ return _npi.multiply(self, other)
elif isinstance(other, numeric_types):
- return _nd_internal._np_multiply_scalar(self, float(other))
+ return _npi.multiply_scalar(self, float(other))
else:
raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
@@ -190,20 +188,20 @@ class ndarray(NDArray): # pylint: disable=invalid-name
@use_np_compat
def __truediv__(self, other):
"""x.__truediv__(y) <=> x / y"""
- if isinstance(other, NDArray):
- return _nd_internal._true_divide(self, other)
+ if isinstance(other, ndarray):
+ return _npi.true_divide(self, other)
elif isinstance(other, numeric_types):
- return _nd_internal._true_divide_scalar(self, float(other))
+ return _npi.true_divide_scalar(self, float(other))
else:
raise TypeError("ndarray does not support type {} as divisor".format(str(type(other))))
@use_np_compat
def __rtruediv__(self, other):
"""x.__rtruediv__(y) <=> y / x"""
- if isinstance(other, NDArray):
- return _nd_internal._true_divide(other, self)
+ if isinstance(other, ndarray):
+ return _npi.true_divide(other, self)
elif isinstance(other, numeric_types):
- return _nd_internal._rtrue_divide_scalar(self, float(other))
+ return _npi.rtrue_divide_scalar(self, float(other))
else:
raise TypeError("ndarray does not support type {} as dividend".format(str(type(other))))
@@ -214,20 +212,20 @@ class ndarray(NDArray): # pylint: disable=invalid-name
@use_np_compat
def __mod__(self, other):
"""x.__mod__(y) <=> x % y"""
- if isinstance(other, NDArray):
- return _nd_internal._np_mod(self, other)
+ if isinstance(other, ndarray):
+ return _npi.mod(self, other)
elif isinstance(other, numeric_types):
- return _nd_internal._np_mod_scalar(self, float(other))
+ return _npi.mod_scalar(self, float(other))
else:
raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
@use_np_compat
def __rmod__(self, other):
"""x.__rmod__(y) <=> y % x"""
- if isinstance(other, NDArray):
- return _nd_internal._np_mod(other, self)
+ if isinstance(other, ndarray):
+ return _npi.mod(other, self)
elif isinstance(other, numeric_types):
- return _nd_internal._np_rmod_scalar(self, float(other))
+ return _npi.rmod_scalar(self, float(other))
else:
raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
@@ -238,20 +236,20 @@ class ndarray(NDArray): # pylint: disable=invalid-name
@use_np_compat
def __pow__(self, other):
"""x.__pow__(y) <=> x ** y"""
- if isinstance(other, NDArray):
- return _nd_internal._np_power(self, other)
+ if isinstance(other, ndarray):
+ return _npi.power(self, other)
elif isinstance(other, numeric_types):
- return _nd_internal._np_power_scalar(self, float(other))
+ return _npi.power_scalar(self, float(other))
else:
raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
@use_np_compat
def __rpow__(self, other):
"""x.__rpow__(y) <=> y ** x"""
- if isinstance(other, NDArray):
- return _nd_internal._np_power(other, self)
+ if isinstance(other, ndarray):
+ return _npi.power(other, self)
elif isinstance(other, numeric_types):
- return _nd_internal._np_rpower_scalar(self, float(other))
+ return _npi.rpower_scalar(self, float(other))
else:
raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
@@ -355,15 +353,41 @@ class ndarray(NDArray): # pylint: disable=invalid-name
@use_np_compat
def __repr__(self):
- """Returns a string representation of the array."""
- return '%s\n<%s shape=%s ctx=%s>' % (str(self.asnumpy()), self.__class__.__name__,
- self.shape, self.context)
+ """Returns a string representation of the array using the following rules:
+ 1. If the `ndarray` is a scalar tensor, only the string of the scalar is returned.
+ 2. Else if the `ndarray` is allocated on cpu, the string of its numpy form, class name,
+ and shape is returned.
+ 3. Else (the `ndarray` is allocated on gpu), the string of its numpy form, class name,
+ shape, and context is returned."""
+ array_str = str(self.asnumpy())
+ if self.ndim == 0: # scalar tensor
+ return array_str
+ context = self.context
+ if context.device_type == 'gpu':
+ return '%s\n<%s shape=%s ctx=%s>' % (array_str, self.__class__.__name__, self.shape,
+ context)
+ else:
+ return '%s\n<%s shape=%s>' % (array_str, self.__class__.__name__, self.shape)
@use_np_compat
- def attach_grad(self, grad_req='write', stype=None):
- if stype is not None:
- raise NotImplementedError('mxnet.numpy.ndarray currently does not support stype')
- super(ndarray, self).attach_grad(grad_req, stype)
+ def attach_grad(self, grad_req='write'): # pylint: disable=arguments-differ
+ """Attach a gradient buffer to this ndarray, so that `backward`
+ can compute gradient with respect to it.
+
+ Parameters
+ ----------
+ grad_req : {'write', 'add', 'null'}
+ How gradient will be accumulated.
+ - 'write': gradient will be overwritten on every backward.
+ - 'add': gradient will be added to existing value on every backward.
+ - 'null': do not compute gradient for this NDArray.
+ """
+ grad = _mx_np_op.zeros_like(self) # pylint: disable=undefined-variable
+ grad_req = _GRAD_REQ_MAP[grad_req]
+ check_call(_LIB.MXAutogradMarkVariables(
+ 1, ctypes.pointer(self.handle),
+ ctypes.pointer(mx_uint(grad_req)),
+ ctypes.pointer(grad.handle)))
@property
def grad(self):
@@ -412,6 +436,43 @@ class ndarray(NDArray): # pylint: disable=invalid-name
self.copyto(res)
return res
+ @use_np_compat
+ def copyto(self, other):
+ """Copies the value of this array to another array.
+
+ If ``other`` is a ``ndarray`` object, then ``other.shape`` and
+ ``self.shape`` should be the same. This function copies the value from
+ ``self`` to ``other``.
+
+ If ``other`` is a context, a new ``NDArray`` will be first created on
+ the target context, and the value of ``self`` is copied.
+
+ Parameters
+ ----------
+ other : ndarray or Context
+ The destination array or context.
+
+ Returns
+ -------
+ ndarray
+ The copied array. If ``other`` is an ``ndarray``, then the return value
+ and ``other`` will point to the same ``ndarray``.
+
+ Examples
+ --------
+ >>> x = np.ones((2,3))
+ >>> y = np.zeros((2,3), mx.gpu(0))
+ >>> z = x.copyto(y)
+ >>> z is y
+ True
+ >>> y.asnumpy()
+ array([[ 1., 1., 1.],
+ [ 1., 1., 1.]], dtype=float32)
+ """
+ if isinstance(other, ndarray):
+ other = other.as_classic_ndarray()
+ return self.as_classic_ndarray().copyto(other).as_np_ndarray()
+
def asscalar(self):
raise AttributeError('mxnet.numpy.ndarray object has no attribute as_scalar')
@@ -435,7 +496,7 @@ class ndarray(NDArray): # pylint: disable=invalid-name
if order != 'C':
raise NotImplementedError('reshape only supports C-order,'
' while received {}'.format(order))
- return _mx_np_op.reshape(self, shape=shape, order=order)
+ return _mx_np_op.reshape(self, newshape=shape, order=order)
def reshape_like(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`reshape_like`.
@@ -1117,15 +1178,11 @@ class ndarray(NDArray): # pylint: disable=invalid-name
"""Number of elements in the array."""
return super(ndarray, self).size
- @property
- @use_np_compat
- def stype(self):
- raise AttributeError('mxnet.numpy.ndarray object has no attribute stype')
-
def tostype(self, stype):
raise AttributeError('mxnet.numpy.ndarray object has no attribute tostype')
+@set_module('mxnet.numpy')
@use_np_compat
def empty(shape, dtype=None, **kwargs):
"""Return a new array of given shape and type, without initializing entries.
@@ -1158,6 +1215,7 @@ def empty(shape, dtype=None, **kwargs):
return ndarray(handle=_new_alloc_handle(shape, ctx, False, dtype))
+@set_module('mxnet.numpy')
@use_np_compat
def array(object, dtype=None, **kwargs):
"""
@@ -1169,10 +1227,7 @@ def array(object, dtype=None, **kwargs):
An array, any object exposing the array interface, an object whose
__array__ method returns an array, or any (nested) sequence.
dtype : data-type, optional
- The desired data-type for the array. If not given, then the type will
- be determined as the minimum type required to hold the objects in the
- sequence. This argument can only be used to 'upcast' the array. For
- downcasting, use the .astype(t) method.
+ The desired data-type for the array. Default is `float32`.
ctx : device context, optional
Device context on which the memory is allocated. Default is
`mxnet.context.current_context()`.
@@ -1186,18 +1241,19 @@ def array(object, dtype=None, **kwargs):
ctx = kwargs.get('ctx', current_context())
if ctx is None:
ctx = current_context()
+ if dtype is None:
+ dtype = _np.float32
if not isinstance(object, (ndarray, NDArray, _np.ndarray)):
try:
object = _np.array(object, dtype=dtype)
except:
raise TypeError('source array must be an array like object')
- if dtype is None:
- dtype = object.dtype
ret = empty(object.shape, dtype=dtype, ctx=ctx)
ret[:] = object
return ret
+@set_module('mxnet.numpy')
def zeros(shape, dtype=_np.float32, **kwargs):
"""Return a new array of given shape and type, filled with zeros.
This function currently only supports storing multi-dimensional data
@@ -1223,6 +1279,7 @@ def zeros(shape, dtype=_np.float32, **kwargs):
return _mx_nd_np.zeros(shape, dtype, **kwargs)
+@set_module('mxnet.numpy')
def ones(shape, dtype=None, **kwargs):
"""Return a new array of given shape and type, filled with zeros.
This function currently only supports storing multi-dimensional data
@@ -1248,6 +1305,7 @@ def ones(shape, dtype=None, **kwargs):
return _mx_nd_np.ones(shape, dtype, **kwargs)
+@set_module('mxnet.numpy')
def maximum(x1, x2, out=None):
"""Returns element-wise maximum of the input arrays with broadcasting.
@@ -1264,6 +1322,7 @@ def maximum(x1, x2, out=None):
return _mx_nd_np.maximum(x1, x2, out=out)
+@set_module('mxnet.numpy')
def minimum(x1, x2, out=None):
"""Returns element-wise minimum of the input arrays with broadcasting.
diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py
index b1f4b02..e49bfcf 100644
--- a/python/mxnet/numpy/random.py
+++ b/python/mxnet/numpy/random.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""namespace for registering numpy.random ops for imperative programming."""
+"""Namespace for ops used in imperative programming."""
__all__ = []
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy_extension/__init__.py
similarity index 85%
copy from python/mxnet/numpy/__init__.py
copy to python/mxnet/numpy_extension/__init__.py
index 2a58f27..bd51175 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/numpy_extension/__init__.py
@@ -17,15 +17,12 @@
# specific language governing permissions and limitations
# under the License.
-"""numpy module for imperative programming."""
+"""Module for ops not belonging to the official numpy package for imperative programming."""
from __future__ import absolute_import
-from . import random
-from . import linalg
-from . import ext
-from .multiarray import * # pylint: disable=wildcard-import
from . import _op
from . import _register
from ._op import * # pylint: disable=wildcard-import
+from ..context import * # pylint: disable=wildcard-import
__all__ = []
diff --git a/python/mxnet/numpy/_op.py b/python/mxnet/numpy_extension/_op.py
similarity index 90%
copy from python/mxnet/numpy/_op.py
copy to python/mxnet/numpy_extension/_op.py
index e6a918c..a995e48 100644
--- a/python/mxnet/numpy/_op.py
+++ b/python/mxnet/numpy_extension/_op.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""namespace for registering numpy ops for imperative programming."""
+"""Namespace for registering numpy_extension ops for imperative programming."""
__all__ = []
diff --git a/python/mxnet/numpy/_register.py b/python/mxnet/numpy_extension/_register.py
similarity index 79%
copy from python/mxnet/numpy/_register.py
copy to python/mxnet/numpy_extension/_register.py
index 53ceecd..8abb725 100644
--- a/python/mxnet/numpy/_register.py
+++ b/python/mxnet/numpy_extension/_register.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-"""Register backend ops in mxnet.ndarray namespace."""
+"""Registering ops in mxnet.numpy_extension for imperative programming."""
from __future__ import absolute_import
@@ -23,4 +23,5 @@ from ..base import _init_np_op_module
from ..ndarray.register import _make_ndarray_function
-_init_np_op_module('mxnet', 'numpy', _make_ndarray_function)
+_init_np_op_module(root_module_name='mxnet', np_module_name='numpy_extension',
+ mx_module_name=None, make_op_func=_make_ndarray_function)
diff --git a/python/mxnet/symbol/__init__.py b/python/mxnet/symbol/__init__.py
index ae9477a..1cd8057 100644
--- a/python/mxnet/symbol/__init__.py
+++ b/python/mxnet/symbol/__init__.py
@@ -28,5 +28,7 @@ from .op import *
from .symbol import *
# pylint: enable=wildcard-import
from . import numpy as np
+from . import numpy_extension as npe
-__all__ = op.__all__ + symbol.__all__ + ['contrib', 'linalg', 'random', 'sparse', 'image']
+__all__ = op.__all__ + symbol.__all__\
+ + ['contrib', 'linalg', 'random', 'sparse', 'image', 'numpy', 'numpy_extension']
diff --git a/python/mxnet/symbol/numpy/__init__.py b/python/mxnet/symbol/numpy/__init__.py
index 1f20c03..857849c 100644
--- a/python/mxnet/symbol/numpy/__init__.py
+++ b/python/mxnet/symbol/numpy/__init__.py
@@ -15,13 +15,12 @@
# specific language governing permissions and limitations
# under the License.
-"""numpy module for numpy ops under mxnet.symbol."""
+"""Module for numpy ops under mxnet.symbol."""
from . import random
from . import linalg
-from . import ext
-from . import _op, _symbol
-from ._symbol import _NumpySymbol
+from . import _op, _symbol, _internal
+from ._symbol import _Symbol
from . import _register
from ._op import * # pylint: disable=wildcard-import
from ._symbol import * # pylint: disable=wildcard-import
diff --git a/python/mxnet/numpy/_op.py b/python/mxnet/symbol/numpy/_internal.py
similarity index 91%
copy from python/mxnet/numpy/_op.py
copy to python/mxnet/symbol/numpy/_internal.py
index e6a918c..c5f2928 100644
--- a/python/mxnet/numpy/_op.py
+++ b/python/mxnet/symbol/numpy/_internal.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""namespace for registering numpy ops for imperative programming."""
+"""Namespace for numpy internal ops."""
__all__ = []
diff --git a/python/mxnet/symbol/numpy/_op.py b/python/mxnet/symbol/numpy/_op.py
index 96da828..a4a979f 100644
--- a/python/mxnet/symbol/numpy/_op.py
+++ b/python/mxnet/symbol/numpy/_op.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""numpy namespace for operators used in Gluon APIs dispatched by F=symbol module."""
+"""Namespace for operators used in Gluon dispatched by F=symbol module."""
__all__ = []
diff --git a/python/mxnet/symbol/numpy/_register.py b/python/mxnet/symbol/numpy/_register.py
index 36dfd78..3245c8d 100644
--- a/python/mxnet/symbol/numpy/_register.py
+++ b/python/mxnet/symbol/numpy/_register.py
@@ -15,9 +15,14 @@
# specific language governing permissions and limitations
# under the License.
-"""module for registering numpy ops under mxnet.symbol.numpy."""
+"""Registering numpy ops."""
from ...base import _init_np_op_module
from ..register import _make_symbol_function
-_init_np_op_module('mxnet', 'symbol', _make_symbol_function)
+_init_np_op_module(root_module_name='mxnet', np_module_name='numpy',
+ mx_module_name='symbol', make_op_func=_make_symbol_function)
+
+
+_init_np_op_module(root_module_name='mxnet', np_module_name='numpy._internal',
+ mx_module_name='symbol', make_op_func=_make_symbol_function)
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 8cf6e30..0bbd96b 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -23,21 +23,17 @@ import ctypes
import numpy as _np
from . import _op as _mx_np_op
from ...base import _sanity_check_params, use_np_compat, check_call, _LIB, SymbolHandle
-from ...base import numeric_types
+from ...base import numeric_types, set_module
from ...context import current_context
-from .. import _internal
from ..symbol import Symbol
from .._internal import _set_np_symbol_class
-from .. import _internal as _sym_internal
+from . import _internal as _npi
__all__ = ['zeros', 'ones', 'maximum', 'minimum']
-class _NumpySymbol(Symbol):
-
- def _is_np_compat(self):
- return True
-
+@set_module('mxnet.symbol.numpy')
+class _Symbol(Symbol):
def __getitem__(self, item):
raise NotImplementedError
@@ -45,72 +41,72 @@ class _NumpySymbol(Symbol):
raise NotImplementedError
def __iter__(self):
- raise AttributeError('_NumpySymbol object has no attribute __iter__')
+ raise AttributeError('_Symbol object has no attribute __iter__')
@use_np_compat
def __add__(self, other):
"""x.__add__(y) <=> x + y"""
- if isinstance(other, Symbol):
- return _sym_internal._np_add(self, other)
+ if isinstance(other, _Symbol):
+ return _npi.add(self, other)
elif isinstance(other, numeric_types):
- return _sym_internal._np_add_scalar(self, float(other))
+ return _npi.add_scalar(self, float(other))
else:
- raise TypeError("_NumpySymbol does not support type {} as operand"
+ raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
@use_np_compat
def __sub__(self, other):
"""x.__sub__(y) <=> x - y"""
- if isinstance(other, Symbol):
- return _sym_internal._np_subtract(self, other)
+ if isinstance(other, _Symbol):
+ return _npi.subtract(self, other)
elif isinstance(other, numeric_types):
- return _sym_internal._np_subtract_scalar(self, float(other))
+ return _npi.subtract_scalar(self, float(other))
else:
- raise TypeError("_NumpySymbol does not support type {} as operand"
+ raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
@use_np_compat
def __rsub__(self, other):
"""x.__rsub__(y) <=> y - x"""
- if isinstance(other, Symbol):
- return _sym_internal._np_subtract(other, self)
+ if isinstance(other, _Symbol):
+ return _npi.subtract(other, self)
elif isinstance(other, numeric_types):
- return _sym_internal._np_rsubtract_scalar(self, float(other))
+ return _npi.rsubtract_scalar(self, float(other))
else:
- raise TypeError("_NumpySymbol does not support type {} as operand"
+ raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
@use_np_compat
def __mul__(self, other):
"""x.__mul__(y) <=> x * y"""
- if isinstance(other, Symbol):
- return _sym_internal._np_multiply(self, other)
+ if isinstance(other, _Symbol):
+ return _npi.multiply(self, other)
elif isinstance(other, numeric_types):
- return _sym_internal._np_multiply_scalar(self, float(other))
+ return _npi.multiply_scalar(self, float(other))
else:
- raise TypeError("_NumpySymbol does not support type {} as operand"
+ raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
@use_np_compat
def __rmul__(self, other):
"""x.__rmul__(y) <=> y * x"""
- if isinstance(other, Symbol):
- return _sym_internal._np_multiply(self, other)
+ if isinstance(other, _Symbol):
+ return _npi.multiply(self, other)
elif isinstance(other, numeric_types):
- return _sym_internal._np_multiply_scalar(self, float(other))
+ return _npi.multiply_scalar(self, float(other))
else:
- raise TypeError("_NumpySymbol does not support type {} as operand"
+ raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
def __div__(self, other):
- raise AttributeError('_NumpySymbol.__div__ is replaced by __truediv__. If you are using'
+ raise AttributeError('_Symbol.__div__ is replaced by __truediv__. If you are using'
' Python2, please use the statement from __future__ import division'
' to change the / operator to mean true division throughout the'
' module. If you are using Python3, this error should not have'
' been encountered.')
def __rdiv__(self, other):
- raise AttributeError('_NumpySymbol.__rdiv__ is replaced by __rtruediv__. If you are using'
+ raise AttributeError('_Symbol.__rdiv__ is replaced by __rtruediv__. If you are using'
' Python2, please use the statement from __future__ import division'
' to change the / operator to mean true division throughout the'
' module. If you are using Python3, this error should not have'
@@ -119,23 +115,23 @@ class _NumpySymbol(Symbol):
@use_np_compat
def __mod__(self, other):
"""x.__mod__(y) <=> x % y"""
- if isinstance(other, Symbol):
- return _sym_internal._np_mod(self, other)
+ if isinstance(other, _Symbol):
+ return _npi.mod(self, other)
elif isinstance(other, numeric_types):
- return _sym_internal._np_mod_scalar(self, float(other))
+ return _npi.mod_scalar(self, float(other))
else:
- raise TypeError("_NumpySymbol does not support type {} as operand"
+ raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
@use_np_compat
def __rmod__(self, other):
"""x.__rmod__(y) <=> y % x"""
- if isinstance(other, Symbol):
- return _sym_internal._np_mod(other, self)
+ if isinstance(other, _Symbol):
+ return _npi.mod(other, self)
elif isinstance(other, numeric_types):
- return _sym_internal._np_rmod_scalar(self, float(other))
+ return _npi.rmod_scalar(self, float(other))
else:
- raise TypeError("_NumpySymbol does not support type {} as operand"
+ raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
@use_np_compat
@@ -145,23 +141,23 @@ class _NumpySymbol(Symbol):
@use_np_compat
def __truediv__(self, other):
"""x.__truediv__(y) <=> x / y"""
- if isinstance(other, Symbol):
- return _sym_internal._true_divide(self, other)
+ if isinstance(other, _Symbol):
+ return _npi.true_divide(self, other)
elif isinstance(other, numeric_types):
- return _sym_internal._true_divide_scalar(self, float(other))
+ return _npi.true_divide_scalar(self, float(other))
else:
- raise TypeError("_NumpySymbol does not support type {} as divisor"
+ raise TypeError("_Symbol does not support type {} as divisor"
.format(str(type(other))))
@use_np_compat
def __rtruediv__(self, other):
"""x.__rtruediv__(y) <=> y / x"""
- if isinstance(other, Symbol):
- return _sym_internal._true_divide(other, self)
+ if isinstance(other, _Symbol):
+ return _npi.true_divide(other, self)
elif isinstance(other, numeric_types):
- return _sym_internal._rtrue_divide_scalar(self, float(other)).as_np_ndarray()
+ return _npi.rtrue_divide_scalar(self, float(other)).as_np_ndarray()
else:
- raise TypeError("_NumpySymbol does not support type {} as dividend"
+ raise TypeError("_Symbol does not support type {} as dividend"
.format(str(type(other))))
@use_np_compat
@@ -171,23 +167,23 @@ class _NumpySymbol(Symbol):
@use_np_compat
def __pow__(self, other):
"""x.__pow__(y) <=> x ** y"""
- if isinstance(other, Symbol):
- return _sym_internal._np_power(self, other)
+ if isinstance(other, _Symbol):
+ return _npi.power(self, other)
elif isinstance(other, numeric_types):
- return _sym_internal._np_power_scalar(self, float(other))
+ return _npi.power_scalar(self, float(other))
else:
- raise TypeError("_NumpySymbol does not support type {} as operand"
+ raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
@use_np_compat
def __rpow__(self, other):
"""x.__rpow__(y) <=> y ** x"""
- if isinstance(other, Symbol):
- return _sym_internal._np_power(other, self)
+ if isinstance(other, _Symbol):
+ return _npi.power(other, self)
elif isinstance(other, numeric_types):
- return _sym_internal._np_rpower_scalar(self, float(other))
+ return _npi.rpower_scalar(self, float(other))
else:
- raise TypeError("_NumpySymbol does not support type {} as operand"
+ raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
@use_np_compat
@@ -197,7 +193,7 @@ class _NumpySymbol(Symbol):
@use_np_compat
def __deepcopy__(self, _):
- return super(_NumpySymbol, self).as_np_ndarray()
+ return super(_Symbol, self).as_np_ndarray()
@use_np_compat
def __eq__(self, other):
@@ -233,7 +229,7 @@ class _NumpySymbol(Symbol):
raise NotImplementedError
def as_classic_ndarray(self):
- """Convert _NumpySymbol to mxnet.symbol.Symbol to use its convenience fluent methods."""
+ """Convert _Symbol to mxnet.symbol.Symbol to use its convenience fluent methods."""
hdl = SymbolHandle()
check_call(_LIB.MXShallowCopySymbol(self.handle, ctypes.byref(hdl)))
return Symbol(handle=hdl)
@@ -258,7 +254,7 @@ class _NumpySymbol(Symbol):
if order != 'C':
raise NotImplementedError('ndarray.copy only supports order=\'C\', while '
'received {}'.format(str(order)))
- return _mx_np_op.reshape(self, shape=shape, order=order)
+ return _mx_np_op.reshape(self, newshape=shape, order=order)
def reshape_like(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`reshape_like`.
@@ -266,7 +262,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`reshape_like`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute reshape_like')
+ raise AttributeError('_Symbol object has no attribute reshape_like')
def zeros_like(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`zeros_like`.
@@ -274,7 +270,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`zeros_like`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute zeros_like')
+ raise AttributeError('_Symbol object has no attribute zeros_like')
def ones_like(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`ones_like`.
@@ -282,7 +278,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`ones_like`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute ones_like')
+ raise AttributeError('_Symbol object has no attribute ones_like')
def broadcast_axes(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`broadcast_axes`.
@@ -290,7 +286,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`broadcast_axes`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute broadcast_like')
+ raise AttributeError('_Symbol object has no attribute broadcast_like')
@use_np_compat
def repeat(self, *args, **kwargs):
@@ -307,7 +303,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`pad`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute pad')
+ raise AttributeError('_Symbol object has no attribute pad')
@use_np_compat
def swapaxes(self, *args, **kwargs):
@@ -324,7 +320,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`split`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute split')
+ raise AttributeError('_Symbol object has no attribute split')
def split_v2(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`split_v2`.
@@ -332,7 +328,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`split_v2`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute split_v2')
+ raise AttributeError('_Symbol object has no attribute split_v2')
def slice(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`slice`.
@@ -340,7 +336,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`slice`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute slice')
+ raise AttributeError('_Symbol object has no attribute slice')
def slice_axis(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`slice_axis`.
@@ -348,7 +344,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`slice_axis`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute slice_axis')
+ raise AttributeError('_Symbol object has no attribute slice_axis')
def slice_like(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`slice_like`.
@@ -356,7 +352,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`slice_like`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute slice_like')
+ raise AttributeError('_Symbol object has no attribute slice_like')
@use_np_compat
def take(self, *args, **kwargs):
@@ -373,7 +369,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`one_hot`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute one_hot')
+ raise AttributeError('_Symbol object has no attribute one_hot')
def pick(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`pick`.
@@ -381,7 +377,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`pick`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute pick')
+ raise AttributeError('_Symbol object has no attribute pick')
@use_np_compat
def sort(self, *args, **kwargs):
@@ -398,7 +394,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`topk`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute topk')
+ raise AttributeError('_Symbol object has no attribute topk')
@use_np_compat
def argsort(self, *args, **kwargs):
@@ -424,7 +420,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`argmax_channel`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute argmax_channel')
+ raise AttributeError('_Symbol object has no attribute argmax_channel')
@use_np_compat
def argmin(self, *args, **kwargs):
@@ -450,7 +446,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`abs`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute abs')
+ raise AttributeError('_Symbol object has no attribute abs')
def sign(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sign`.
@@ -458,7 +454,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`sign`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute abs')
+ raise AttributeError('_Symbol object has no attribute abs')
@use_np_compat
def flatten(self, *args, **kwargs):
@@ -475,7 +471,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`shape_array`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute shape_array')
+ raise AttributeError('_Symbol object has no attribute shape_array')
def size_array(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`size_array`.
@@ -483,7 +479,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`size_array`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute size_array')
+ raise AttributeError('_Symbol object has no attribute size_array')
def expand_dims(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`expand_dims`.
@@ -491,7 +487,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`expand_dims`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute expand_dims')
+ raise AttributeError('_Symbol object has no attribute expand_dims')
def tile(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`tile`.
@@ -499,7 +495,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`tile`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute tile')
+ raise AttributeError('_Symbol object has no attribute tile')
@use_np_compat
def transpose(self, *axes): # pylint: disable=arguments-differ
@@ -516,7 +512,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`flip`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute flip')
+ raise AttributeError('_Symbol object has no attribute flip')
def depth_to_space(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`depth_to_space`.
@@ -524,7 +520,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`depth_to_space`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute depth_to_space')
+ raise AttributeError('_Symbol object has no attribute depth_to_space')
def space_to_depth(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`space_to_depth`.
@@ -532,7 +528,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`space_to_depth`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute space_to_depth')
+ raise AttributeError('_Symbol object has no attribute space_to_depth')
def diag(self, k=0, **kwargs):
"""Convenience fluent method for :py:func:`diag`.
@@ -540,7 +536,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`diag`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute diag')
+ raise AttributeError('_Symbol object has no attribute diag')
@use_np_compat
def sum(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ
@@ -557,7 +553,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`nansum`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute nansum')
+ raise AttributeError('_Symbol object has no attribute nansum')
@use_np_compat
def prod(self, *args, **kwargs):
@@ -574,7 +570,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`nanprod`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute nanprod')
+ raise AttributeError('_Symbol object has no attribute nanprod')
@use_np_compat
def mean(self, *args, **kwargs):
@@ -609,7 +605,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`norm`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute norm')
+ raise AttributeError('_Symbol object has no attribute norm')
@use_np_compat
def round(self, *args, **kwargs):
@@ -626,7 +622,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`rint`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute rint')
+ raise AttributeError('_Symbol object has no attribute rint')
def fix(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`fix`.
@@ -634,7 +630,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`fix`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute fix')
+ raise AttributeError('_Symbol object has no attribute fix')
def floor(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`floor`.
@@ -642,7 +638,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`floor`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute floor')
+ raise AttributeError('_Symbol object has no attribute floor')
def ceil(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`ceil`.
@@ -650,7 +646,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`ceil`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute ceil')
+ raise AttributeError('_Symbol object has no attribute ceil')
def trunc(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`trunc`.
@@ -658,7 +654,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`trunc`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute trunc')
+ raise AttributeError('_Symbol object has no attribute trunc')
def sin(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sin`.
@@ -666,7 +662,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`sin`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute sin')
+ raise AttributeError('_Symbol object has no attribute sin')
def cos(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`cos`.
@@ -674,7 +670,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`cos`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute cos')
+ raise AttributeError('_Symbol object has no attribute cos')
def tan(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`tan`.
@@ -682,7 +678,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`tan`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute tan')
+ raise AttributeError('_Symbol object has no attribute tan')
def arcsin(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`arcsin`.
@@ -690,7 +686,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`arcsin`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute arcsin')
+ raise AttributeError('_Symbol object has no attribute arcsin')
def arccos(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`arccos`.
@@ -698,7 +694,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`arccos`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute arccos')
+ raise AttributeError('_Symbol object has no attribute arccos')
def arctan(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`arctan`.
@@ -706,7 +702,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`arctan`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute arctan')
+ raise AttributeError('_Symbol object has no attribute arctan')
def degrees(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`degrees`.
@@ -714,7 +710,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`degrees`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute degrees')
+ raise AttributeError('_Symbol object has no attribute degrees')
def radians(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`radians`.
@@ -722,7 +718,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`radians`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute radians')
+ raise AttributeError('_Symbol object has no attribute radians')
def sinh(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sinh`.
@@ -730,7 +726,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`sinh`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute sinh')
+ raise AttributeError('_Symbol object has no attribute sinh')
def cosh(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`cosh`.
@@ -738,7 +734,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`cosh`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute cosh')
+ raise AttributeError('_Symbol object has no attribute cosh')
def tanh(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`tanh`.
@@ -746,7 +742,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`tanh`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute tanh')
+ raise AttributeError('_Symbol object has no attribute tanh')
def arcsinh(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`arcsinh`.
@@ -754,7 +750,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`arcsinh`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute arcsinh')
+ raise AttributeError('_Symbol object has no attribute arcsinh')
def arccosh(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`arccosh`.
@@ -762,7 +758,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`arccosh`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute arccosh')
+ raise AttributeError('_Symbol object has no attribute arccosh')
def arctanh(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`arctanh`.
@@ -770,7 +766,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`arctanh`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute arctanh')
+ raise AttributeError('_Symbol object has no attribute arctanh')
def exp(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`exp`.
@@ -778,7 +774,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`exp`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute exp')
+ raise AttributeError('_Symbol object has no attribute exp')
def expm1(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`expm1`.
@@ -786,7 +782,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`expm1`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute expm1')
+ raise AttributeError('_Symbol object has no attribute expm1')
def log(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`log`.
@@ -794,7 +790,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`log`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute log')
+ raise AttributeError('_Symbol object has no attribute log')
def log10(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`log10`.
@@ -802,7 +798,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`log10`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute log10')
+ raise AttributeError('_Symbol object has no attribute log10')
def log2(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`log2`.
@@ -810,7 +806,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`log2`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute log2')
+ raise AttributeError('_Symbol object has no attribute log2')
def log1p(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`log1p`.
@@ -818,7 +814,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`log1p`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute log1p')
+ raise AttributeError('_Symbol object has no attribute log1p')
def sqrt(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sqrt`.
@@ -826,7 +822,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`sqrt`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute sqrt')
+ raise AttributeError('_Symbol object has no attribute sqrt')
def rsqrt(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`rsqrt`.
@@ -834,7 +830,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`rsqrt`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute rsqrt')
+ raise AttributeError('_Symbol object has no attribute rsqrt')
def cbrt(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`cbrt`.
@@ -842,7 +838,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`cbrt`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute cqrt')
+ raise AttributeError('_Symbol object has no attribute cqrt')
def rcbrt(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`rcbrt`.
@@ -850,7 +846,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`rcbrt`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute rcqrt')
+ raise AttributeError('_Symbol object has no attribute rcqrt')
def square(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`square`.
@@ -858,7 +854,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`square`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute square')
+ raise AttributeError('_Symbol object has no attribute square')
def reciprocal(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`reciprocal`.
@@ -866,7 +862,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`reciprocal`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute reciprocal')
+ raise AttributeError('_Symbol object has no attribute reciprocal')
def relu(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`relu`.
@@ -874,7 +870,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`relu`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute relu')
+ raise AttributeError('_Symbol object has no attribute relu')
def sigmoid(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sigmoid`.
@@ -882,7 +878,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`sigmoid`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute sigmoid')
+ raise AttributeError('_Symbol object has no attribute sigmoid')
def softmax(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`softmax`.
@@ -890,7 +886,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`softmax`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute softmax')
+ raise AttributeError('_Symbol object has no attribute softmax')
def log_softmax(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`log_softmax`.
@@ -898,7 +894,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`log_softmax`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute log_softmax')
+ raise AttributeError('_Symbol object has no attribute log_softmax')
def softmin(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`softmin`.
@@ -906,7 +902,7 @@ class _NumpySymbol(Symbol):
The arguments are the same as for :py:func:`softmin`, with
this array as data.
"""
- raise AttributeError('_NumpySymbol object has no attribute softmin')
+ raise AttributeError('_Symbol object has no attribute softmin')
@use_np_compat
def squeeze(self, *args, **kwargs):
@@ -918,12 +914,13 @@ class _NumpySymbol(Symbol):
raise NotImplementedError
def broadcast_to(self, *args, **kwargs):
- raise AttributeError('_NumpySymbol object has no attribute broadcast_to')
+ raise AttributeError('_Symbol object has no attribute broadcast_to')
def broadcast_like(self, *args, **kwargs):
- raise AttributeError('_NumpySymbol object has no attribute broadcast_like')
+ raise AttributeError('_Symbol object has no attribute broadcast_like')
+@set_module('mxnet.symbol.numpy')
@use_np_compat
def zeros(shape, dtype=_np.float32, **kwargs):
"""Return a new array of given shape and type, filled with zeros.
@@ -952,9 +949,10 @@ def zeros(shape, dtype=_np.float32, **kwargs):
if ctx is None:
ctx = current_context()
dtype = _np.float32 if dtype is None else dtype
- return _internal._np_zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
+ return _npi.zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
+@set_module('mxnet.symbol.numpy')
@use_np_compat
def ones(shape, dtype=None, **kwargs):
"""Return a new array of given shape and type, filled with zeros.
@@ -983,7 +981,7 @@ def ones(shape, dtype=None, **kwargs):
if ctx is None:
ctx = current_context()
dtype = _np.float32 if dtype is None else dtype
- return _internal._np_ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
+ return _npi.ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
#pylint: disable= too-many-arguments, no-member, protected-access
@@ -1035,16 +1033,16 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou
#pylint: enable= too-many-arguments, no-member, protected-access
+@set_module('mxnet.symbol.numpy')
@use_np_compat
def maximum(x1, x2, out=None):
- return _ufunc_helper(x1, x2, _internal._np_maximum, _np.maximum,
- _internal._np_maximum_scalar, None, out)
+ return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out)
+@set_module('mxnet.symbol.numpy')
@use_np_compat
def minimum(x1, x2, out=None):
- return _ufunc_helper(x1, x2, _internal._np_minimum, _np.minimum,
- _internal._np_minimum_scalar, None, out)
+ return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum, _npi.minimum_scalar, None, out)
-_set_np_symbol_class(_NumpySymbol)
+_set_np_symbol_class(_Symbol)
diff --git a/python/mxnet/symbol/numpy/ext.py b/python/mxnet/symbol/numpy/ext.py
deleted file mode 100644
index 12c5f15..0000000
--- a/python/mxnet/symbol/numpy/ext.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""numpy.ext namespace for operators used in Gluon APIs dispatched by F=symbol module."""
-
-__all__ = []
diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py
index b8f10b3..869fdeb 100644
--- a/python/mxnet/symbol/numpy/linalg.py
+++ b/python/mxnet/symbol/numpy/linalg.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""numpy.linalg namespace for operators used in Gluon APIs dispatched by F=symbol module."""
+"""Namespace for operators used in Gluon dispatched by F=symbol."""
__all__ = []
diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py
index 79c73d8..869fdeb 100644
--- a/python/mxnet/symbol/numpy/random.py
+++ b/python/mxnet/symbol/numpy/random.py
@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-"""numpy.random namespace for operators used in Gluon APIs dispatched by F=symbol module."""
+"""Namespace for operators used in Gluon dispatched by F=symbol."""
__all__ = []
diff --git a/python/mxnet/ndarray/numpy/__init__.py b/python/mxnet/symbol/numpy_extension/__init__.py
similarity index 88%
copy from python/mxnet/ndarray/numpy/__init__.py
copy to python/mxnet/symbol/numpy_extension/__init__.py
index d97e808..a718274 100644
--- a/python/mxnet/ndarray/numpy/__init__.py
+++ b/python/mxnet/symbol/numpy_extension/__init__.py
@@ -15,11 +15,8 @@
# specific language governing permissions and limitations
# under the License.
-"""numpy module for numpy ops under mxnet.ndarray."""
+"""Module for the ops not belonging to the official numpy package."""
-from . import ext
-from . import random
-from . import linalg
from . import _op
from . import _register
from ._op import * # pylint: disable=wildcard-import
diff --git a/python/mxnet/symbol/numpy/_op.py b/python/mxnet/symbol/numpy_extension/_op.py
similarity index 86%
copy from python/mxnet/symbol/numpy/_op.py
copy to python/mxnet/symbol/numpy_extension/_op.py
index 96da828..82eaa8e 100644
--- a/python/mxnet/symbol/numpy/_op.py
+++ b/python/mxnet/symbol/numpy_extension/_op.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-"""numpy namespace for operators used in Gluon APIs dispatched by F=symbol module."""
+"""Namespace for operators not belonging to the official numpy package
+used in Gluon APIs dispatched by F=symbol module."""
__all__ = []
diff --git a/python/mxnet/symbol/numpy/_register.py b/python/mxnet/symbol/numpy_extension/_register.py
similarity index 81%
copy from python/mxnet/symbol/numpy/_register.py
copy to python/mxnet/symbol/numpy_extension/_register.py
index 36dfd78..b118987 100644
--- a/python/mxnet/symbol/numpy/_register.py
+++ b/python/mxnet/symbol/numpy_extension/_register.py
@@ -15,9 +15,10 @@
# specific language governing permissions and limitations
# under the License.
-"""module for registering numpy ops under mxnet.symbol.numpy."""
+"""Registering numpy_extension ops."""
from ...base import _init_np_op_module
from ..register import _make_symbol_function
-_init_np_op_module('mxnet', 'symbol', _make_symbol_function)
+_init_np_op_module(root_module_name='mxnet', np_module_name='numpy_extension',
+ mx_module_name='symbol', make_op_func=_make_symbol_function)
diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py
index ac59f8b..a835e2e 100644
--- a/python/mxnet/symbol/register.py
+++ b/python/mxnet/symbol/register.py
@@ -27,12 +27,58 @@ from ._internal import SymbolBase, _symbol_creator
from ..attribute import AttrScope
from ..base import mx_uint, check_call, _LIB, py_str
from ..symbol_doc import _build_doc
-from ..base import _Null, _init_op_module
+from ..base import _Null, _init_op_module, _is_np_op
from ..name import NameManager
# pylint: enable=unused-import
-def _generate_symbol_function_code(handle, name, func_name, signature_only=False):
+def _verify_np_symbol(op_name, func_name, sym):
+ """Verify if the sym is a numpy symbol.
+
+ Parameters
+ ----------
+ op_name : str
+ Operator full name registered in backend.
+ func_name : str
+ Operator name exposed to users. This is usually the name by stripping off
+ the prefix of the full operator names registered in backend.
+ sym : symbol to be verified
+ """
+ from .numpy._symbol import _Symbol as np_symbol
+ if not isinstance(sym, np_symbol):
+ raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
+ 'This is a numpy operator which can only accept '
+ 'MXNet numpy ndarrays, while received a classic ndarray. '
+ 'Please call `as_np_ndarray()` upon the classic ndarray to '
+ 'convert it to an MXNet numpy ndarray, and then feed the converted '
+ 'array to this operator.'
+ .format(op_name, func_name))
+
+
+def _verify_classic_symbol(op_name, func_name, sym):
+ """Verify if the sym is a classic symbol.
+
+ Parameters
+ ----------
+ op_name : str
+ Operator full name registered in backend.
+ func_name : str
+ Operator name exposed to users. This is usually the name by stripping off
+ the prefix of the full operator names registered in backend.
+ sym : symbol to be verified
+ """
+ from .numpy._symbol import _Symbol as np_symbol
+ if isinstance(sym, np_symbol):
+ raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
+ 'This is a classic operator which can only accept '
+ 'classic ndarrays, while received an MXNet numpy ndarray. '
+ 'Please call `as_classic_ndarray()` upon the numpy ndarray to '
+ 'convert it to a classic ndarray, and then feed the converted '
+ 'array to this operator.'
+ .format(op_name, func_name))
+
+
+def _generate_symbol_function_code(handle, op_name, func_name, signature_only=False):
"""Generate function for symbol op by handle and function name."""
real_name = ctypes.c_char_p()
desc = ctypes.c_char_p()
@@ -56,7 +102,7 @@ def _generate_symbol_function_code(handle, name, func_name, signature_only=False
arg_types = [py_str(arg_types[i]) for i in range(narg)]
key_var_num_args = py_str(key_var_num_args.value)
ret_type = py_str(ret_type.value) if ret_type.value is not None else ''
- doc_str = _build_doc(name,
+ doc_str = _build_doc(op_name,
py_str(desc.value),
arg_names,
arg_types,
@@ -95,6 +141,8 @@ def _generate_symbol_function_code(handle, name, func_name, signature_only=False
signature.append('**kwargs')
signature = ndsignature + signature
+ is_np_op = _is_np_op(op_name)
+ verify_symbol_fn = _verify_np_symbol.__name__ if is_np_op else _verify_classic_symbol.__name__
code = []
if arr_name:
code.append("""
@@ -106,7 +154,8 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name))
assert isinstance(i, SymbolBase), \\
"Positional arguments must be Symbol instances, " \\
"but got %s"%str(i)
- sym_args.append(i)""".format(arr_name))
+ {}('{}', '{}', i)
+ sym_args.append(i)""".format(arr_name, verify_symbol_fn, op_name, func_name))
if dtype_name is not None:
code.append("""
if '%s' in kwargs:
@@ -128,9 +177,10 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name))
for k, v in kwargs.items():
if isinstance(v, SymbolBase):
sym_kwargs[k] = v
+ %s('%s', '%s', v)
else:
keys.append(k)
- vals.append(v)"""%(func_name.lower()))
+ vals.append(v)"""%(func_name.lower(), verify_symbol_fn, op_name, func_name))
if key_var_num_args: # pylint: disable=using-constant-test
code.append("""
if '%s' not in kwargs:
@@ -139,8 +189,8 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name))
key_var_num_args, key_var_num_args))
code.append("""
- return _symbol_creator(%d, sym_args, sym_kwargs, keys, vals, name)"""%(
- handle.value))
+ return _symbol_creator(%d, sym_args, sym_kwargs, keys, vals, name, %s)"""%(
+ handle.value, str(is_np_op)))
else:
code.append("""
def %s(%s):"""%(func_name, ', '.join(signature)))
@@ -155,9 +205,10 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
for _k, _v in kwargs.items():
if isinstance(_v, SymbolBase):
sym_kwargs[_k] = _v
+ {}('{}', '{}', _v)
else:
_keys.append(_k)
- _vals.append(_v)""")
+ _vals.append(_v)""".format(verify_symbol_fn, op_name, func_name))
# NDArray args
for name in ndarg_names: # pylint: disable=redefined-argument-from-local
code.append("""
@@ -165,6 +216,9 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
assert isinstance({name}, SymbolBase), \\
"Argument {name} must be Symbol instances, but got %s"%str({name})
sym_kwargs['{name}'] = {name}""".format(name=name))
+ code.append("""
+ {}('{}', '{}', {name})
+ """.format(verify_symbol_fn, op_name, func_name, name=name))
# kwargs
for name in kwarg_names: # pylint: disable=redefined-argument-from-local
code.append("""
@@ -182,8 +236,8 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
if not hasattr(NameManager._current, "value"):
NameManager._current.value = NameManager()
name = NameManager._current.value.get(name, '%s')
- return _symbol_creator(%d, None, sym_kwargs, _keys, _vals, name)"""%(
- func_name.lower(), handle.value))
+ return _symbol_creator(%d, None, sym_kwargs, _keys, _vals, name, %s)"""%(
+ func_name.lower(), handle.value, str(is_np_op)))
if signature_only:
code.append("""
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 3de5913..d84a1cb 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -61,15 +61,11 @@ class Symbol(SymbolBase):
__array_priority__ = 1000.0
def as_np_ndarray(self):
- """Convert mxnet.symbol.Symbol to _NumpySymbol."""
- from .numpy import _NumpySymbol
+ """Convert mx.sym.Symbol to mx.sym.np._Symbol."""
+ from .numpy import _Symbol
hdl = SymbolHandle()
check_call(_LIB.MXShallowCopySymbol(self.handle, ctypes.byref(hdl)))
- return _NumpySymbol(hdl)
-
- def _is_np_compat(self):
- """Always returns False except for mxnet.symbol.numpy._NumpySymbol."""
- return False
+ return _Symbol(hdl)
def __repr__(self):
"""Gets a string representation of the symbol."""
@@ -109,8 +105,6 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_add` instead. """
if isinstance(other, Symbol):
- if other._is_np_compat():
- return other.__add__(self)
return _internal._Plus(self, other)
if isinstance(other, Number):
return _internal._PlusScalar(self, scalar=other)
@@ -126,8 +120,6 @@ class Symbol(SymbolBase):
raise NotImplementedForSymbol(self.__iadd__, '+=', other, 1)
def __radd__(self, other):
- if isinstance(other, Symbol) and other._is_np_compat():
- return other.__add__(self)
return self.__add__(other)
def __sub__(self, other):
@@ -136,8 +128,6 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_sub` instead. """
if isinstance(other, Symbol):
- if other._is_np_compat():
- return other.__rsub__(self)
return _internal._Minus(self, other)
if isinstance(other, Number):
return _internal._MinusScalar(self, scalar=other)
@@ -160,7 +150,7 @@ class Symbol(SymbolBase):
array([[-2., -2., -2.],
[-2., -2., -2.]], dtype=float32)
"""
- if isinstance(other, Symbol) and other._is_np_compat():
+ if isinstance(other, Symbol):
return other.__sub__(self)
if isinstance(other, Number):
return _internal._RMinusScalar(self, scalar=other)
@@ -173,8 +163,6 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_mul` instead. """
if isinstance(other, Symbol):
- if other._is_np_compat():
- return other.__mul__(self)
return _internal._Mul(self, other)
if isinstance(other, Number):
return _internal._MulScalar(self, scalar=other)
@@ -185,8 +173,6 @@ class Symbol(SymbolBase):
raise NotImplementedForSymbol(self.__imul__, '*=', other)
def __rmul__(self, other):
- if isinstance(other, Symbol) and other._is_np_compat():
- return other.__mul__(self)
return self.__mul__(other)
def __div__(self, other):
@@ -195,8 +181,6 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_div` instead. """
if isinstance(other, Symbol):
- if other._is_np_compat():
- return other.__rtruediv__(self)
return _internal._Div(self, other)
if isinstance(other, Number):
return _internal._DivScalar(self, scalar=other)
@@ -216,7 +200,7 @@ class Symbol(SymbolBase):
array([[ 0.33333334, 0.33333334, 0.33333334],
[ 0.33333334, 0.33333334, 0.33333334]], dtype=float32)
"""
- if isinstance(other, Symbol) and other._is_np_compat():
+ if isinstance(other, Symbol):
return other.__truediv__(self)
if isinstance(other, Number):
return _internal._RDivScalar(self, scalar=other)
@@ -229,8 +213,6 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_mod` instead. """
if isinstance(other, Symbol):
- if other._is_np_compat():
- return other.__rmod__(self)
return _internal._Mod(self, other)
if isinstance(other, Number):
return _internal._ModScalar(self, scalar=other)
@@ -250,7 +232,7 @@ class Symbol(SymbolBase):
array([[ 1., 1., 1.,
[ 1., 1., 1., dtype=float32)
"""
- if isinstance(other, Symbol) and other._is_np_compat():
+ if isinstance(other, Symbol):
return other.__mod__(self)
if isinstance(other, Number):
return _internal._RModScalar(self, scalar=other)
@@ -275,8 +257,6 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_pow` instead. """
if isinstance(other, Symbol):
- if other._is_np_compat():
- return other.__rpow__(self)
return _internal._Power(self, other)
if isinstance(other, Number):
return _internal._PowerScalar(self, scalar=other)
@@ -286,8 +266,6 @@ class Symbol(SymbolBase):
def __rpow__(self, other):
"""x.__rpow__(y) <=> y ** x"""
if isinstance(other, Symbol):
- if other._is_np_compat():
- return other.__pow__(self)
return other.__pow__(self)
elif isinstance(other, Number):
return _internal._rpower_scalar(self, scalar=other)
@@ -347,8 +325,6 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_equal` instead. """
if isinstance(other, Symbol):
- if other._is_np_compat():
- return other.__eq__(self)
return _internal._equal(self, other)
if isinstance(other, numeric_types):
return _internal._equal_scalar(self, scalar=other)
@@ -361,8 +337,6 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_not_equal` instead. """
if isinstance(other, Symbol):
- if other._is_np_compat():
- return other.__ne__(self)
return _internal._not_equal(self, other)
if isinstance(other, numeric_types):
return _internal._not_equal_scalar(self, scalar=other)
@@ -375,8 +349,6 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_greater` instead. """
if isinstance(other, Symbol):
- if other._is_np_compat():
- return other.__lt__(self)
return _internal._greater(self, other)
if isinstance(other, numeric_types):
return _internal._greater_scalar(self, scalar=other)
@@ -389,8 +361,6 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_greater_equal` instead. """
if isinstance(other, Symbol):
- if other._is_np_compat():
- return other.__le__(self)
return _internal._greater_equal(self, other)
if isinstance(other, numeric_types):
return _internal._greater_equal_scalar(self, scalar=other)
@@ -403,8 +373,6 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_lesser` instead. """
if isinstance(other, Symbol):
- if other._is_np_compat():
- return other.__gt__(self)
return _internal._lesser(self, other)
if isinstance(other, numeric_types):
return _internal._lesser_scalar(self, scalar=other)
@@ -417,8 +385,6 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_lesser_equal` instead. """
if isinstance(other, Symbol):
- if other._is_np_compat():
- return other.__ge__(self)
return _internal._lesser_equal(self, other)
if isinstance(other, numeric_types):
return _internal._lesser_equal_scalar(self, scalar=other)
@@ -2706,8 +2672,12 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None,
Variable = var
-def Group(symbols):
+def Group(symbols, create_fn=Symbol):
"""Creates a symbol that contains a collection of other symbols, grouped together.
+ A classic symbol (`mx.sym.Symbol`) will be returned if all the symbols in the list
+ are of that type; a numpy symbol (`mx.sym.np._Symbol`) will be returned if all the
+ symbols in the list are of that type. A type error will be raised if a list of mixed
+ classic and numpy symbols are provided.
Example
-------
@@ -2721,6 +2691,9 @@ def Group(symbols):
symbols : list
List of symbols to be grouped.
+ create_fn : mx.sym.Symbol or mx.sym.np._Symbol
+ Symbol class for creating the grouped symbol.
+
Returns
-------
sym : Symbol
@@ -2732,7 +2705,7 @@ def Group(symbols):
check_call(_LIB.MXSymbolCreateGroup(
mx_uint(len(symbols)),
c_handle_array(symbols), ctypes.byref(handle)))
- return Symbol(handle)
+ return create_fn(handle)
def load(fname):
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 6617c65..d91e2d7 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -47,6 +47,7 @@ from .context import Context, current_context
from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
from .ndarray import array
from .symbol import Symbol
+from .symbol.numpy import _Symbol as np_symbol
def default_context():
@@ -887,7 +888,12 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto
input_shape = {k: v.shape for k, v in location.items()}
_, out_shape, _ = sym.infer_shape(**input_shape)
proj = mx.sym.Variable("__random_proj")
+ is_np_sym = True if isinstance(sym, np_symbol) else False
+ if is_np_sym: # convert to np symbol for using element-wise multiplication
+ proj = proj.as_np_ndarray()
out = sym * proj
+ if is_np_sym: # convert to classic symbol so that make_loss can be used
+ out = out.as_classic_ndarray()
out = mx.sym.make_loss(out)
location = dict(list(location.items()) +
diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h
index 82fe28b..233acc8 100644
--- a/src/c_api/c_api_common.h
+++ b/src/c_api/c_api_common.h
@@ -163,21 +163,4 @@ inline void CopyAttr(const nnvm::IndexedGraph& idx,
extern const std::vector<std::string> kHiddenKeys;
} // namespace mxnet
-/*!
- * An operator is considered as numpy compatible if it satisfies either one
- * of the following conditions.
- * 1. The op has the attribute mxnet::TIsNumpyCompatible> registered as True.
- * 2. The op's name starts with the prefix _numpy_.
- * The first condition is usually for the ops registered as internal ops, such
- * as _np_add, _true_divide, etc. They are wrapped by some user-facing op
- * APIs in the Python end.
- * The second condition is for the ops registered in the backend while exposed
- * directly to users as is, such as _numpy_sum etc.
- */
-inline bool IsNumpyCompatOp(const nnvm::Op* op) {
- static const auto& is_np_compat =
- nnvm::Op::GetAttr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible");
- return is_np_compat.get(op, false);
-}
-
#endif // MXNET_C_API_C_API_COMMON_H_
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index 4da573e..0e136b0 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -378,19 +378,3 @@ int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out) {
*out = reinterpret_cast<SymbolHandle>(sym);
API_END();
}
-
-int MXIsCachedOpOutputFromNumpyCompatOp(CachedOpHandle handle,
- int output_idx,
- int* is_from_np_op) {
- API_BEGIN();
- CachedOpPtr op = *static_cast<CachedOpPtr*>(handle);
- const auto& output_entries = op->GetForwardSym().outputs;
- CHECK_LT(output_idx, static_cast<int>(output_entries.size()));
- const nnvm::NodePtr& node_ptr = output_entries[output_idx].node;
- if (node_ptr->is_variable()) {
- *is_from_np_op = 0;
- } else {
- *is_from_np_op = (IsNumpyCompatOp(node_ptr->op()) ? 1 : 0);
- }
- API_END();
-}
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index 2a03f3e..af457a5 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -840,13 +840,6 @@ int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend,
API_END_HANDLE_ERROR(delete s);
}
-int MXIsNumpyCompatOp(AtomicSymbolCreator creator, int* is_np_op) {
- API_BEGIN();
- const nnvm::Op* op = static_cast<Op*>(creator);
- *is_np_op = IsNumpyCompatOp(op) ? 1 : 0;
- API_END();
-}
-
int MXShallowCopySymbol(SymbolHandle src, SymbolHandle* out) {
nnvm::Symbol* out_sym = new nnvm::Symbol;
API_BEGIN();
diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h
index 2c4d579..0f3d71d 100644
--- a/src/operator/numpy/np_broadcast_reduce_op.h
+++ b/src/operator/numpy/np_broadcast_reduce_op.h
@@ -169,6 +169,7 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
if (param.initial.has_value()) {
LOG(FATAL) << "initial is not supported yet";
}
+ if (outputs[0].shape_.Size() == 0U) return; // zero-size tensor
if (param.axis.has_value() && param.axis.value().ndim() == 0) {
UnaryOp::IdentityCompute<xpu>(attrs, ctx, inputs, req, outputs);
}
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc
index c1c1132..a72efd9 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cc
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc
@@ -47,7 +47,7 @@ inline bool NumpySumType(const nnvm::NodeAttrs& attrs,
return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
}
-NNVM_REGISTER_OP(_numpy_sum)
+NNVM_REGISTER_OP(_np_sum)
.describe(R"code()code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
@@ -61,14 +61,13 @@ NNVM_REGISTER_OP(_numpy_sum)
.add_argument("a", "NDArray-or-Symbol", "The input")
.add_arguments(NumpyReduceAxesParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesCompute<cpu, mshadow_op::sum, true>)
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_numpy_sum"});
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_np_sum"});
-NNVM_REGISTER_OP(_backward_numpy_sum)
+NNVM_REGISTER_OP(_backward_np_sum)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyReduceAxesParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
@@ -102,7 +101,7 @@ inline bool NumpyMeanType(const nnvm::NodeAttrs& attrs,
return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
}
-NNVM_REGISTER_OP(_numpy_mean)
+NNVM_REGISTER_OP(_np_mean)
.describe(R"code()code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
@@ -116,14 +115,13 @@ NNVM_REGISTER_OP(_numpy_mean)
.add_argument("a", "NDArray-or-Symbol", "The input")
.add_arguments(NumpyReduceAxesParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesCompute<cpu, mshadow_op::sum, true, true>)
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_numpy_mean"});
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_np_mean"});
-NNVM_REGISTER_OP(_backward_numpy_mean)
+NNVM_REGISTER_OP(_backward_np_mean)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyReduceAxesParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu b/src/operator/numpy/np_broadcast_reduce_op_value.cu
index f16745d..2f50738 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cu
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu
@@ -27,16 +27,16 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_numpy_sum)
+NNVM_REGISTER_OP(_np_sum)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesCompute<gpu, mshadow_op::sum, true>);
-NNVM_REGISTER_OP(_backward_numpy_sum)
+NNVM_REGISTER_OP(_backward_np_sum)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseNone<gpu>);
-NNVM_REGISTER_OP(_numpy_mean)
+NNVM_REGISTER_OP(_np_mean)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesCompute<gpu, mshadow_op::sum, true, true>);
-NNVM_REGISTER_OP(_backward_numpy_mean)
+NNVM_REGISTER_OP(_backward_np_mean)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseNone<gpu, true>);
diff --git a/src/operator/numpy/np_dot-inl.h b/src/operator/numpy/np_dot-inl.h
index 8fc7d5d..2f7c589 100644
--- a/src/operator/numpy/np_dot-inl.h
+++ b/src/operator/numpy/np_dot-inl.h
@@ -95,6 +95,7 @@ inline void NumpyDotForward(const nnvm::NodeAttrs& attrs,
const TBlob& a = inputs[0];
const TBlob& b = inputs[1];
const TBlob& out = outputs[0];
+ if (out.shape_.Size() == 0U) return; // zero-size tensor, no need to launch kernel
const mxnet::TShape a_shape = a.shape_;
const mxnet::TShape b_shape = b.shape_;
@@ -107,7 +108,13 @@ inline void NumpyDotForward(const nnvm::NodeAttrs& attrs,
(out.type_flag_ == kFloat16 && ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask))
<< "dot only supports float32/float64 for CPU, and float16/float32/float64 for GPU";
MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, {
- if (a_shape.ndim() == 1 && b_shape.ndim() == 1) {
+ if (a_shape.Size() == 0U || b_shape.Size() == 0U) {
+ if (req[0] != kAddTo) {
+ Tensor<xpu, 1, DType> out_data = out.get_with_shape<xpu, 1, DType>(
+ Shape1(out.shape_.Size()), s);
+ out_data = static_cast<DType>(0);
+ }
+ } else if (a_shape.ndim() == 1 && b_shape.ndim() == 1) {
// Case 1: both 1-D arrays, inner product of vectors
if (out.type_flag_ == kFloat16) {
MMImpl<xpu>(ctx, a, b, out, req[0]);
@@ -158,12 +165,14 @@ inline void NumpyDotBackward(const nnvm::NodeAttrs& attrs,
CHECK_EQ(outputs.size(), 2U);
const TBlob& ograd = inputs[0];
+ if (ograd.shape_.Size() == 0U) return;
const TBlob& a = inputs[1];
const TBlob& b = inputs[2];
const TBlob& grad_a = outputs[0];
const TBlob& grad_b = outputs[1];
const mxnet::TShape a_shape = a.shape_;
const mxnet::TShape b_shape = b.shape_;
+ if (a_shape.Size() == 0U || b_shape.Size() == 0U) return;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(ograd.type_flag_, DType, {
diff --git a/src/operator/numpy/np_dot.cc b/src/operator/numpy/np_dot.cc
index c25953f..bcb310f 100644
--- a/src/operator/numpy/np_dot.cc
+++ b/src/operator/numpy/np_dot.cc
@@ -71,7 +71,7 @@ inline bool NumpyDotShape(const nnvm::NodeAttrs& attrs,
return true;
}
-NNVM_REGISTER_OP(_numpy_dot)
+NNVM_REGISTER_OP(_np_dot)
.describe(R"doc(Dot product of two arrays. Specifically,
- If both a and b are 1-D arrays, it is inner product of vectors.
diff --git a/src/operator/numpy/np_dot.cu b/src/operator/numpy/np_dot.cu
index 2accd9d..9a9c69a 100644
--- a/src/operator/numpy/np_dot.cu
+++ b/src/operator/numpy/np_dot.cu
@@ -27,7 +27,7 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_numpy_dot)
+NNVM_REGISTER_OP(_np_dot)
.set_attr<FCompute>("FCompute<gpu>", NumpyDotForward<gpu>);
NNVM_REGISTER_OP(_backward_np_dot)
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc
index 5d36c29..2ffa3b8 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cc
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cc
@@ -57,12 +57,11 @@ bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs,
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}}; \
}) \
- .set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true) \
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_argument("scalar", "float", "scalar input")
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_add)
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_add)
.describe(R"code(Add arguments element-wise with broadcasting if necessary.
Example::
@@ -78,10 +77,9 @@ Example::
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::plus>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"});
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_subtract)
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_subtract)
.describe(R"code(Subtract arguments element-wise with broadcasting if necessary.
Example::
@@ -97,10 +95,9 @@ Example::
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::minus>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"});
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_multiply)
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_multiply)
.describe(R"code(Multiply arguments with broadcasting if necessary.
Example::
@@ -116,10 +113,9 @@ Example::
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::mul>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_mod)
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod)
.describe(R"code(Return element-wise remainder of division.
It is equivalent to the Python modulus operator``x1 % x2`` and has the same sign as the divisor x2.
@@ -136,10 +132,9 @@ Example::
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::mod>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mod"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mod"});
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_power)
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_power)
.describe(R"code(First array elements raised to powers from second array, element-wise.
Raise each base in x1 to the positionally-corresponding power in x2. x1 and x2 must be
@@ -158,56 +153,53 @@ Example::
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::power>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"});
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_maximum)
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_maximum)
.describe(R"code()code" ADD_FILELINE)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::maximum>)
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::maximum>);
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_minimum)
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_minimum)
.describe(R"code()code" ADD_FILELINE)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::minimum>)
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::minimum>);
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_add_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::plus>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_subtract_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_subtract_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::minus>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_rsubtract_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rsubtract_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rminus>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"negative"});
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_multiply_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_multiply_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::mul>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_mul_scalar"});
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_mod_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_mod_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::mod>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_mod_scalar"});
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_rmod_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rmod_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rmod>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_rmod_scalar"});
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_power_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_power_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::power>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_power_scalar"});
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_rpower_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rpower_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rpower>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_rpower_scalar"});
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_maximum_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_maximum_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::maximum>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_maximum_scalar"});
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_minimum_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_minimum_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::minimum>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_minimum_scalar"});
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu
index 26e2fce..c858b3a 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cu
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cu
@@ -27,55 +27,55 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_np_add)
+NNVM_REGISTER_OP(_npi_add)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::plus>);
-NNVM_REGISTER_OP(_np_subtract)
+NNVM_REGISTER_OP(_npi_subtract)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::minus>);
-NNVM_REGISTER_OP(_np_multiply)
+NNVM_REGISTER_OP(_npi_multiply)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::mul>);
-NNVM_REGISTER_OP(_np_mod)
+NNVM_REGISTER_OP(_npi_mod)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::mod>);
-NNVM_REGISTER_OP(_np_power)
+NNVM_REGISTER_OP(_npi_power)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::power>);
-NNVM_REGISTER_OP(_np_maximum)
+NNVM_REGISTER_OP(_npi_maximum)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::maximum>);
-NNVM_REGISTER_OP(_np_minimum)
+NNVM_REGISTER_OP(_npi_minimum)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::minimum>);
-NNVM_REGISTER_OP(_np_add_scalar)
+NNVM_REGISTER_OP(_npi_add_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::plus>);
-NNVM_REGISTER_OP(_np_subtract_scalar)
+NNVM_REGISTER_OP(_npi_subtract_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::minus>);
-NNVM_REGISTER_OP(_np_rsubtract_scalar)
+NNVM_REGISTER_OP(_npi_rsubtract_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rminus>);
-NNVM_REGISTER_OP(_np_multiply_scalar)
+NNVM_REGISTER_OP(_npi_multiply_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::mul>);
-NNVM_REGISTER_OP(_np_mod_scalar)
+NNVM_REGISTER_OP(_npi_mod_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::mod>);
-NNVM_REGISTER_OP(_np_rmod_scalar)
+NNVM_REGISTER_OP(_npi_rmod_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rmod>);
-NNVM_REGISTER_OP(_np_power_scalar)
+NNVM_REGISTER_OP(_npi_power_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::power>);
-NNVM_REGISTER_OP(_np_rpower_scalar)
+NNVM_REGISTER_OP(_npi_rpower_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rpower>);
-NNVM_REGISTER_OP(_np_maximum_scalar)
+NNVM_REGISTER_OP(_npi_maximum_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::maximum>);
-NNVM_REGISTER_OP(_np_minimum_scalar)
+NNVM_REGISTER_OP(_npi_minimum_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::minimum>);
} // namespace op
diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cc b/src/operator/numpy/np_elemwise_unary_op_basic.cc
index f31ed5e..a64356e 100644
--- a/src/operator/numpy/np_elemwise_unary_op_basic.cc
+++ b/src/operator/numpy/np_elemwise_unary_op_basic.cc
@@ -27,7 +27,7 @@
namespace mxnet {
namespace op {
-MXNET_OPERATOR_REGISTER_UNARY(_numpy__ext_relu)
+MXNET_OPERATOR_REGISTER_UNARY(_npe_relu)
.describe(R"code(Computes rectified linear activation.
.. math::
@@ -35,10 +35,9 @@ MXNET_OPERATOR_REGISTER_UNARY(_numpy__ext_relu)
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::relu>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_relu"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_relu"});
-MXNET_OPERATOR_REGISTER_UNARY(_numpy__ext_sigmoid)
+MXNET_OPERATOR_REGISTER_UNARY(_npe_sigmoid)
.describe(R"code(Computes sigmoid of x element-wise.
.. math::
@@ -46,18 +45,29 @@ MXNET_OPERATOR_REGISTER_UNARY(_numpy__ext_sigmoid)
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::sigmoid>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_sigmoid"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_sigmoid"});
-MXNET_OPERATOR_REGISTER_UNARY(_np_copy)
-.MXNET_DESCRIBE("Returns a copy of the input.")
+NNVM_REGISTER_OP(_np_copy)
+.describe(R"code(Return an array copy of the given object.)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+ [](const NodeAttrs& attrs){
+ return std::vector<std::pair<int, int> >{{0, 0}};
+ })
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
[](const NodeAttrs& attrs){
return std::vector<bool>{true};
})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"a"};
+ })
+.add_argument("a", "NDArray-or-Symbol", "The input");
} // namespace op
} // namespace mxnet
diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cu b/src/operator/numpy/np_elemwise_unary_op_basic.cu
index 9f108f7..600f198 100644
--- a/src/operator/numpy/np_elemwise_unary_op_basic.cu
+++ b/src/operator/numpy/np_elemwise_unary_op_basic.cu
@@ -26,10 +26,10 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_numpy__ext_relu)
+NNVM_REGISTER_OP(_npe_relu)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::relu>);
-NNVM_REGISTER_OP(_numpy__ext_sigmoid)
+NNVM_REGISTER_OP(_npe_sigmoid)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::sigmoid>);
NNVM_REGISTER_OP(_np_copy)
diff --git a/src/operator/numpy/np_init_op.cc b/src/operator/numpy/np_init_op.cc
index 0abd010..83a44c8 100644
--- a/src/operator/numpy/np_init_op.cc
+++ b/src/operator/numpy/np_init_op.cc
@@ -28,7 +28,7 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_np_zeros)
+NNVM_REGISTER_OP(_npi_zeros)
.describe("Return a new array of given shape, type, and context, filled with zeros.")
.set_num_inputs(0)
.set_num_outputs(1)
@@ -37,10 +37,9 @@ NNVM_REGISTER_OP(_np_zeros)
.set_attr<nnvm::FInferType>("FInferType", InitType<InitOpParam>)
.set_attr<FInferStorageType>("FInferStorageType", InitStorageType<InitOpParam, true, true>)
.set_attr<FCompute>("FCompute<cpu>", FillCompute<cpu, 0>)
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
.add_arguments(InitOpParam::__FIELDS__());
-NNVM_REGISTER_OP(_np_ones)
+NNVM_REGISTER_OP(_npi_ones)
.describe("Return a new array of given shape, type, and context, filled with ones.")
.set_num_inputs(0)
.set_num_outputs(1)
@@ -48,8 +47,65 @@ NNVM_REGISTER_OP(_np_ones)
.set_attr<mxnet::FInferShape>("FInferShape", InitShape<InitOpParam>)
.set_attr<nnvm::FInferType>("FInferType", InitType<InitOpParam>)
.set_attr<FCompute>("FCompute<cpu>", FillCompute<cpu, 1>)
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
.add_arguments(InitOpParam::__FIELDS__());
+NNVM_REGISTER_OP(_np_zeros_like)
+.describe(R"code(Return an array of zeros with the same shape and type as a given array.
+
+Examples::
+
+ x = [[ 1., 1., 1.],
+ [ 1., 1., 1.]]
+
+ zeros_like(x) = [[ 0., 0., 0.],
+ [ 0., 0., 0.]]
+
+)code")
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FIgnoreInputs>("FIgnoreInputs",
+ [](const NodeAttrs& attrs) {
+ return std::vector<uint32_t>(1, 0);
+ })
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"a"};
+ })
+.set_attr<FCompute>("FCompute<cpu>", FillCompute<cpu, 0>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_argument("a", "NDArray-or-Symbol",
+ "The shape and data-type of a define these same attributes of the returned array.");
+
+NNVM_REGISTER_OP(_np_ones_like)
+.describe(R"code(Return an array of ones with the same shape and type as a given array.
+
+Examples::
+
+ x = [[ 0., 0., 0.],
+ [ 0., 0., 0.]]
+
+ ones_like(x) = [[ 1., 1., 1.],
+ [ 1., 1., 1.]]
+
+)code")
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FIgnoreInputs>("FIgnoreInputs",
+ [](const NodeAttrs& attrs) {
+ return std::vector<uint32_t>(1, 0);
+ })
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"a"};
+ })
+.set_attr<FCompute>("FCompute<cpu>", FillCompute<cpu, 1>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_argument("a", "NDArray-or-Symbol",
+ "The shape and data-type of a define these same attributes of the returned array.");
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/numpy/np_init_op.cu b/src/operator/numpy/np_init_op.cu
index 4e6f81d..2eb8ed6 100644
--- a/src/operator/numpy/np_init_op.cu
+++ b/src/operator/numpy/np_init_op.cu
@@ -28,10 +28,16 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_np_zeros)
+NNVM_REGISTER_OP(_npi_zeros)
.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 0>);
-NNVM_REGISTER_OP(_np_ones)
+NNVM_REGISTER_OP(_npi_ones)
+.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 1>);
+
+NNVM_REGISTER_OP(_np_zeros_like)
+.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 0>);
+
+NNVM_REGISTER_OP(_np_ones_like)
.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 1>);
} // namespace op
diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc
index 215b1c5..6e93442 100644
--- a/src/operator/numpy/np_matrix_op.cc
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -54,7 +54,7 @@ bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs,
return shape_is_known(ret);
}
-NNVM_REGISTER_OP(_numpy_transpose)
+NNVM_REGISTER_OP(_np_transpose)
.describe(R"code(Permute the dimensions of an array.
Examples::
@@ -105,7 +105,6 @@ Examples::
}
})
.set_attr<FCompute>("FCompute<cpu>", NumpyTranspose<cpu>)
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a"};
@@ -189,7 +188,7 @@ bool NumpyReshapeShape(const nnvm::NodeAttrs& attrs,
return success;
}
-NNVM_REGISTER_OP(_numpy_reshape)
+NNVM_REGISTER_OP(_np_reshape)
.describe(R"code()code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
@@ -210,7 +209,6 @@ NNVM_REGISTER_OP(_numpy_reshape)
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a"};
})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
.add_argument("a", "NDArray-or-Symbol", "Array to be reshaped.")
.add_arguments(NumpyReshapeParam::__FIELDS__());
diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu
index 9753566..5bf36e5 100644
--- a/src/operator/numpy/np_matrix_op.cu
+++ b/src/operator/numpy/np_matrix_op.cu
@@ -27,10 +27,10 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_numpy_transpose)
+NNVM_REGISTER_OP(_np_transpose)
.set_attr<FCompute>("FCompute<gpu>", NumpyTranspose<gpu>);
-NNVM_REGISTER_OP(_numpy_reshape)
+NNVM_REGISTER_OP(_np_reshape)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
} // namespace op
diff --git a/src/operator/numpy/np_true_divide.cc b/src/operator/numpy/np_true_divide.cc
index 3bafa26..4297627 100644
--- a/src/operator/numpy/np_true_divide.cc
+++ b/src/operator/numpy/np_true_divide.cc
@@ -54,7 +54,7 @@ bool TrueDivideType(const nnvm::NodeAttrs& attrs,
return true;
}
-NNVM_REGISTER_OP(_true_divide)
+NNVM_REGISTER_OP(_npi_true_divide)
.describe(R"code(
Returns a true division of the inputs, element-wise.
@@ -86,11 +86,10 @@ Example::
})
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::div>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
.add_argument("lhs", "NDArray-or-Symbol", "Dividend array")
.add_argument("rhs", "NDArray-or-Symbol", "Divisor array");
-NNVM_REGISTER_OP(_true_divide_scalar)
+NNVM_REGISTER_OP(_npi_true_divide_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
@@ -104,11 +103,10 @@ NNVM_REGISTER_OP(_true_divide_scalar)
})
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::div>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_div_scalar"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_argument("scalar", "float", "scalar input");
-NNVM_REGISTER_OP(_rtrue_divide_scalar)
+NNVM_REGISTER_OP(_npi_rtrue_divide_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
@@ -122,7 +120,6 @@ NNVM_REGISTER_OP(_rtrue_divide_scalar)
})
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rdiv>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_rdiv_scalar"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_argument("scalar", "float", "scalar input");
diff --git a/src/operator/numpy/np_true_divide.cu b/src/operator/numpy/np_true_divide.cu
index cbc7cf9..be10c44 100644
--- a/src/operator/numpy/np_true_divide.cu
+++ b/src/operator/numpy/np_true_divide.cu
@@ -28,13 +28,13 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_true_divide)
+NNVM_REGISTER_OP(_npi_true_divide)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::div>);
-NNVM_REGISTER_OP(_true_divide_scalar)
+NNVM_REGISTER_OP(_npi_true_divide_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::div>);
-NNVM_REGISTER_OP(_rtrue_divide_scalar)
+NNVM_REGISTER_OP(_npi_rtrue_divide_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rdiv>);
} // namespace op
diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py
index 141d153..eb45234 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -20,7 +20,7 @@ from __future__ import absolute_import
from __future__ import division
import numpy as _np
import mxnet as mx
-from mxnet import numpy as np
+from mxnet import np
from mxnet.gluon import HybridBlock
from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, assert_exception
from common import with_seed
@@ -37,15 +37,15 @@ def test_array_creation():
mx_arr = np.array(src, dtype=dtype)
assert mx_arr.context == mx.current_context()
if isinstance(src, mx.nd.NDArray):
- np_arr = _np.array(src.asnumpy(), dtype=dtype)
+ np_arr = _np.array(src.asnumpy(), dtype=dtype if dtype is not None else _np.float32)
else:
- np_arr = _np.array(src, dtype=dtype)
- assert same(mx_arr.asnumpy(), np_arr)
+ np_arr = _np.array(src, dtype=dtype if dtype is not None else _np.float32)
assert mx_arr.dtype == np_arr.dtype
+ assert same(mx_arr.asnumpy(), np_arr)
@with_seed()
-@mx.use_np_compat
+@np.use_np_compat
def test_zeros():
# test np.zeros in Gluon
class TestZeros(HybridBlock):
@@ -76,7 +76,7 @@ def test_zeros():
for shape in shapes:
for dtype in dtypes:
check_zero_array_creation(shape, dtype)
- x = mx.nd.array(_np.random.uniform(size=shape), dtype=dtype)
+ x = np.array(_np.random.uniform(size=shape), dtype=dtype)
if dtype is None:
x = x.astype('float32')
for hybridize in [True, False]:
@@ -93,7 +93,7 @@ def test_zeros():
@with_seed()
-@mx.use_np_compat
+@np.use_np_compat
def test_ones():
# test np.ones in Gluon
class TestOnes(HybridBlock):
@@ -141,7 +141,7 @@ def test_ones():
@with_seed()
-@mx.use_np_compat
+@np.use_np_compat
def test_ndarray_binary_element_wise_ops():
# Cannot test operators like >, because boolean arrays are not supported yet.
np_op_map = {'+': _np.add, '*': _np.multiply, '-': _np.subtract, '/': _np.divide,
@@ -241,23 +241,22 @@ def test_ndarray_binary_element_wise_ops():
np_out = get_np_ret(np_input1, np_input2, op)
for hybridize in [True, False]:
if scalar is None:
- get_mx_ret = TestBinaryElementWiseOp(op)
+ get_mx_ret_np = TestBinaryElementWiseOp(op)
+ get_mx_ret_classic = TestBinaryElementWiseOp(op)
if hybridize:
- get_mx_ret.hybridize()
- mx_out = get_mx_ret(mx_input1.as_np_ndarray(), mx_input2.as_np_ndarray())
+ get_mx_ret_np.hybridize()
+ get_mx_ret_classic.hybridize()
+ mx_out = get_mx_ret_np(mx_input1.as_np_ndarray(), mx_input2.as_np_ndarray())
assert type(mx_out) == np.ndarray
assert np_out.shape == mx_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5)
- mx_out = get_mx_ret(mx_input1, mx_input2.as_np_ndarray())
- assert type(mx_out) == np.ndarray
- assert np_out.shape == mx_out.shape
- assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5)
-
- mx_out = get_mx_ret(mx_input1.as_np_ndarray(), mx_input2)
- assert type(mx_out) == np.ndarray
- assert np_out.shape == mx_out.shape
- assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5)
+ if mx_input1.shape == mx_input2.shape:
+ # classic symbol does not support element-wise binary broadcast.
+ mx_out = get_mx_ret_classic(mx_input1, mx_input2)
+ assert type(mx_out) == mx.nd.NDArray
+ assert np_out.shape == mx_out.shape
+ assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5)
else:
get_mx_ret = TestBinaryElementWiseOp(op, scalar=scalar, reverse=reverse)
if hybridize:
@@ -291,29 +290,42 @@ def test_ndarray_binary_element_wise_ops():
@with_seed()
-def test_np_op_output_type():
- # test imperative invoke
- data = np.array([1., 3.], dtype='float32')
- ret = np.sum(data)
- assert type(ret) == np.ndarray
- ret = mx.nd.sin(data)
- assert type(ret) == mx.nd.NDArray
-
- # test cached op
- class TestCachedOpOutputType(HybridBlock):
- @mx.use_np_compat
+def test_hybrid_block_multiple_outputs():
+ class TestAllNumpyOutputs(HybridBlock):
+ @np.use_np_compat
+ def hybrid_forward(self, F, x, *args, **kwargs):
+ return F.npe.relu(x), F.np.sum(x)
+
+ class TestAllClassicOutputs(HybridBlock):
+ @np.use_np_compat
+ def hybrid_forward(self, F, x, *args, **kwargs):
+ return F.relu(x.as_classic_ndarray()), F.sum(x.as_classic_ndarray())
+
+ class TestMixedTypeOutputsSuccess(HybridBlock):
+ @np.use_np_compat
+ def hybrid_forward(self, F, x, *args, **kwargs):
+ return F.relu(x.as_classic_ndarray()).as_np_ndarray(), F.np.sum(x)
+
+ data_np = np.ones((2, 3))
+ for block, expected_out_type in [(TestAllClassicOutputs, mx.nd.NDArray),
+ (TestAllNumpyOutputs, np.ndarray),
+ (TestMixedTypeOutputsSuccess, np.ndarray)]:
+ net = block()
+ for hybridize in [True, False]:
+ if hybridize:
+ net.hybridize()
+ out1, out2 = net(data_np)
+ assert type(out1) is expected_out_type
+ assert type(out2) is expected_out_type
+
+ class TestMixedTypeOutputsFailure(HybridBlock):
+ @np.use_np_compat
def hybrid_forward(self, F, x, *args, **kwargs):
- ret1 = F.sin(x)
- ret2 = F.np.sum(x)
- return ret1, ret2
+ return F.relu(x.as_classic_ndarray()), F.np.sum(x)
- net = TestCachedOpOutputType()
- for hybridize in [True, False]:
- if hybridize:
- net.hybridize()
- ret1, ret2 = net(data)
- assert type(ret1) == mx.nd.NDArray
- assert type(ret2) == np.ndarray
+ net = TestMixedTypeOutputsFailure()
+ net.hybridize()
+ assert_exception(net, TypeError, data_np)
@with_seed()
@@ -331,6 +343,7 @@ def test_np_ndarray_astype():
def check_astype_equal(dtype, copy, expect_zero_copy=False):
mx_ret = mx_data.astype(dtype=dtype, copy=copy)
+ assert type(mx_ret) is np.ndarray
np_ret = np_data.astype(dtype=dtype, copy=copy)
assert mx_ret.dtype == np_ret.dtype
assert same(mx_ret.asnumpy(), np_ret)
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 8c13227..34b2cbe 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -19,7 +19,7 @@
from __future__ import absolute_import
import numpy as _np
import mxnet as mx
-from mxnet import numpy as np
+from mxnet import np, npe
from mxnet.gluon import HybridBlock
from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray
from mxnet.test_utils import check_numeric_gradient
@@ -27,7 +27,7 @@ from common import with_seed
import random
-@mx.use_np_compat
+@np.use_np_compat
@with_seed()
def test_np_sum():
class TestSum(HybridBlock):
@@ -38,7 +38,7 @@ def test_np_sum():
self._keepdims = keepdims
def hybrid_forward(self, F, a, *args, **kwargs):
- return F.numpy.sum(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims)
+ return F.np.sum(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims)
def is_int(dtype):
return 'int' in dtype
@@ -63,6 +63,7 @@ def test_np_sum():
x = mx.nd.array(x)
else:
x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype)
+ x = x.as_np_ndarray()
x.attach_grad()
expected_ret = _np.sum(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims)
expected_ret = expected_ret.astype(dtype)
@@ -77,8 +78,8 @@ def test_np_sum():
# test numeric
if itype == 'float32' and dtype == 'float32':
- x_sym = mx.sym.Variable("x")
- mx_sym = mx.sym.numpy.sum(x_sym, axis=axis, dtype=dtype, keepdims=keepdims)
+ x_sym = mx.sym.Variable("x").as_np_ndarray()
+ mx_sym = mx.sym.np.sum(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_classic_ndarray()
check_numeric_gradient(mx_sym, [x], numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
# test imperative
@@ -87,10 +88,11 @@ def test_np_sum():
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
-@mx.use_np_compat
+@np.use_np_compat
@with_seed()
def test_np_dot():
shapes = [
+ ((3, 0), (0, 4)),
((3,), (3,)), # Case 1
((3, 4), (4, 5)), # Case 2
((), ()), # Case 3
@@ -102,7 +104,6 @@ def test_np_dot():
eps = 1e-3
for shape_a, shape_b in shapes:
- print(shape_a, shape_b)
np_a = _np.random.uniform(-1.0, 1.0, shape_a)
np_a[abs(np_a) < eps] = 2 * eps;
np_b = _np.random.uniform(-1.0, 1.0, shape_b)
@@ -110,12 +111,12 @@ def test_np_dot():
a = mx.nd.array(np_a)
b = mx.nd.array(np_b)
np_res = _np.dot(np_a, np_b)
- mx_res = np.dot(a, b)
+ mx_res = np.dot(a.as_np_ndarray(), b.as_np_ndarray())
assert mx_res.shape == np_res.shape
assert_almost_equal(np_res, mx_res.asnumpy(), rtol=1e-5, atol=1e-5)
mx_a = mx.sym.Variable("a")
mx_b = mx.sym.Variable("b")
- mx_sym = mx.sym.numpy.dot(mx_a, mx_b)
+ mx_sym = mx.sym.np.dot(mx_a.as_np_ndarray(), mx_b.as_np_ndarray()).as_classic_ndarray()
check_numeric_gradient(mx_sym, {"a": a, "b": b}, numeric_eps=eps, rtol=1e-2, atol=1e-3)
bad_shapes = [((4, 5), (2, 3)), ((3, 4, 5), (6, ))]
@@ -124,13 +125,13 @@ def test_np_dot():
a = mx.nd.array(random.random()) if len(shape_a) == 0 else rand_ndarray(shape_a)
b = mx.nd.array(random.random()) if len(shape_b) == 0 else rand_ndarray(shape_b)
try:
- mx_res = np.dot(a, b)
+ mx_res = np.dot(a.as_np_ndarray(), b.as_np_ndarray())
except mx.base.MXNetError:
continue
assert False
-@mx.use_np_compat
+@np.use_np_compat
@with_seed()
def test_np_mean():
class TestMean(HybridBlock):
@@ -141,7 +142,7 @@ def test_np_mean():
self._keepdims = keepdims
def hybrid_forward(self, F, a, *args, **kwargs):
- return F.numpy.mean(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims)
+ return F.np.mean(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims)
def is_int(dtype):
return 'int' in dtype
@@ -167,6 +168,7 @@ def test_np_mean():
x = mx.nd.array(x, dtype=itype)
else:
x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype)
+ x = x.as_np_ndarray()
x.attach_grad()
expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims)
expected_ret = expected_ret.astype(dtype)
@@ -182,8 +184,8 @@ def test_np_mean():
# test numeric
if itype == 'float32' and dtype == 'float32':
- x_sym = mx.sym.Variable("x")
- mx_sym = mx.sym.numpy.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims)
+ x_sym = mx.sym.Variable("x").as_np_ndarray()
+ mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_classic_ndarray()
check_numeric_gradient(mx_sym, [x], numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
# test imperative
@@ -193,12 +195,12 @@ def test_np_mean():
@with_seed()
-@mx.use_np_compat
+@np.use_np_compat
def test_np_transpose():
# TODO(junwu): Add more test cases
- data = mx.sym.var('a')
- ret = mx.sym.np.transpose(data)
- assert type(ret) == mx.sym.np._NumpySymbol
+ data = mx.sym.var('a').as_np_ndarray()
+ ret = data.transpose()
+ assert type(ret) == mx.sym.np._Symbol
dtypes = ['float32', 'int32']
for dtype in dtypes:
@@ -223,44 +225,44 @@ def test_np_transpose():
@with_seed()
-@mx.use_np_compat
+@np.use_np_compat
def test_relu():
# TODO(junwu): Add more test cases
- data = mx.sym.var('data')
- ret = mx.sym.np.ext.relu(data)
- assert type(ret) == mx.sym.np._NumpySymbol
+ data = mx.sym.var('data').as_np_ndarray()
+ ret = mx.sym.npe.relu(data)
+ assert type(ret) == mx.sym.np._Symbol
shapes = [(), (0, 2, 0)]
shapes.extend([rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)])
for shape in shapes:
data = np.array(_np.random.uniform(size=shape).astype('float32'))
- ret = np.ext.relu(data)
+ ret = npe.relu(data)
assert type(ret) == np.ndarray
@with_seed()
-@mx.use_np_compat
+@np.use_np_compat
def test_sigmoid():
# TODO(junwu): Add more test cases
- data = mx.sym.var('data')
- ret = mx.sym.np.ext.sigmoid(data)
- assert type(ret) == mx.sym.np._NumpySymbol
+ data = mx.sym.var('data').as_np_ndarray()
+ ret = mx.sym.npe.sigmoid(data)
+ assert type(ret) == mx.sym.np._Symbol
shapes = [(), (0, 2, 0)]
shapes.extend([rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)])
for shape in shapes:
data = np.array(_np.random.uniform(size=shape).astype('float32'))
- ret = np.ext.sigmoid(data)
+ ret = npe.sigmoid(data)
assert type(ret) == np.ndarray
@with_seed()
-@mx.use_np_compat
+@np.use_np_compat
def test_np_reshape():
# TODO(junwu): Add more test cases
- data = mx.sym.var('a')
- ret = mx.sym.np.reshape(data, newshape=())
- assert type(ret) == mx.sym.np._NumpySymbol
+ data = mx.sym.var('a').as_np_ndarray()
+ ret = data.reshape(shape=())
+ assert type(ret) == mx.sym.np._Symbol
data = np.ones((1, 1, 1))
ret = np.reshape(data, ())
@@ -271,12 +273,12 @@ def test_np_reshape():
@with_seed()
-@mx.use_np_compat
+@np.use_np_compat
def test_np_maximum():
# TODO(junwu): Add more test cases
- x1, x2 = mx.sym.var('x1'), mx.sym.var('x2')
+ x1, x2 = mx.sym.var('x1').as_np_ndarray(), mx.sym.var('x2').as_np_ndarray()
ret = mx.sym.np.maximum(x1, x2)
- assert type(ret) == mx.sym.np._NumpySymbol
+ assert type(ret) == mx.sym.np._Symbol
def check_maximum(x1, x2):
mx_out = np.maximum(x1, x2)
@@ -292,12 +294,12 @@ def test_np_maximum():
@with_seed()
-@mx.use_np_compat
+@np.use_np_compat
def test_np_minimum():
# TODO(junwu): Add more test cases
- x1, x2 = mx.sym.var('x1'), mx.sym.var('x2')
+ x1, x2 = mx.sym.var('x1').as_np_ndarray(), mx.sym.var('x2').as_np_ndarray()
ret = mx.sym.np.minimum(x1, x2)
- assert type(ret) == mx.sym.np._NumpySymbol
+ assert type(ret) == mx.sym.np._Symbol
def check_minimum(x1, x2):
mx_out = np.minimum(x1, x2)