You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/07/18 00:00:38 UTC
[incubator-mxnet] 02/42: [numpy] Infra for supporting numpy ops in
imperative mode and Gluon APIs (#14758)
This is an automated email from the ASF dual-hosted git repository.
haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 94f6a2aa14eaf0324cb325c2bf085bd459b736d5
Author: reminisce <wu...@gmail.com>
AuthorDate: Fri May 3 16:09:44 2019 -0700
[numpy] Infra for supporting numpy ops in imperative mode and Gluon APIs (#14758)
* Infra of new ndarray and symbol types for numpy operators
* Rename
* Fix import problem
* Refactor
* Remove redundant code
* Add docstring
* More on numpy ndarray and symbol
* Override unimplemented methdos for ndarray and _NumpySymbol
* Fix built-in methods of ndarray and _NumpySymbol
* Fix test and sanity check
* Fix pylint
* Address cr comments
* Add unit tests for ndarray and _NumpySymbol
* Add _true_divide
* Fix gpu build
* Add future import division
* More correct way of checking if an output is from a np compat op
* Fix gpu build
* Fix output ndarray/symbol types with at least one new ndarray/symbol
* Modify true_divide doc
* Fix flaky copying zero-size arrays via gpus
* Fix zero size in gluon hybridize and zeros/ones symbol not creating new symbol type
* Fix doc
---
include/mxnet/c_api.h | 29 +
include/mxnet/op_attr_types.h | 9 +
python/mxnet/__init__.py | 2 +-
python/mxnet/_ctypes/ndarray.py | 38 +-
python/mxnet/_ctypes/symbol.py | 14 +-
python/mxnet/base.py | 101 +-
python/mxnet/gluon/block.py | 9 +-
python/mxnet/ndarray/__init__.py | 1 +
python/mxnet/ndarray/_internal.py | 11 +-
python/mxnet/ndarray/ndarray.py | 54 +
python/mxnet/{ => ndarray}/numpy/__init__.py | 10 +-
python/mxnet/ndarray/numpy/_op.py | 88 ++
.../__init__.py => ndarray/numpy/_register.py} | 10 +-
python/mxnet/numpy/__init__.py | 10 +
python/mxnet/numpy/{__init__.py => _op.py} | 4 +-
python/mxnet/numpy/{__init__.py => _register.py} | 12 +-
python/mxnet/{ndarray/numpy.py => numpy/linalg.py} | 2 +
python/mxnet/numpy/multiarray.py | 1200 ++++++++++++++++++++
python/mxnet/{symbol/numpy.py => numpy/random.py} | 2 +
python/mxnet/symbol/__init__.py | 1 +
python/mxnet/symbol/_internal.py | 10 +-
python/mxnet/{ => symbol}/numpy/__init__.py | 12 +-
.../{numpy/__init__.py => symbol/numpy/_op.py} | 4 +-
.../__init__.py => symbol/numpy/_register.py} | 9 +-
python/mxnet/symbol/numpy/_symbol.py | 974 ++++++++++++++++
python/mxnet/symbol/symbol.py | 57 +-
python/mxnet/test_utils.py | 19 +-
src/c_api/c_api.cc | 9 +
src/c_api/c_api_common.h | 7 +
src/c_api/c_api_ndarray.cc | 16 +
src/c_api/c_api_symbolic.cc | 13 +-
src/imperative/imperative_utils.h | 1 -
src/ndarray/ndarray.cc | 13 +-
src/operator/numpy/np_broadcast_reduce_op_value.cc | 3 +-
src/operator/numpy/np_elemwise_broadcast_op.cc | 197 ++++
src/operator/numpy/np_elemwise_broadcast_op.cu | 71 ++
src/operator/numpy/np_init_op.cc | 55 +
src/operator/numpy/np_init_op.cu | 38 +
src/operator/numpy/np_true_divide.cc | 130 +++
src/operator/numpy/np_true_divide.cu | 41 +
tests/python/gpu/test_operator_gpu.py | 1 +
tests/python/unittest/test_numpy_ndarray.py | 358 ++++++
42 files changed, 3567 insertions(+), 78 deletions(-)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index bd30e44..51d4b46 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -2893,6 +2893,35 @@ MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
EngineVarHandle mutable_vars_handle, int num_mutable_vars,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));
+/*!
+ * \brief Determines if an op is a Numpy op by its name prefix.
+ * Every Numpy op starts with a prefix string "_numpy_".
+ * \param creator Operator handle
+ * \param is_np_op Indicator of whether creator is a numpy op handle
+ */
+MXNET_DLL int MXIsNumpyCompatOp(AtomicSymbolCreator creator,
+ int* is_np_op);
+/*!
+ * \brief Create an NDArray from source sharing the same data chunk.
+ * \param src source NDArray
+ * \param out new NDArray sharing the same data chunck with src
+ */
+MXNET_DLL int MXShallowCopyNDArray(NDArrayHandle src, NDArrayHandle* out);
+/*!
+ * \brief Create an Symbol from source sharing the same graph structure.
+ * \param src source Symbol
+ * \param out new Symbol sharing the same graph structure with src
+ */
+MXNET_DLL int MXShallowCopySymbol(SymbolHandle src, SymbolHandle * out);
+/*!
+ * \brief Checks if an output of CachedOp is from a numpy op.
+ * \param handle CachedOp shared ptr
+ * \param output_idx index of the output of the CachedOp
+ * \param is_from_np_op indicator of whether the output is from a numpy op
+ */
+MXNET_DLL int MXIsCachedOpOutputFromNumpyCompatOp(CachedOpHandle handle,
+ int output_idx,
+ int* is_from_np_op);
/*!
* \brief Push an asynchronous operation to the engine.
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index 889b502..0e4e322 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -319,6 +319,15 @@ using FNeedRequantize = std::function<bool (const NodeAttrs& attrs)>;
using FAvoidQuantizeInput = std::function<bool (const NodeAttrs& attrs,
size_t index)>;
+/*!
+ * \brief Indicates whether this operator is NumPy compatible.
+ * It is for distinguishing the operator from classic MXNet operators
+ * which do not support zero-dim and zero-size tensors.
+ * In Python, it is used to determine whether to output numpy ndarrays
+ * or symbols that are NumPy compatible.
+ */
+using TIsNumpyCompatible = bool;
+
} // namespace mxnet
#endif // MXNET_OP_ATTR_TYPES_H_
diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py
index a850b38..7c8150b 100644
--- a/python/mxnet/__init__.py
+++ b/python/mxnet/__init__.py
@@ -26,10 +26,10 @@ from . import engine
from .base import MXNetError
from .util import is_np_shape, set_np_shape, np_shape, use_np_shape
from . import base
-from . import numpy
from . import contrib
from . import ndarray
from . import ndarray as nd
+from . import numpy
from . import name
# use mx.sym as short for symbol
from . import symbol as sym
diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py
index f324545..60ec248 100644
--- a/python/mxnet/_ctypes/ndarray.py
+++ b/python/mxnet/_ctypes/ndarray.py
@@ -26,7 +26,7 @@ import ctypes
from ..base import _LIB
from ..base import c_str_array, c_handle_array
from ..base import NDArrayHandle, CachedOpHandle
-from ..base import check_call
+from ..base import check_call, _is_np_compat_op
class NDArrayBase(object):
@@ -55,6 +55,8 @@ class NDArrayBase(object):
_ndarray_cls = None
+_np_ndarray_cls = None
+
def _set_ndarray_class(cls):
"""Set the symbolic class to be cls"""
@@ -62,6 +64,12 @@ def _set_ndarray_class(cls):
_ndarray_cls = cls
+def _set_np_ndarray_class(cls):
+ """Set the symbolic class to be cls"""
+ global _np_ndarray_cls
+ _np_ndarray_cls = cls
+
+
def _imperative_invoke(handle, ndargs, keys, vals, out):
"""ctypes implementation of imperative invoke wrapper"""
if out is not None:
@@ -93,18 +101,19 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
if original_output is not None:
return original_output
+ create_ndarray_fn = _np_ndarray_cls if _is_np_compat_op(handle) else _ndarray_cls
if num_output.value == 1:
- return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle),
- stype=out_stypes[0])
+ return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle),
+ stype=out_stypes[0])
else:
- return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle),
- stype=out_stypes[i])
- for i in range(num_output.value)]
+ return [create_ndarray_fn(ctypes.cast(output_vars[i], NDArrayHandle),
+ stype=out_stypes[i]) for i in range(num_output.value)]
class CachedOp(object):
"""Cached operator handle."""
__slots__ = ["handle"]
+
def __init__(self, sym, flags=()):
self.handle = CachedOpHandle()
@@ -118,6 +127,13 @@ class CachedOp(object):
def __del__(self):
check_call(_LIB.MXFreeCachedOp(self.handle))
+ def _is_from_np_compat_op(self, idx):
+ """Check if the CachedOp's idx-th output is directly from a numpy op."""
+ is_from_np_op = ctypes.c_int(0)
+ check_call(_LIB.MXIsCachedOpOutputFromNumpyCompatOp(self.handle, ctypes.c_int(idx),
+ ctypes.byref(is_from_np_op)))
+ return is_from_np_op.value != 0
+
def __call__(self, *args, **kwargs):
"""ctypes implementation of imperative invoke wrapper"""
out = kwargs.pop('out', None)
@@ -152,9 +168,11 @@ class CachedOp(object):
if original_output is not None:
return original_output
if num_output.value == 1:
- return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle),
- stype=out_stypes[0])
+ create_ndarray_fn = _np_ndarray_cls if self._is_from_np_compat_op(0) else _ndarray_cls
+ return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle),
+ stype=out_stypes[0])
else:
- return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle),
- stype=out_stypes[i])
+ return [_np_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), stype=out_stypes[i])
+ if self._is_from_np_compat_op(i) else
+ _ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), stype=out_stypes[i])
for i in range(num_output.value)]
diff --git a/python/mxnet/_ctypes/symbol.py b/python/mxnet/_ctypes/symbol.py
index fe4cb95..7aea0a2 100644
--- a/python/mxnet/_ctypes/symbol.py
+++ b/python/mxnet/_ctypes/symbol.py
@@ -22,11 +22,12 @@ from __future__ import absolute_import as _abs
import ctypes
from ..base import _LIB
-from ..base import c_str_array, c_handle_array, c_str, mx_uint
+from ..base import c_str_array, c_handle_array, c_str, mx_uint, _is_np_compat_op
from ..base import SymbolHandle
from ..base import check_call
_symbol_cls = None
+_np_symbol_cls = None
class SymbolBase(object):
"""Symbol is symbolic graph."""
@@ -115,6 +116,12 @@ def _set_symbol_class(cls):
_symbol_cls = cls
+def _set_np_symbol_class(cls):
+ """Set the symbolic class to be cls"""
+ global _np_symbol_cls
+ _np_symbol_cls = cls
+
+
def _symbol_creator(handle, args, kwargs, keys, vals, name):
sym_handle = SymbolHandle()
check_call(_LIB.MXSymbolCreateAtomicSymbol(
@@ -128,7 +135,10 @@ def _symbol_creator(handle, args, kwargs, keys, vals, name):
raise TypeError(
'Operators with variable length input can only accept input'
'Symbols either as positional or keyword arguments, not both')
- s = _symbol_cls(sym_handle)
+ if _is_np_compat_op(handle):
+ s = _np_symbol_cls(sym_handle)
+ else:
+ s = _symbol_cls(sym_handle)
if args:
s._compose(*args, name=name)
elif kwargs:
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index c435317..0d4bf53 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -561,7 +561,7 @@ def _as_list(obj):
return [obj]
-_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_', '_numpy_']
+_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_']
def _get_op_name_prefix(op_name):
@@ -607,15 +607,6 @@ def _init_op_module(root_namespace, module_name, make_op_func):
# use mx.nd.contrib or mx.sym.contrib from now on
contrib_module_name_old = "%s.contrib.%s" % (root_namespace, module_name)
contrib_module_old = sys.modules[contrib_module_name_old]
- # special handling of registering numpy ops
- # only expose mxnet.numpy.op_name to users for imperative mode.
- # Symbolic mode should be used in Gluon.
- if module_name == 'ndarray':
- numpy_module_name = "%s.numpy" % root_namespace
- numpy_module = sys.modules[numpy_module_name]
- else:
- numpy_module_name = None
- numpy_module = None
submodule_dict = {}
for op_name_prefix in _OP_NAME_PREFIX_LIST:
submodule_dict[op_name_prefix] =\
@@ -654,16 +645,6 @@ def _init_op_module(root_namespace, module_name, make_op_func):
function.__module__ = contrib_module_name_old
setattr(contrib_module_old, function.__name__, function)
contrib_module_old.__all__.append(function.__name__)
- elif op_name_prefix == '_numpy_' and numpy_module_name is not None:
- # only register numpy ops under mxnet.numpy in imperative mode
- hdl = OpHandle()
- check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
- # TODO(reminisce): Didn't consider third level module here, e.g. mxnet.numpy.random.
- func_name = name[len(op_name_prefix):]
- function = make_op_func(hdl, name, func_name)
- function.__module__ = numpy_module_name
- setattr(numpy_module, function.__name__, function)
- numpy_module.__all__.append(function.__name__)
def _generate_op_module_signature(root_namespace, module_name, op_code_gen_func):
@@ -753,3 +734,83 @@ def _generate_op_module_signature(root_namespace, module_name, op_code_gen_func)
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
+
+
+def _sanity_check_params(func_name, unsupported_params, param_dict):
+ for param_name in unsupported_params:
+ if param_name in param_dict:
+ raise NotImplementedError("function {} does not support parameter {}"
+ .format(func_name, param_name))
+
+
+_NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_']
+_NP_OP_PREFIX = '_numpy_'
+
+
+def _get_np_op_submodule_name(op_name):
+ assert op_name.startswith(_NP_OP_PREFIX)
+ for name in _NP_OP_SUBMODULE_LIST:
+ if op_name[len(_NP_OP_PREFIX):].startswith(name):
+ return name
+ return ""
+
+
+def _init_np_op_module(root_namespace, module_name, make_op_func):
+ """
+ Register numpy operators in namespaces `mxnet.numpy`, `mxnet.ndarray.numpy`
+ and `mxnet.symbol.numpy`. They are used in imperative mode, Gluon APIs w/o hybridization,
+ and Gluon APIs w/ hybridization, respectively. Essentially, operators with the same name
+ registered in three namespaces, respectively share the same functionality in C++ backend.
+ Different namespaces are needed for dispatching operator calls in Gluon's `HybridBlock` by `F`.
+
+ Parameters
+ ----------
+ root_namespace : str
+ Top level module name, `mxnet` in the current cases.
+ module_name : str
+ Second level module name, `ndarray` or `symbol` in the current case.
+ make_op_func : function
+ Function for creating op functions.
+ """
+ plist = ctypes.POINTER(ctypes.c_char_p)()
+ size = ctypes.c_uint()
+
+ check_call(_LIB.MXListAllOpNames(ctypes.byref(size), ctypes.byref(plist)))
+ op_names = []
+ for i in range(size.value):
+ name = py_str(plist[i])
+ if name.startswith(_NP_OP_PREFIX):
+ op_names.append(name)
+
+ if module_name == 'numpy':
+ # register ops for mxnet.numpy
+ module_pattern = "%s.%s._op"
+ submodule_pattern = "%s.%s.%s"
+ else:
+ # register ops for mxnet.ndarray.numpy or mxnet.symbol.numpy
+ module_pattern = "%s.%s.numpy._op"
+ submodule_pattern = "%s.%s.numpy.%s"
+ module_np_op = sys.modules[module_pattern % (root_namespace, module_name)]
+ submodule_dict = {}
+ # TODO(junwu): uncomment the following lines when adding numpy ops in submodules, e.g. np.random
+ # for submodule_name in _NP_OP_SUBMODULE_LIST:
+ # submodule_dict[submodule_name] = \
+ # sys.modules[submodule_pattern % (root_namespace, module_name, submodule_name[1:-1])]
+ for name in op_names:
+ hdl = OpHandle()
+ check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
+ submodule_name = _get_np_op_submodule_name(name)
+ module_name_local = module_name
+ if len(submodule_name) > 0:
+ func_name = name[(len(_NP_OP_PREFIX) + len(submodule_name)):]
+ cur_module = submodule_dict[submodule_name]
+ module_name_local = submodule_pattern % (root_namespace,
+ module_name, submodule_name[1:-1])
+ else:
+ func_name = name[len(_NP_OP_PREFIX):]
+ cur_module = module_np_op
+
+ function = make_op_func(hdl, name, func_name)
+ function.__module__ = module_name_local
+ setattr(cur_module, function.__name__, function)
+ cur_module.__all__.append(function.__name__)
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 3bac3c0..b9e8754 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -34,6 +34,7 @@ from ..ndarray import NDArray
from .. import name as _name
from .parameter import Parameter, ParameterDict, DeferredInitializationError
from .utils import _indent, _brief_print_list, HookHandle
+from .. import numpy as _mx_np
class _BlockScope(object):
@@ -739,9 +740,13 @@ class HybridBlock(Block):
if not self._cached_graph:
args, self._in_format = _flatten(args, "input")
if len(args) > 1:
- inputs = [symbol.var('data%d'%i) for i in range(len(args))]
+ inputs = [symbol.var('data%d' % i).as_np_ndarray()
+ if isinstance(args[i], _mx_np.ndarray)
+ else symbol.var('data%d' % i) for i in range(len(args))]
else:
- inputs = [symbol.var('data')]
+ inputs = [symbol.var('data').as_np_ndarray()
+ if isinstance(args[0], _mx_np.ndarray)
+ else symbol.var('data')]
grouped_inputs = _regroup(inputs, self._in_format)[0]
params = {i: j.var() for i, j in self._reg_params.items()}
diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py
index a102399..f0e6edb 100644
--- a/python/mxnet/ndarray/__init__.py
+++ b/python/mxnet/ndarray/__init__.py
@@ -30,6 +30,7 @@ from .ndarray import *
from .utils import load, load_frombuffer, save, zeros, empty, array
from .sparse import _ndarray_cls
from .ndarray import _GRAD_REQ_MAP, _DTYPE_MX_TO_NP, _DTYPE_NP_TO_MX, _new_empty_handle
+from . import numpy as np
__all__ = op.__all__ + ndarray.__all__ + utils.__all__ + \
['contrib', 'linalg', 'random', 'sparse', 'image']
diff --git a/python/mxnet/ndarray/_internal.py b/python/mxnet/ndarray/_internal.py
index 8045d9b..d482556 100644
--- a/python/mxnet/ndarray/_internal.py
+++ b/python/mxnet/ndarray/_internal.py
@@ -23,18 +23,18 @@ import sys as _sys
try:
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
from .._ctypes.ndarray import NDArrayBase, CachedOp
- from .._ctypes.ndarray import _set_ndarray_class, _imperative_invoke
+ from .._ctypes.ndarray import _set_ndarray_class, _imperative_invoke, _set_np_ndarray_class
elif _sys.version_info >= (3, 0):
from .._cy3.ndarray import NDArrayBase, CachedOp
- from .._cy3.ndarray import _set_ndarray_class, _imperative_invoke
+ from .._cy3.ndarray import _set_ndarray_class, _imperative_invoke, _set_np_ndarray_class
else:
from .._cy2.ndarray import NDArrayBase, CachedOp
- from .._cy2.ndarray import _set_ndarray_class, _imperative_invoke
+ from .._cy2.ndarray import _set_ndarray_class, _imperative_invoke, _set_np_ndarray_class
except ImportError:
if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0:
raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1")
from .._ctypes.ndarray import NDArrayBase, CachedOp
- from .._ctypes.ndarray import _set_ndarray_class, _imperative_invoke
+ from .._ctypes.ndarray import _set_ndarray_class, _imperative_invoke, _set_np_ndarray_class
from ..base import _Null
try:
@@ -42,4 +42,5 @@ try:
except ImportError:
pass
-__all__ = ['NDArrayBase', 'CachedOp', '_imperative_invoke', '_set_ndarray_class']
+__all__ = ['NDArrayBase', 'CachedOp', '_imperative_invoke', '_set_ndarray_class',
+ '_set_np_ndarray_class']
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 3fb1af6..23a239c 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -184,6 +184,18 @@ fixed-size items.
# See C++ side of definition(kTVMNDArrayTypeCode) at include/mxmet/tensor_blob.h
_tvm_tcode = 19
# pylint: disable= no-member, undefined-variable
+
+ def as_np_ndarray(self):
+ """Convert mxnet.ndarray.NDArray to mxnet.numpy.ndarray."""
+ from ..numpy import ndarray
+ hdl = NDArrayHandle()
+ check_call(_LIB.MXShallowCopyNDArray(self.handle, ctypes.byref(hdl)))
+ return ndarray(handle=hdl, writable=self.writable)
+
+ def _is_np_compat(self):
+ """Always returns False except for mxnet.numpy.ndarray."""
+ return False
+
@property
def _tvm_handle(self):
return self.handle.value
@@ -207,6 +219,9 @@ fixed-size items.
def __add__(self, other):
"""x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """
+ # other may be the type of mxnet.numpy.ndarray
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__add__(self)
return add(self, other)
def __iadd__(self, other):
@@ -221,10 +236,15 @@ fixed-size items.
raise TypeError('type %s not supported' % str(type(other)))
def __radd__(self, other):
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__add__(self)
return self.__add__(other)
def __sub__(self, other):
"""x.__sub__(y) <=> x-y <=> mx.nd.subtract(x, y) """
+ # other may be the type of mxnet.numpy.ndarray
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__rsub__(self)
return subtract(self, other)
def __isub__(self, other):
@@ -240,10 +260,14 @@ fixed-size items.
def __rsub__(self, other):
"""x.__rsub__(y) <=> y-x <=> mx.nd.subtract(y, x) """
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__sub__(self)
return subtract(other, self)
def __mul__(self, other):
"""x.__mul__(y) <=> x*y <=> mx.nd.multiply(x, y) """
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__mul__(self)
return multiply(self, other)
def __neg__(self):
@@ -262,14 +286,20 @@ fixed-size items.
raise TypeError('type %s not supported' % str(type(other)))
def __rmul__(self, other):
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__mul__(self)
return self.__mul__(other)
def __div__(self, other):
"""x.__div__(y) <=> x/y <=> mx.nd.divide(x, y) """
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__rtruediv__(self)
return divide(self, other)
def __rdiv__(self, other):
"""x.__rdiv__(y) <=> y/x <=> mx.nd.divide(y, x) """
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__truediv__(self)
return divide(other, self)
def __idiv__(self, other):
@@ -284,9 +314,13 @@ fixed-size items.
raise TypeError('type %s not supported' % str(type(other)))
def __truediv__(self, other):
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__rtruediv__(self)
return divide(self, other)
def __rtruediv__(self, other):
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__truediv__(self)
return divide(other, self)
def __itruediv__(self, other):
@@ -294,10 +328,14 @@ fixed-size items.
def __mod__(self, other):
"""x.__mod__(y) <=> x%y <=> mx.nd.modulo(x, y) """
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__rmod__(self)
return modulo(self, other)
def __rmod__(self, other):
"""x.__rmod__(y) <=> y%x <=> mx.nd.modulo(y, x) """
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__mod__(self)
return modulo(other, self)
def __imod__(self, other):
@@ -313,14 +351,20 @@ fixed-size items.
def __pow__(self, other):
"""x.__pow__(y) <=> x**y <=> mx.nd.power(x,y) """
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__rpow__(self)
return power(self, other)
def __rpow__(self, other):
"""x.__pow__(y) <=> y**x <=> mx.nd.power(y,x) """
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__pow__(self)
return power(other, self)
def __eq__(self, other):
"""x.__eq__(y) <=> x==y <=> mx.nd.equal(x, y) """
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__eq__(self)
return equal(self, other)
def __hash__(self):
@@ -329,22 +373,32 @@ fixed-size items.
def __ne__(self, other):
"""x.__ne__(y) <=> x!=y <=> mx.nd.not_equal(x, y) """
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__ne__(self)
return not_equal(self, other)
def __gt__(self, other):
"""x.__gt__(y) <=> x>y <=> mx.nd.greater(x, y) """
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__lt__(self)
return greater(self, other)
def __ge__(self, other):
"""x.__ge__(y) <=> x>=y <=> mx.nd.greater_equal(x, y) """
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__le__(self)
return greater_equal(self, other)
def __lt__(self, other):
"""x.__lt__(y) <=> x<y <=> mx.nd.lesser(x, y) """
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__gt__(self)
return lesser(self, other)
def __le__(self, other):
"""x.__le__(y) <=> x<=y <=> mx.nd.less_equal(x, y) """
+ if isinstance(other, NDArray) and other._is_np_compat():
+ return other.__ge__(self)
return lesser_equal(self, other)
def __bool__(self):
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/ndarray/numpy/__init__.py
similarity index 81%
copy from python/mxnet/numpy/__init__.py
copy to python/mxnet/ndarray/numpy/__init__.py
index b1139a0..a714a4b 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/ndarray/numpy/__init__.py
@@ -1,5 +1,3 @@
-#!/usr/bin/env python
-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -17,4 +15,10 @@
# specific language governing permissions and limitations
# under the License.
-__all__ = []
+"""numpy module for numpy ops under mxnet.ndarray."""
+
+from . import _op
+from . import _register
+from ._op import * # pylint: disable=wildcard-import
+
+__all__ = _op.__all__
diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
new file mode 100644
index 0000000..383bf2f
--- /dev/null
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""numpy namespace for operators used in Gluon APIs dispatched by F=ndarray module."""
+
+from __future__ import absolute_import
+import numpy as _np
+from ...base import _sanity_check_params, use_np_compat
+from ...context import current_context
+from .. import _internal
+
+__all__ = ['zeros', 'ones']
+
+
+@use_np_compat
+def zeros(shape, dtype=_np.float32, **kwargs):
+ """Return a new array of given shape and type, filled with zeros.
+ This function currently only supports storing multi-dimensional data
+ in row-major (C-style).
+
+ Parameters
+ ----------
+ shape : int or tuple of int
+ The shape of the empty array.
+ dtype : str or numpy.dtype, optional
+ An optional value type. Default is `numpy.float32`. Note that this
+ behavior is different from NumPy's `ones` function where `float64`
+ is the default value, because `float32` is considered as the default
+ data type in deep learning.
+ ctx : Context, optional
+ An optional device context (default is the current default context).
+
+ Returns
+ -------
+ out : ndarray
+ Array of zeros with the given shape, dtype, and ctx.
+ """
+ _sanity_check_params('zeros', ['order'], kwargs)
+ ctx = kwargs.get('ctx', current_context())
+ if ctx is None:
+ ctx = current_context()
+ dtype = _np.float32 if dtype is None else dtype
+ return _internal._np_zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
+
+
+@use_np_compat
+def ones(shape, dtype=None, **kwargs):
+ """Return a new array of given shape and type, filled with ones.
+ This function currently only supports storing multi-dimensional data
+ in row-major (C-style).
+
+ Parameters
+ ----------
+ shape : int or tuple of int
+ The shape of the empty array.
+ dtype : str or numpy.dtype, optional
+ An optional value type. Default is `numpy.float32`. Note that this
+ behavior is different from NumPy's `ones` function where `float64`
+ is the default value, because `float32` is considered as the default
+ data type in deep learning.
+ ctx : Context, optional
+ An optional device context (default is the current default context).
+
+ Returns
+ -------
+ out : ndarray
+ Array of zeros with the given shape, dtype, and ctx.
+ """
+ _sanity_check_params('zeros', ['order'], kwargs)
+ ctx = kwargs.get('ctx', current_context())
+ if ctx is None:
+ ctx = current_context()
+ dtype = _np.float32 if dtype is None else dtype
+ return _internal._np_ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/ndarray/numpy/_register.py
similarity index 78%
copy from python/mxnet/numpy/__init__.py
copy to python/mxnet/ndarray/numpy/_register.py
index b1139a0..840797f 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/ndarray/numpy/_register.py
@@ -1,5 +1,3 @@
-#!/usr/bin/env python
-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -17,4 +15,10 @@
# specific language governing permissions and limitations
# under the License.
-__all__ = []
+"""module for registering numpy ops under mxnet.ndarray.numpy."""
+
+from ...base import _init_np_op_module
+from ..register import _make_ndarray_function
+
+
+_init_np_op_module('mxnet', 'ndarray', _make_ndarray_function)
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py
index b1139a0..c4dea9e 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/numpy/__init__.py
@@ -17,4 +17,14 @@
# specific language governing permissions and limitations
# under the License.
+"""numpy module for imperative programming."""
+
+from __future__ import absolute_import
+from .multiarray import * # pylint: disable=wildcard-import
+from . import _op
+from . import random
+from . import linalg
+from . import _register
+from ._op import * # pylint: disable=wildcard-import
+
__all__ = []
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/_op.py
similarity index 91%
copy from python/mxnet/numpy/__init__.py
copy to python/mxnet/numpy/_op.py
index b1139a0..e6a918c 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/numpy/_op.py
@@ -1,5 +1,3 @@
-#!/usr/bin/env python
-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -17,4 +15,6 @@
# specific language governing permissions and limitations
# under the License.
+"""namespace for registering numpy ops for imperative programming."""
+
__all__ = []
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/_register.py
similarity index 75%
copy from python/mxnet/numpy/__init__.py
copy to python/mxnet/numpy/_register.py
index b1139a0..53ceecd 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/numpy/_register.py
@@ -1,5 +1,3 @@
-#!/usr/bin/env python
-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -17,4 +15,12 @@
# specific language governing permissions and limitations
# under the License.
-__all__ = []
+"""Register backend ops in mxnet.ndarray namespace."""
+
+from __future__ import absolute_import
+
+from ..base import _init_np_op_module
+from ..ndarray.register import _make_ndarray_function
+
+
+_init_np_op_module('mxnet', 'numpy', _make_ndarray_function)
diff --git a/python/mxnet/ndarray/numpy.py b/python/mxnet/numpy/linalg.py
similarity index 92%
rename from python/mxnet/ndarray/numpy.py
rename to python/mxnet/numpy/linalg.py
index 0826ac8..1527c61 100644
--- a/python/mxnet/ndarray/numpy.py
+++ b/python/mxnet/numpy/linalg.py
@@ -15,4 +15,6 @@
# specific language governing permissions and limitations
# under the License.
+"""namespace for registering numpy ops of linear algebra."""
+
__all__ = []
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
new file mode 100644
index 0000000..9f47ce1
--- /dev/null
+++ b/python/mxnet/numpy/multiarray.py
@@ -0,0 +1,1200 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=too-many-lines
+"""numpy ndarray and util functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from array import array as native_array
+import ctypes
+import numpy as _np
+from ..ndarray import NDArray, _DTYPE_NP_TO_MX
+from ..ndarray._internal import _set_np_ndarray_class
+from . import _op
+from ..base import use_np_compat, check_call, _LIB, NDArrayHandle, _sanity_check_params
+from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types
+from ..context import current_context
+from ..ndarray import numpy as _mx_nd_np
+from ..ndarray import _internal as _nd_internal
+
+__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones']
+
+
+# This function is copied from ndarray.py since pylint
+# keeps giving false alarm error of undefined-all-variable
+def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
+ """Return a new handle with specified shape and context.
+
+ Empty handle is only used to hold results.
+
+ Returns
+ -------
+ handle
+ A new empty `ndarray` handle.
+ """
+ hdl = NDArrayHandle()
+ check_call(_LIB.MXNDArrayCreateEx(
+ c_array_buf(mx_uint, native_array('I', shape)),
+ mx_uint(len(shape)),
+ ctypes.c_int(ctx.device_typeid),
+ ctypes.c_int(ctx.device_id),
+ ctypes.c_int(int(delay_alloc)),
+ ctypes.c_int(int(_DTYPE_NP_TO_MX[_np.dtype(dtype).type])),
+ ctypes.byref(hdl)))
+ return hdl
+
+
+# Have to use 0 as default value for stype since plylint does not allow
+# importing _STORAGE_TYPE_DEFAULT from ndarray.py.
+def _np_ndarray_cls(handle, writable=True, stype=0):
+ if stype != 0:
+ raise ValueError('_np_ndarray_cls currently only supports default storage '
+ 'type, while received stype = {}'.format(stype))
+ return ndarray(handle, writable=writable)
+
+
+_set_np_ndarray_class(_np_ndarray_cls)
+
+
+class ndarray(NDArray):
+ """An array object represents a multidimensional, homogeneous array of fixed-size items.
+ An associated data-type object describes the format of each element in the array
+ (its byte-order, how many bytes it occupies in memory, whether it is an integer, a
+ floating point number, or something else, etc.). Arrays should be constructed using
+ `array`, `zeros` or `empty`. Currently, only c-contiguous arrays are supported."""
+
+ def _is_np_compat(self):
+ return True
+
+ @use_np_compat
+ def __getitem__(self, item):
+ # TODO(junwu): make output shape of integer indexing correct
+ raise NotImplementedError
+
+ @use_np_compat
+ def __setitem__(self, key, value):
+ super(ndarray, self).__setitem__(key, value)
+
+ @use_np_compat
+ def __add__(self, other):
+ """x.__add__(y) <=> x + y"""
+ if isinstance(other, NDArray):
+ return _nd_internal._np_add(self, other)
+ elif isinstance(other, numeric_types):
+ return _nd_internal._np_add_scalar(self, float(other))
+ else:
+ raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
+
+ @use_np_compat
+ def __iadd__(self, other):
+ raise NotImplementedError
+
+ @use_np_compat
+ def __sub__(self, other):
+ """x.__sub__(y) <=> x - y"""
+ if isinstance(other, NDArray):
+ return _nd_internal._np_subtract(self, other)
+ elif isinstance(other, numeric_types):
+ return _nd_internal._np_subtract_scalar(self, float(other))
+ else:
+ raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
+
+ @use_np_compat
+ def __isub__(self, other):
+ raise NotImplementedError
+
+ @use_np_compat
+ def __rsub__(self, other):
+ """x.__rsub__(y) <=> y - x"""
+ if isinstance(other, NDArray):
+ return _nd_internal._np_subtract(other, self)
+ elif isinstance(other, numeric_types):
+ return _nd_internal._np_rsubtract_scalar(self, float(other))
+ else:
+ raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
+
+ @use_np_compat
+ def __mul__(self, other):
+ """x.__mul__(y) <=> x * y"""
+ if isinstance(other, NDArray):
+ return _nd_internal._np_multiply(self, other)
+ elif isinstance(other, numeric_types):
+ return _nd_internal._np_multiply_scalar(self, float(other))
+ else:
+ raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
+
+ @use_np_compat
+ def __neg__(self):
+ return self.__mul__(-1.0)
+
+ @use_np_compat
+ def __imul__(self, other):
+ raise NotImplementedError
+
+ @use_np_compat
+ def __rmul__(self, other):
+ """x.__rmul__(y) <=> y * x"""
+ return self.__mul__(other)
+
+ def __div__(self, other):
+ raise AttributeError('ndarray.__div__ is replaced by __truediv__. If you are using'
+ ' Python2, please use the statement from __future__ import division'
+ ' to change the / operator to mean true division throughout the'
+ ' module. If you are using Python3, this error should not have'
+ ' been encountered.')
+
+ def __rdiv__(self, other):
+ raise AttributeError('ndarray.__rdiv__ is replaced by __rtruediv__. If you are using'
+ ' Python2, please use the statement from __future__ import division'
+ ' to change the / operator to mean true division throughout the'
+ ' module. If you are using Python3, this error should not have'
+ ' been encountered.')
+
+ @use_np_compat
+ def __idiv__(self, other):
+ raise NotImplementedError
+
+ @use_np_compat
+ def __truediv__(self, other):
+ """x.__truediv__(y) <=> x / y"""
+ if isinstance(other, NDArray):
+ return _nd_internal._true_divide(self, other)
+ elif isinstance(other, numeric_types):
+ return _nd_internal._true_divide_scalar(self, float(other))
+ else:
+ raise TypeError("ndarray does not support type {} as divisor".format(str(type(other))))
+
+ @use_np_compat
+ def __rtruediv__(self, other):
+ """x.__rtruediv__(y) <=> y / x"""
+ if isinstance(other, NDArray):
+ return _nd_internal._true_divide(other, self)
+ elif isinstance(other, numeric_types):
+ return _nd_internal._rtrue_divide_scalar(self, float(other))
+ else:
+ raise TypeError("ndarray does not support type {} as dividend".format(str(type(other))))
+
+ @use_np_compat
+ def __itruediv__(self, other):
+ raise NotImplementedError
+
+ @use_np_compat
+ def __mod__(self, other):
+ """x.__mod__(y) <=> x % y"""
+ if isinstance(other, NDArray):
+ return _nd_internal._np_mod(self, other)
+ elif isinstance(other, numeric_types):
+ return _nd_internal._np_mod_scalar(self, float(other))
+ else:
+ raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
+
+ @use_np_compat
+ def __rmod__(self, other):
+ """x.__rmod__(y) <=> y % x"""
+ if isinstance(other, NDArray):
+ return _nd_internal._np_mod(other, self)
+ elif isinstance(other, numeric_types):
+ return _nd_internal._np_rmod_scalar(self, float(other))
+ else:
+ raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
+
+ @use_np_compat
+ def __imod__(self, other):
+ raise NotImplementedError
+
+ @use_np_compat
+ def __pow__(self, other):
+ """x.__pow__(y) <=> x ** y"""
+ if isinstance(other, NDArray):
+ return _nd_internal._np_power(self, other)
+ elif isinstance(other, numeric_types):
+ return _nd_internal._np_power_scalar(self, float(other))
+ else:
+ raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
+
+ @use_np_compat
+ def __rpow__(self, other):
+ """x.__rpow__(y) <=> y ** x"""
+ if isinstance(other, NDArray):
+ return _nd_internal._np_power(other, self)
+ elif isinstance(other, numeric_types):
+ return _nd_internal._np_rpower_scalar(self, float(other))
+ else:
+ raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
+
+ @use_np_compat
+ def __eq__(self, other):
+ """x.__eq__(y) <=> x == y"""
+ raise NotImplementedError
+
+ @use_np_compat
+ def __hash__(self):
+ raise NotImplementedError
+
+ @use_np_compat
+ def __ne__(self, other):
+ """x.__ne__(y) <=> x != y"""
+ raise NotImplementedError
+
+ @use_np_compat
+ def __gt__(self, other):
+ """x.__gt__(y) <=> x > y"""
+ raise NotImplementedError
+
+ @use_np_compat
+ def __ge__(self, other):
+ """x.__ge__(y) <=> x >= y"""
+ raise NotImplementedError
+
+ @use_np_compat
+ def __lt__(self, other):
+ """x.__lt__(y) <=> x < y"""
+ raise NotImplementedError
+
+ @use_np_compat
+ def __le__(self, other):
+ """x.__le__(y) <=> x <= y"""
+ raise NotImplementedError
+
+ @use_np_compat
+ def __bool__(self):
+ raise NotImplementedError
+
+ @use_np_compat
+ def __len__(self):
+ """Number of elements along the first axis."""
+ return self.shape[0]
+
+ def __reduce__(self):
+ return ndarray, (None,), self.__getstate__()
+
+ @use_np_compat
+ def _slice(self, start, stop):
+ raise NotImplementedError
+
+ @use_np_compat
+ def _at(self, idx):
+ raise NotImplementedError
+
+ @use_np_compat
+ def all(self, axis=None, out=None, keepdims=False):
+ raise NotImplementedError
+
+ @use_np_compat
+ def any(self, axis=None, out=None, keepdims=False):
+ raise NotImplementedError
+
+ def as_classic_ndarray(self):
+ """Convert mxnet.numpy.ndarray to mxnet.ndarray.NDArray to use its fluent methods."""
+ hdl = NDArrayHandle()
+ check_call(_LIB.MXShallowCopyNDArray(self.handle, ctypes.byref(hdl)))
+ return NDArray(handle=hdl, writable=self.writable)
+
+ @use_np_compat
+ def __repr__(self):
+ """Returns a string representation of the array."""
+ return '%s\n<%s shape=%s ctx=%s>' % (str(self.asnumpy()), self.__class__.__name__,
+ self.shape, self.context)
+
+ @use_np_compat
+ def attach_grad(self, grad_req='write', stype=None):
+ if stype is not None:
+ raise NotImplementedError('mxnet.numpy.ndarray currently does not support stype')
+ super(ndarray, self).attach_grad(grad_req, stype)
+
+ @property
+ def grad(self):
+ """Returns gradient buffer attached to this ndarray."""
+ hdl = NDArrayHandle()
+ check_call(_LIB.MXNDArrayGetGrad(self.handle, ctypes.byref(hdl)))
+ if hdl.value is None:
+ return None
+ return _np_ndarray_cls(hdl)
+
+ @use_np_compat
+ def detach(self):
+ """Returns a new ndarray, detached from the current graph."""
+ hdl = NDArrayHandle()
+ check_call(_LIB.MXNDArrayDetach(self.handle, ctypes.byref(hdl)))
+ return _np_ndarray_cls(hdl)
+
+ @use_np_compat
+ def astype(self, dtype, *args, **kwargs): # pylint: disable=arguments-differ,unused-argument
+ """
+ Copy of the array, cast to a specified type.
+
+ Parameters
+ ----------
+ dtype : str or dtype
+ Typecode or data-type to which the array is cast.
+ copy : bool, optional
+ Default `True`. By default, astype always returns a newly
+ allocated ndarray on the same context. If this is set to
+ `False`, and the dtype requested is the same as the ndarray's
+ dtype, the ndarray is returned instead of a copy.
+
+ Returns
+ -------
+ arr_t : ndarray
+ Unless `copy` is False and the other conditions for returning the input
+ array are satisfied (see description for `copy` input parameter), `arr_t`
+ is a new array of the same shape as the input array with `dtype`.
+ """
+ _sanity_check_params('astype', ['order', 'casting', 'subok'], kwargs)
+ copy = kwargs.get('copy', True)
+ if not copy and _np.dtype(dtype) == self.dtype:
+ return self
+
+ res = empty(self.shape, dtype=dtype, ctx=self.context)
+ self.copyto(res)
+ return res
+
+ def asscalar(self):
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute as_scalar')
+
+ def as_in_context(self, context):
+ return super(ndarray, self).as_in_context(context).as_np_ndarray()
+
+ @use_np_compat
+ def copy(self, order='C'): # pylint: disable=arguments-differ
+ if order != 'C':
+ raise NotImplementedError('ndarray.copy only supports order=\'C\', while '
+ 'received {}'.format(str(order)))
+ return super(ndarray, self).copy().as_np_ndarray()
+
+ @use_np_compat
+ def reshape(self, *shape, **kwargs):
+ """Returns an array containing the same data with a new shape."""
+ raise NotImplementedError
+
+ def reshape_like(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`reshape_like`.
+
+ The arguments are the same as for :py:func:`reshape_like`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute reshape_like')
+
+ def zeros_like(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`zeros_like`.
+
+ The arguments are the same as for :py:func:`zeros_like`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute zeros_like')
+
+ def ones_like(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`ones_like`.
+
+ The arguments are the same as for :py:func:`ones_like`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute ones_like')
+
+ def broadcast_axes(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`broadcast_axes`.
+
+ The arguments are the same as for :py:func:`broadcast_axes`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute broadcast_like')
+
+ @use_np_compat
+ def repeat(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`repeat`.
+
+ The arguments are the same as for :py:func:`repeat`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def pad(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`pad`.
+
+ The arguments are the same as for :py:func:`pad`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute pad')
+
+ @use_np_compat
+ def swapaxes(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`swapaxes`.
+
+ The arguments are the same as for :py:func:`swapaxes`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def split(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`split`.
+
+ The arguments are the same as for :py:func:`split`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute split')
+
+ def split_v2(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`split_v2`.
+
+ The arguments are the same as for :py:func:`split_v2`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute split_v2')
+
+ def slice(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`slice`.
+
+ The arguments are the same as for :py:func:`slice`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute slice')
+
+ def slice_axis(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`slice_axis`.
+
+ The arguments are the same as for :py:func:`slice_axis`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute slice_axis')
+
+ def slice_like(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`slice_like`.
+
+ The arguments are the same as for :py:func:`slice_like`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute slice_like')
+
+ @use_np_compat
+ def take(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`take`.
+
+ The arguments are the same as for :py:func:`take`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def one_hot(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`one_hot`.
+
+ The arguments are the same as for :py:func:`one_hot`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute one_hot')
+
+ def pick(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`pick`.
+
+ The arguments are the same as for :py:func:`pick`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute pick')
+
+ @use_np_compat
+ def sort(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`sort`.
+
+ The arguments are the same as for :py:func:`sort`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def topk(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`topk`.
+
+ The arguments are the same as for :py:func:`topk`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute topk')
+
+ @use_np_compat
+ def argsort(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`argsort`.
+
+ The arguments are the same as for :py:func:`argsort`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ @use_np_compat
+ def argmax(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`argmax`.
+
+ The arguments are the same as for :py:func:`argmax`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def argmax_channel(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`argmax_channel`.
+
+ The arguments are the same as for :py:func:`argmax_channel`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute argmax_channel')
+
+ @use_np_compat
+ def argmin(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`argmin`.
+
+ The arguments are the same as for :py:func:`argmin`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ @use_np_compat
+ def clip(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`clip`.
+
+ The arguments are the same as for :py:func:`clip`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def abs(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`abs`.
+
+ The arguments are the same as for :py:func:`abs`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute abs')
+
+ def sign(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`sign`.
+
+ The arguments are the same as for :py:func:`sign`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute abs')
+
+ @use_np_compat
+ def flatten(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`flatten`.
+
+ The arguments are the same as for :py:func:`flatten`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def shape_array(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`shape_array`.
+
+ The arguments are the same as for :py:func:`shape_array`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute shape_array')
+
+ def size_array(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`size_array`.
+
+ The arguments are the same as for :py:func:`size_array`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute size_array')
+
+ def expand_dims(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`expand_dims`.
+
+ The arguments are the same as for :py:func:`expand_dims`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute expand_dims')
+
+ def tile(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`tile`.
+
+ The arguments are the same as for :py:func:`tile`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute tile')
+
+ @use_np_compat
+ def transpose(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`transpose`.
+
+ The arguments are the same as for :py:func:`transpose`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def flip(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`flip`.
+
+ The arguments are the same as for :py:func:`flip`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute flip')
+
+ def depth_to_space(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`depth_to_space`.
+
+ The arguments are the same as for :py:func:`depth_to_space`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute depth_to_space')
+
+ def space_to_depth(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`space_to_depth`.
+
+ The arguments are the same as for :py:func:`space_to_depth`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute space_to_depth')
+
+ def diag(self, k=0, **kwargs):
+ """Convenience fluent method for :py:func:`diag`.
+
+ The arguments are the same as for :py:func:`diag`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute diag')
+
+ @use_np_compat
+ def sum(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`sum`.
+
+ The arguments are the same as for :py:func:`sum`, with
+ this array as data.
+ """
+ return _op.sum(self, *args, **kwargs)
+
+ def nansum(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`nansum`.
+
+ The arguments are the same as for :py:func:`nansum`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute nansum')
+
+ @use_np_compat
+ def prod(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`prod`.
+
+ The arguments are the same as for :py:func:`prod`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def nanprod(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`nanprod`.
+
+ The arguments are the same as for :py:func:`nanprod`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute nanprod')
+
+ @use_np_compat
+ def mean(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`mean`.
+
+ The arguments are the same as for :py:func:`mean`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ @use_np_compat
+ def max(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`max`.
+
+ The arguments are the same as for :py:func:`max`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ @use_np_compat
+ def min(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`min`.
+
+ The arguments are the same as for :py:func:`min`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def norm(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`norm`.
+
+ The arguments are the same as for :py:func:`norm`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute norm')
+
+ @use_np_compat
+ def round(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`round`.
+
+ The arguments are the same as for :py:func:`round`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def rint(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`rint`.
+
+ The arguments are the same as for :py:func:`rint`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute rint')
+
+ def fix(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`fix`.
+
+ The arguments are the same as for :py:func:`fix`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute fix')
+
+ def floor(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`floor`.
+
+ The arguments are the same as for :py:func:`floor`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute floor')
+
+ def ceil(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`ceil`.
+
+ The arguments are the same as for :py:func:`ceil`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute ceil')
+
+ def trunc(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`trunc`.
+
+ The arguments are the same as for :py:func:`trunc`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute trunc')
+
+ def sin(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`sin`.
+
+ The arguments are the same as for :py:func:`sin`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute sin')
+
+ def cos(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`cos`.
+
+ The arguments are the same as for :py:func:`cos`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute cos')
+
+ def tan(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`tan`.
+
+ The arguments are the same as for :py:func:`tan`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute tan')
+
+ def arcsin(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`arcsin`.
+
+ The arguments are the same as for :py:func:`arcsin`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute arcsin')
+
+ def arccos(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`arccos`.
+
+ The arguments are the same as for :py:func:`arccos`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute arccos')
+
+ def arctan(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`arctan`.
+
+ The arguments are the same as for :py:func:`arctan`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute arctan')
+
+ def degrees(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`degrees`.
+
+ The arguments are the same as for :py:func:`degrees`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute degrees')
+
+ def radians(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`radians`.
+
+ The arguments are the same as for :py:func:`radians`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute radians')
+
+ def sinh(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`sinh`.
+
+ The arguments are the same as for :py:func:`sinh`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute sinh')
+
+ def cosh(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`cosh`.
+
+ The arguments are the same as for :py:func:`cosh`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute cosh')
+
+ def tanh(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`tanh`.
+
+ The arguments are the same as for :py:func:`tanh`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute tanh')
+
+ def arcsinh(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`arcsinh`.
+
+ The arguments are the same as for :py:func:`arcsinh`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute arcsinh')
+
+ def arccosh(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`arccosh`.
+
+ The arguments are the same as for :py:func:`arccosh`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute arccosh')
+
+ def arctanh(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`arctanh`.
+
+ The arguments are the same as for :py:func:`arctanh`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute arctanh')
+
+ def exp(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`exp`.
+
+ The arguments are the same as for :py:func:`exp`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute exp')
+
+ def expm1(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`expm1`.
+
+ The arguments are the same as for :py:func:`expm1`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute expm1')
+
+ def log(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`log`.
+
+ The arguments are the same as for :py:func:`log`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute log')
+
+ def log10(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`log10`.
+
+ The arguments are the same as for :py:func:`log10`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute log10')
+
+ def log2(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`log2`.
+
+ The arguments are the same as for :py:func:`log2`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute log2')
+
+ def log1p(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`log1p`.
+
+ The arguments are the same as for :py:func:`log1p`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute log1p')
+
+ def sqrt(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`sqrt`.
+
+ The arguments are the same as for :py:func:`sqrt`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute sqrt')
+
+ def rsqrt(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`rsqrt`.
+
+ The arguments are the same as for :py:func:`rsqrt`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute rsqrt')
+
+ def cbrt(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`cbrt`.
+
+ The arguments are the same as for :py:func:`cbrt`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute cqrt')
+
+ def rcbrt(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`rcbrt`.
+
+ The arguments are the same as for :py:func:`rcbrt`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute rcqrt')
+
+ def square(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`square`.
+
+ The arguments are the same as for :py:func:`square`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute square')
+
+ def reciprocal(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`reciprocal`.
+
+ The arguments are the same as for :py:func:`reciprocal`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute reciprocal')
+
+ def relu(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`relu`.
+
+ The arguments are the same as for :py:func:`relu`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute relu')
+
+ def sigmoid(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`sigmoid`.
+
+ The arguments are the same as for :py:func:`sigmoid`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute sigmoid')
+
+ def softmax(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`softmax`.
+
+ The arguments are the same as for :py:func:`softmax`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute softmax')
+
+ def log_softmax(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`log_softmax`.
+
+ The arguments are the same as for :py:func:`log_softmax`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute log_softmax')
+
+ def softmin(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`softmin`.
+
+ The arguments are the same as for :py:func:`softmin`, with
+ this array as data.
+ """
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute softmin')
+
+ @use_np_compat
+ def squeeze(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`squeeze`.
+
+ The arguments are the same as for :py:func:`squeeze`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def broadcast_to(self, shape):
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute broadcast_to')
+
+ def broadcast_like(self, other):
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute broadcast_like')
+
+ @property
+ @use_np_compat
+ def shape(self):
+ return super(ndarray, self).shape
+
+ @property
+ @use_np_compat
+ def ndim(self):
+ """Number of array dimensions."""
+ return len(self.shape)
+
+ @property
+ @use_np_compat
+ def size(self):
+ """Number of elements in the array."""
+ return super(ndarray, self).size
+
+ @property
+ @use_np_compat
+ def stype(self):
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute stype')
+
+ @property
+ @use_np_compat
+ def T(self):
+ raise NotImplementedError
+
+ def tostype(self, stype):
+ raise AttributeError('mxnet.numpy.ndarray object has no attribute tostype')
+
+
+@use_np_compat
+def empty(shape, dtype=None, **kwargs):
+ """Return a new array of given shape and type, without initializing entries.
+
+ Parameters
+ ----------
+ shape : int or tuple of int Shape of the empty array, e.g., ``(2, 3)`` or ``2``.
+ dtype : data-type, optional
+ Desired output data-type for the array, e.g, `numpy.int8`. Default is
+ `numpy.float32`. Note that this behavior is different from NumPy's `empty`
+ function where `float64` is the default value, because `float32` is
+ considered as the default data type in deep learning.
+ ctx : device context, optional
+ Device context on which the memory is allocated. Default is
+ `mxnet.context.current_context()`.
+
+ Returns
+ -------
+ out : ndarray
+ Array of uninitialized (arbitrary) data of the given shape, dtype, and order.
+ """
+ _sanity_check_params('emtpy', ['order'], kwargs)
+ ctx = kwargs.get('ctx', current_context())
+ if ctx is None:
+ ctx = current_context()
+ if dtype is None:
+ dtype = _np.float32
+ if isinstance(shape, int):
+ shape = (shape,)
+ return ndarray(handle=_new_alloc_handle(shape, ctx, False, dtype))
+
+
+@use_np_compat
+def array(object, dtype=None, **kwargs):
+ """
+ Create an array.
+
+ Parameters
+ ----------
+ object : array_like or `mxnet.ndarray.NDArray` or `mxnet.numpy.ndarray`
+ An array, any object exposing the array interface, an object whose
+ __array__ method returns an array, or any (nested) sequence.
+ dtype : data-type, optional
+ The desired data-type for the array. If not given, then the type will
+ be determined as the minimum type required to hold the objects in the
+ sequence. This argument can only be used to 'upcast' the array. For
+ downcasting, use the .astype(t) method.
+ ctx : device context, optional
+ Device context on which the memory is allocated. Default is
+ `mxnet.context.current_context()`.
+
+ Returns
+ -------
+ out : ndarray
+ An array object satisfying the specified requirements.
+ """
+ _sanity_check_params('array', ['copy', 'order', 'subok', 'ndim'], kwargs)
+ ctx = kwargs.get('ctx', current_context())
+ if ctx is None:
+ ctx = current_context()
+ if not isinstance(object, (ndarray, NDArray, _np.ndarray)):
+ try:
+ object = _np.array(object, dtype=dtype)
+ except:
+ raise TypeError('source array must be an array like object')
+ if dtype is None:
+ dtype = object.dtype
+ ret = empty(object.shape, dtype=dtype, ctx=ctx)
+ ret[:] = object
+ return ret
+
+
+def zeros(shape, dtype=_np.float32, **kwargs):
+ """Return a new array of given shape and type, filled with zeros.
+ This function currently only supports storing multi-dimensional data
+ in row-major (C-style).
+
+ Parameters
+ ----------
+ shape : int or tuple of int
+ The shape of the empty array.
+ dtype : str or numpy.dtype, optional
+ An optional value type (default is `numpy.float32`). Note that this
+ behavior is different from NumPy's `ones` function where `float64`
+ is the default value, because `float32` is considered as the default
+ data type in deep learning.
+ ctx : Context, optional
+ An optional device context (default is the current default context).
+
+ Returns
+ -------
+ out : ndarray
+ Array of zeros with the given shape, dtype, and ctx.
+ """
+ return _mx_nd_np.zeros(shape, dtype, **kwargs)
+
+
+def ones(shape, dtype=None, **kwargs):
+ """Return a new array of given shape and type, filled with zeros.
+ This function currently only supports storing multi-dimensional data
+ in row-major (C-style).
+
+ Parameters
+ ----------
+ shape : int or tuple of int
+ The shape of the empty array.
+ dtype : str or numpy.dtype, optional
+ An optional value type. Default is `numpy.float32`. Note that this
+ behavior is different from NumPy's `ones` function where `float64`
+ is the default value, because `float32` is considered as the default
+ data type in deep learning.
+ ctx : Context, optional
+ An optional device context (default is the current default context).
+
+ Returns
+ -------
+ out : ndarray
+ Array of zeros with the given shape, dtype, and ctx.
+ """
+ return _mx_nd_np.ones(shape, dtype, **kwargs)
diff --git a/python/mxnet/symbol/numpy.py b/python/mxnet/numpy/random.py
similarity index 93%
rename from python/mxnet/symbol/numpy.py
rename to python/mxnet/numpy/random.py
index 0826ac8..461da66 100644
--- a/python/mxnet/symbol/numpy.py
+++ b/python/mxnet/numpy/random.py
@@ -15,4 +15,6 @@
# specific language governing permissions and limitations
# under the License.
+"""namespace for registering numpy random operators."""
+
__all__ = []
diff --git a/python/mxnet/symbol/__init__.py b/python/mxnet/symbol/__init__.py
index 326e4f5..ae9477a 100644
--- a/python/mxnet/symbol/__init__.py
+++ b/python/mxnet/symbol/__init__.py
@@ -27,5 +27,6 @@ from . import register
from .op import *
from .symbol import *
# pylint: enable=wildcard-import
+from . import numpy as np
__all__ = op.__all__ + symbol.__all__ + ['contrib', 'linalg', 'random', 'sparse', 'image']
diff --git a/python/mxnet/symbol/_internal.py b/python/mxnet/symbol/_internal.py
index 7e9787e..d46c0e6 100644
--- a/python/mxnet/symbol/_internal.py
+++ b/python/mxnet/symbol/_internal.py
@@ -24,18 +24,18 @@ import os as _os
try:
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
- from .._ctypes.symbol import SymbolBase, _set_symbol_class
+ from .._ctypes.symbol import SymbolBase, _set_symbol_class, _set_np_symbol_class
from .._ctypes.symbol import _symbol_creator
elif _sys.version_info >= (3, 0):
- from .._cy3.symbol import SymbolBase, _set_symbol_class
+ from .._cy3.symbol import SymbolBase, _set_symbol_class, _set_np_symbol_class
from .._cy3.symbol import _symbol_creator
else:
- from .._cy2.symbol import SymbolBase, _set_symbol_class
+ from .._cy2.symbol import SymbolBase, _set_symbol_class, _set_np_symbol_class
from .._cy2.symbol import _symbol_creator
except ImportError:
if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0:
raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1")
- from .._ctypes.symbol import SymbolBase, _set_symbol_class
+ from .._ctypes.symbol import SymbolBase, _set_symbol_class, _set_np_symbol_class
from .._ctypes.symbol import _symbol_creator
from ..attribute import AttrScope
from ..base import _Null
@@ -45,4 +45,4 @@ try:
except ImportError:
pass
-__all__ = ['SymbolBase', '_set_symbol_class', '_symbol_creator']
+__all__ = ['SymbolBase', '_set_symbol_class', '_symbol_creator', '_set_np_symbol_class']
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/symbol/numpy/__init__.py
similarity index 73%
copy from python/mxnet/numpy/__init__.py
copy to python/mxnet/symbol/numpy/__init__.py
index b1139a0..d63daa2 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/symbol/numpy/__init__.py
@@ -1,5 +1,3 @@
-#!/usr/bin/env python
-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -17,4 +15,12 @@
# specific language governing permissions and limitations
# under the License.
-__all__ = []
+"""numpy module for numpy ops under mxnet.symbol."""
+
+from . import _op, _symbol
+from ._symbol import _NumpySymbol
+from . import _register
+from ._op import * # pylint: disable=wildcard-import
+from ._symbol import * # pylint: disable=wildcard-import
+
+__all__ = _op.__all__ + _symbol.__all__
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/symbol/numpy/_op.py
similarity index 90%
copy from python/mxnet/numpy/__init__.py
copy to python/mxnet/symbol/numpy/_op.py
index b1139a0..96da828 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/symbol/numpy/_op.py
@@ -1,5 +1,3 @@
-#!/usr/bin/env python
-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -17,4 +15,6 @@
# specific language governing permissions and limitations
# under the License.
+"""numpy namespace for operators used in Gluon APIs dispatched by F=symbol module."""
+
__all__ = []
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/symbol/numpy/_register.py
similarity index 78%
copy from python/mxnet/numpy/__init__.py
copy to python/mxnet/symbol/numpy/_register.py
index b1139a0..36dfd78 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/symbol/numpy/_register.py
@@ -1,5 +1,3 @@
-#!/usr/bin/env python
-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -17,4 +15,9 @@
# specific language governing permissions and limitations
# under the License.
-__all__ = []
+"""module for registering numpy ops under mxnet.symbol.numpy."""
+
+from ...base import _init_np_op_module
+from ..register import _make_symbol_function
+
+_init_np_op_module('mxnet', 'symbol', _make_symbol_function)
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
new file mode 100644
index 0000000..087f118
--- /dev/null
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -0,0 +1,974 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""numpy namespace for operators used in Gluon APIs dispatched by F=symbol module."""
+
+from __future__ import absolute_import
+import ctypes
+import numpy as _np
+from . import _op as _np_op
+from ...base import _sanity_check_params, use_np_compat, check_call, _LIB, SymbolHandle
+from ...base import numeric_types
+from ...context import current_context
+from .. import _internal
+from ..symbol import Symbol
+from .._internal import _set_np_symbol_class
+from .. import _internal as _sym_internal
+
+__all__ = ['zeros', 'ones']
+
+
+class _NumpySymbol(Symbol):
+
+ def _is_np_compat(self):
+ return True
+
+ def __getitem__(self, item):
+ raise NotImplementedError
+
+ def __setitem__(self, key, value):
+ raise NotImplementedError
+
+ def __iter__(self):
+ raise AttributeError('_NumpySymbol object has no attribute __iter__')
+
+ @use_np_compat
+ def __add__(self, other):
+ """x.__add__(y) <=> x + y"""
+ if isinstance(other, Symbol):
+ return _sym_internal._np_add(self, other)
+ elif isinstance(other, numeric_types):
+ return _sym_internal._np_add_scalar(self, float(other))
+ else:
+ raise TypeError("_NumpySymbol does not support type {} as operand"
+ .format(str(type(other))))
+
+ @use_np_compat
+ def __sub__(self, other):
+ """x.__sub__(y) <=> x - y"""
+ if isinstance(other, Symbol):
+ return _sym_internal._np_subtract(self, other)
+ elif isinstance(other, numeric_types):
+ return _sym_internal._np_subtract_scalar(self, float(other))
+ else:
+ raise TypeError("_NumpySymbol does not support type {} as operand"
+ .format(str(type(other))))
+
+ @use_np_compat
+ def __rsub__(self, other):
+ """x.__rsub__(y) <=> y - x"""
+ if isinstance(other, Symbol):
+ return _sym_internal._np_subtract(other, self)
+ elif isinstance(other, numeric_types):
+ return _sym_internal._np_rsubtract_scalar(self, float(other))
+ else:
+ raise TypeError("_NumpySymbol does not support type {} as operand"
+ .format(str(type(other))))
+
+ @use_np_compat
+ def __mul__(self, other):
+ """x.__mul__(y) <=> x * y"""
+ if isinstance(other, Symbol):
+ return _sym_internal._np_multiply(self, other)
+ elif isinstance(other, numeric_types):
+ return _sym_internal._np_multiply_scalar(self, float(other))
+ else:
+ raise TypeError("_NumpySymbol does not support type {} as operand"
+ .format(str(type(other))))
+
+ @use_np_compat
+ def __rmul__(self, other):
+ """x.__rmul__(y) <=> y * x"""
+ if isinstance(other, Symbol):
+ return _sym_internal._np_multiply(self, other)
+ elif isinstance(other, numeric_types):
+ return _sym_internal._np_multiply_scalar(self, float(other))
+ else:
+ raise TypeError("_NumpySymbol does not support type {} as operand"
+ .format(str(type(other))))
+
+ def __div__(self, other):
+ raise AttributeError('_NumpySymbol.__div__ is replaced by __truediv__. If you are using'
+ ' Python2, please use the statement from __future__ import division'
+ ' to change the / operator to mean true division throughout the'
+ ' module. If you are using Python3, this error should not have'
+ ' been encountered.')
+
+ def __rdiv__(self, other):
+ raise AttributeError('_NumpySymbol.__rdiv__ is replaced by __rtruediv__. If you are using'
+ ' Python2, please use the statement from __future__ import division'
+ ' to change the / operator to mean true division throughout the'
+ ' module. If you are using Python3, this error should not have'
+ ' been encountered.')
+
+ @use_np_compat
+ def __mod__(self, other):
+ """x.__mod__(y) <=> x % y"""
+ if isinstance(other, Symbol):
+ return _sym_internal._np_mod(self, other)
+ elif isinstance(other, numeric_types):
+ return _sym_internal._np_mod_scalar(self, float(other))
+ else:
+ raise TypeError("_NumpySymbol does not support type {} as operand"
+ .format(str(type(other))))
+
+ @use_np_compat
+ def __rmod__(self, other):
+ """x.__rmod__(y) <=> y % x"""
+ if isinstance(other, Symbol):
+ return _sym_internal._np_mod(other, self)
+ elif isinstance(other, numeric_types):
+ return _sym_internal._np_rmod_scalar(self, float(other))
+ else:
+ raise TypeError("_NumpySymbol does not support type {} as operand"
+ .format(str(type(other))))
+
+ @use_np_compat
+ def __idiv__(self, other):
+ raise NotImplementedError
+
+ @use_np_compat
+ def __truediv__(self, other):
+ """x.__truediv__(y) <=> x / y"""
+ if isinstance(other, Symbol):
+ return _sym_internal._true_divide(self, other)
+ elif isinstance(other, numeric_types):
+ return _sym_internal._true_divide_scalar(self, float(other))
+ else:
+ raise TypeError("_NumpySymbol does not support type {} as divisor"
+ .format(str(type(other))))
+
+ @use_np_compat
+ def __rtruediv__(self, other):
+ """x.__rtruediv__(y) <=> y / x"""
+ if isinstance(other, Symbol):
+ return _sym_internal._true_divide(other, self)
+ elif isinstance(other, numeric_types):
+ return _sym_internal._rtrue_divide_scalar(self, float(other)).as_np_ndarray()
+ else:
+ raise TypeError("_NumpySymbol does not support type {} as dividend"
+ .format(str(type(other))))
+
+ @use_np_compat
+ def __itruediv__(self, other):
+ raise NotImplementedError
+
+ @use_np_compat
+ def __pow__(self, other):
+ """x.__pow__(y) <=> x ** y"""
+ if isinstance(other, Symbol):
+ return _sym_internal._np_power(self, other)
+ elif isinstance(other, numeric_types):
+ return _sym_internal._np_power_scalar(self, float(other))
+ else:
+ raise TypeError("_NumpySymbol does not support type {} as operand"
+ .format(str(type(other))))
+
+ @use_np_compat
+ def __rpow__(self, other):
+ """x.__rpow__(y) <=> y ** x"""
+ if isinstance(other, Symbol):
+ return _sym_internal._np_power(other, self)
+ elif isinstance(other, numeric_types):
+ return _sym_internal._np_rpower_scalar(self, float(other))
+ else:
+ raise TypeError("_NumpySymbol does not support type {} as operand"
+ .format(str(type(other))))
+
+ @use_np_compat
+ def __neg__(self):
+ """x.__neg__() <=> - x"""
+ return self.__mul__(-1.0)
+
+ @use_np_compat
+ def __deepcopy__(self, _):
+ return super(_NumpySymbol, self).as_np_ndarray()
+
+ @use_np_compat
+ def __eq__(self, other):
+ """x.__eq__(y) <=> x == y"""
+ raise NotImplementedError
+
+ @use_np_compat
+ def __ne__(self, other):
+ """x.__ne__(y) <=> x != y"""
+ raise NotImplementedError
+
+ @use_np_compat
+ def __gt__(self, other):
+ """x.__gt__(y) <=> x > y"""
+ raise NotImplementedError
+
+ @use_np_compat
+ def __ge__(self, other):
+ """x.__ge__(y) <=> x >= y"""
+ raise NotImplementedError
+
+ @use_np_compat
+ def __lt__(self, other):
+ """x.__lt__(y) <=> x < y"""
+ raise NotImplementedError
+
+ @use_np_compat
+ def __le__(self, other):
+ """x.__le__(y) <=> x <= y"""
+ raise NotImplementedError
+
+ def __len__(self):
+ raise NotImplementedError
+
+ def as_classic_ndarray(self):
+ """Convert _NumpySymbol to mxnet.symbol.Symbol to use its convenience fluent methods."""
+ hdl = SymbolHandle()
+ check_call(_LIB.MXShallowCopySymbol(self.handle, ctypes.byref(hdl)))
+ return Symbol(handle=hdl)
+
+ @use_np_compat
+ def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ
+ raise NotImplementedError
+
+ @use_np_compat
+ def reshape(self, *shape, **kwargs):
+ raise NotImplementedError
+
+ def reshape_like(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`reshape_like`.
+
+ The arguments are the same as for :py:func:`reshape_like`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute reshape_like')
+
+ def zeros_like(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`zeros_like`.
+
+ The arguments are the same as for :py:func:`zeros_like`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute zeros_like')
+
+ def ones_like(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`ones_like`.
+
+ The arguments are the same as for :py:func:`ones_like`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute ones_like')
+
+ def broadcast_axes(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`broadcast_axes`.
+
+ The arguments are the same as for :py:func:`broadcast_axes`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute broadcast_like')
+
+ @use_np_compat
+ def repeat(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`repeat`.
+
+ The arguments are the same as for :py:func:`repeat`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def pad(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`pad`.
+
+ The arguments are the same as for :py:func:`pad`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute pad')
+
+ @use_np_compat
+ def swapaxes(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`swapaxes`.
+
+ The arguments are the same as for :py:func:`swapaxes`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def split(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`split`.
+
+ The arguments are the same as for :py:func:`split`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute split')
+
+ def split_v2(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`split_v2`.
+
+ The arguments are the same as for :py:func:`split_v2`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute split_v2')
+
+ def slice(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`slice`.
+
+ The arguments are the same as for :py:func:`slice`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute slice')
+
+ def slice_axis(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`slice_axis`.
+
+ The arguments are the same as for :py:func:`slice_axis`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute slice_axis')
+
+ def slice_like(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`slice_like`.
+
+ The arguments are the same as for :py:func:`slice_like`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute slice_like')
+
+ @use_np_compat
+ def take(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`take`.
+
+ The arguments are the same as for :py:func:`take`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def one_hot(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`one_hot`.
+
+ The arguments are the same as for :py:func:`one_hot`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute one_hot')
+
+ def pick(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`pick`.
+
+ The arguments are the same as for :py:func:`pick`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute pick')
+
+ @use_np_compat
+ def sort(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`sort`.
+
+ The arguments are the same as for :py:func:`sort`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def topk(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`topk`.
+
+ The arguments are the same as for :py:func:`topk`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute topk')
+
+ @use_np_compat
+ def argsort(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`argsort`.
+
+ The arguments are the same as for :py:func:`argsort`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ @use_np_compat
+ def argmax(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`argmax`.
+
+ The arguments are the same as for :py:func:`argmax`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def argmax_channel(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`argmax_channel`.
+
+ The arguments are the same as for :py:func:`argmax_channel`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute argmax_channel')
+
+ @use_np_compat
+ def argmin(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`argmin`.
+
+ The arguments are the same as for :py:func:`argmin`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ @use_np_compat
+ def clip(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`clip`.
+
+ The arguments are the same as for :py:func:`clip`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def abs(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`abs`.
+
+ The arguments are the same as for :py:func:`abs`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute abs')
+
+ def sign(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`sign`.
+
+ The arguments are the same as for :py:func:`sign`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute abs')
+
+ @use_np_compat
+ def flatten(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`flatten`.
+
+ The arguments are the same as for :py:func:`flatten`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def shape_array(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`shape_array`.
+
+ The arguments are the same as for :py:func:`shape_array`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute shape_array')
+
+ def size_array(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`size_array`.
+
+ The arguments are the same as for :py:func:`size_array`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute size_array')
+
+ def expand_dims(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`expand_dims`.
+
+ The arguments are the same as for :py:func:`expand_dims`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute expand_dims')
+
+ def tile(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`tile`.
+
+ The arguments are the same as for :py:func:`tile`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute tile')
+
+ @use_np_compat
+ def transpose(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`transpose`.
+
+ The arguments are the same as for :py:func:`transpose`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def flip(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`flip`.
+
+ The arguments are the same as for :py:func:`flip`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute flip')
+
+ def depth_to_space(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`depth_to_space`.
+
+ The arguments are the same as for :py:func:`depth_to_space`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute depth_to_space')
+
+ def space_to_depth(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`space_to_depth`.
+
+ The arguments are the same as for :py:func:`space_to_depth`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute space_to_depth')
+
+ def diag(self, k=0, **kwargs):
+ """Convenience fluent method for :py:func:`diag`.
+
+ The arguments are the same as for :py:func:`diag`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute diag')
+
+ @use_np_compat
+ def sum(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`sum`.
+
+ The arguments are the same as for :py:func:`sum`, with
+ this array as data.
+ """
+ return _np_op.sum(self, *args, **kwargs)
+
+ def nansum(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`nansum`.
+
+ The arguments are the same as for :py:func:`nansum`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute nansum')
+
+ @use_np_compat
+ def prod(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`prod`.
+
+ The arguments are the same as for :py:func:`prod`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def nanprod(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`nanprod`.
+
+ The arguments are the same as for :py:func:`nanprod`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute nanprod')
+
+ @use_np_compat
+ def mean(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`mean`.
+
+ The arguments are the same as for :py:func:`mean`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ @use_np_compat
+ def max(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`max`.
+
+ The arguments are the same as for :py:func:`max`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ @use_np_compat
+ def min(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`min`.
+
+ The arguments are the same as for :py:func:`min`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def norm(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`norm`.
+
+ The arguments are the same as for :py:func:`norm`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute norm')
+
+ @use_np_compat
+ def round(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`round`.
+
+ The arguments are the same as for :py:func:`round`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def rint(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`rint`.
+
+ The arguments are the same as for :py:func:`rint`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute rint')
+
+ def fix(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`fix`.
+
+ The arguments are the same as for :py:func:`fix`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute fix')
+
+ def floor(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`floor`.
+
+ The arguments are the same as for :py:func:`floor`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute floor')
+
+ def ceil(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`ceil`.
+
+ The arguments are the same as for :py:func:`ceil`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute ceil')
+
+ def trunc(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`trunc`.
+
+ The arguments are the same as for :py:func:`trunc`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute trunc')
+
+ def sin(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`sin`.
+
+ The arguments are the same as for :py:func:`sin`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute sin')
+
+ def cos(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`cos`.
+
+ The arguments are the same as for :py:func:`cos`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute cos')
+
+ def tan(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`tan`.
+
+ The arguments are the same as for :py:func:`tan`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute tan')
+
+ def arcsin(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`arcsin`.
+
+ The arguments are the same as for :py:func:`arcsin`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute arcsin')
+
+ def arccos(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`arccos`.
+
+ The arguments are the same as for :py:func:`arccos`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute arccos')
+
+ def arctan(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`arctan`.
+
+ The arguments are the same as for :py:func:`arctan`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute arctan')
+
+ def degrees(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`degrees`.
+
+ The arguments are the same as for :py:func:`degrees`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute degrees')
+
+ def radians(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`radians`.
+
+ The arguments are the same as for :py:func:`radians`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute radians')
+
+ def sinh(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`sinh`.
+
+ The arguments are the same as for :py:func:`sinh`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute sinh')
+
+ def cosh(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`cosh`.
+
+ The arguments are the same as for :py:func:`cosh`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute cosh')
+
+ def tanh(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`tanh`.
+
+ The arguments are the same as for :py:func:`tanh`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute tanh')
+
+ def arcsinh(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`arcsinh`.
+
+ The arguments are the same as for :py:func:`arcsinh`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute arcsinh')
+
+ def arccosh(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`arccosh`.
+
+ The arguments are the same as for :py:func:`arccosh`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute arccosh')
+
+ def arctanh(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`arctanh`.
+
+ The arguments are the same as for :py:func:`arctanh`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute arctanh')
+
+ def exp(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`exp`.
+
+ The arguments are the same as for :py:func:`exp`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute exp')
+
+ def expm1(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`expm1`.
+
+ The arguments are the same as for :py:func:`expm1`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute expm1')
+
+ def log(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`log`.
+
+ The arguments are the same as for :py:func:`log`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute log')
+
+ def log10(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`log10`.
+
+ The arguments are the same as for :py:func:`log10`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute log10')
+
+ def log2(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`log2`.
+
+ The arguments are the same as for :py:func:`log2`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute log2')
+
+ def log1p(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`log1p`.
+
+ The arguments are the same as for :py:func:`log1p`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute log1p')
+
+ def sqrt(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`sqrt`.
+
+ The arguments are the same as for :py:func:`sqrt`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute sqrt')
+
+ def rsqrt(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`rsqrt`.
+
+ The arguments are the same as for :py:func:`rsqrt`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute rsqrt')
+
+ def cbrt(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`cbrt`.
+
+ The arguments are the same as for :py:func:`cbrt`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute cqrt')
+
+ def rcbrt(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`rcbrt`.
+
+ The arguments are the same as for :py:func:`rcbrt`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute rcqrt')
+
+ def square(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`square`.
+
+ The arguments are the same as for :py:func:`square`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute square')
+
+ def reciprocal(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`reciprocal`.
+
+ The arguments are the same as for :py:func:`reciprocal`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute reciprocal')
+
+ def relu(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`relu`.
+
+ The arguments are the same as for :py:func:`relu`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute relu')
+
+ def sigmoid(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`sigmoid`.
+
+ The arguments are the same as for :py:func:`sigmoid`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute sigmoid')
+
+ def softmax(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`softmax`.
+
+ The arguments are the same as for :py:func:`softmax`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute softmax')
+
+ def log_softmax(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`log_softmax`.
+
+ The arguments are the same as for :py:func:`log_softmax`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute log_softmax')
+
+ def softmin(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`softmin`.
+
+ The arguments are the same as for :py:func:`softmin`, with
+ this array as data.
+ """
+ raise AttributeError('_NumpySymbol object has no attribute softmin')
+
+ @use_np_compat
+ def squeeze(self, *args, **kwargs):
+ """Convenience fluent method for :py:func:`squeeze`.
+
+ The arguments are the same as for :py:func:`squeeze`, with
+ this array as data.
+ """
+ raise NotImplementedError
+
+ def broadcast_to(self, *args, **kwargs):
+ raise AttributeError('_NumpySymbol object has no attribute broadcast_to')
+
+ def broadcast_like(self, *args, **kwargs):
+ raise AttributeError('_NumpySymbol object has no attribute broadcast_like')
+
+
+@use_np_compat
+def zeros(shape, dtype=_np.float32, **kwargs):
+ """Return a new array of given shape and type, filled with zeros.
+ This function currently only supports storing multi-dimensional data
+ in row-major (C-style).
+
+ Parameters
+ ----------
+ shape : int or tuple of int
+ The shape of the empty array.
+ dtype : str or numpy.dtype, optional
+ An optional value type. Default is `numpy.float32`. Note that this
+ behavior is different from NumPy's `zeros` function where `float64`
+ is the default value, because `float32` is considered as the default
+ data type in deep learning.
+ ctx : Context, optional
+ An optional device context (default is the current default context).
+
+ Returns
+ -------
+ out : Symbol
+ Array of zeros with the given shape, dtype, and ctx.
+ """
+ _sanity_check_params('zeros', ['order'], kwargs)
+ ctx = kwargs.get('ctx', current_context())
+ if ctx is None:
+ ctx = current_context()
+ dtype = _np.float32 if dtype is None else dtype
+ return _internal._np_zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
+
+
+@use_np_compat
+def ones(shape, dtype=None, **kwargs):
+ """Return a new array of given shape and type, filled with zeros.
+ This function currently only supports storing multi-dimensional data
+ in row-major (C-style).
+
+ Parameters
+ ----------
+ shape : int or tuple of int
+ The shape of the empty array.
+ dtype : str or numpy.dtype, optional
+ An optional value type. Default is `numpy.float32`. Note that this
+ behavior is different from NumPy's `ones` function where `float64`
+ is the default value, because `float32` is considered as the default
+ data type in deep learning.
+ ctx : Context, optional
+ An optional device context (default is the current default context).
+
+ Returns
+ -------
+ out : ndarray
+ Array of zeros with the given shape, dtype, and ctx.
+ """
+ _sanity_check_params('zeros', ['order'], kwargs)
+ ctx = kwargs.get('ctx', current_context())
+ if ctx is None:
+ ctx = current_context()
+ dtype = _np.float32 if dtype is None else dtype
+ return _internal._np_ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
+
+
+_set_np_symbol_class(_NumpySymbol)
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index d3cd519..7be042c 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -30,7 +30,7 @@ import ctypes
import warnings
from numbers import Number
-import numpy as _numpy
+import numpy as _numpy # pylint: disable=relative-import
from ..attribute import AttrScope
from ..base import _LIB, numeric_types, c_array, c_array_buf, c_str, c_str_array, c_handle_array
@@ -61,6 +61,17 @@ class Symbol(SymbolBase):
# Make numpy functions return Symbol instead of numpy object array
__array_priority__ = 1000.0
+ def as_np_ndarray(self):
+ """Convert mxnet.symbol.Symbol to _NumpySymbol."""
+ from .numpy import _NumpySymbol
+ hdl = SymbolHandle()
+ check_call(_LIB.MXShallowCopySymbol(self.handle, ctypes.byref(hdl)))
+ return _NumpySymbol(hdl)
+
+ def _is_np_compat(self):
+ """Always returns False except for mxnet.symbol.numpy._NumpySymbol."""
+ return False
+
def __repr__(self):
"""Gets a string representation of the symbol."""
name = self.name
@@ -99,6 +110,8 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_add` instead. """
if isinstance(other, Symbol):
+ if other._is_np_compat():
+ return other.__add__(self)
return _internal._Plus(self, other)
if isinstance(other, Number):
return _internal._PlusScalar(self, scalar=other)
@@ -114,6 +127,8 @@ class Symbol(SymbolBase):
raise NotImplementedForSymbol(self.__iadd__, '+=', other, 1)
def __radd__(self, other):
+ if isinstance(other, Symbol) and other._is_np_compat():
+ return other.__add__(self)
return self.__add__(other)
def __sub__(self, other):
@@ -122,6 +137,8 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_sub` instead. """
if isinstance(other, Symbol):
+ if other._is_np_compat():
+ return other.__rsub__(self)
return _internal._Minus(self, other)
if isinstance(other, Number):
return _internal._MinusScalar(self, scalar=other)
@@ -144,6 +161,8 @@ class Symbol(SymbolBase):
array([[-2., -2., -2.],
[-2., -2., -2.]], dtype=float32)
"""
+ if isinstance(other, Symbol) and other._is_np_compat():
+ return other.__sub__(self)
if isinstance(other, Number):
return _internal._RMinusScalar(self, scalar=other)
else:
@@ -155,6 +174,8 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_mul` instead. """
if isinstance(other, Symbol):
+ if other._is_np_compat():
+ return other.__mul__(self)
return _internal._Mul(self, other)
if isinstance(other, Number):
return _internal._MulScalar(self, scalar=other)
@@ -165,6 +186,8 @@ class Symbol(SymbolBase):
raise NotImplementedForSymbol(self.__imul__, '*=', other)
def __rmul__(self, other):
+ if isinstance(other, Symbol) and other._is_np_compat():
+ return other.__mul__(self)
return self.__mul__(other)
def __div__(self, other):
@@ -173,6 +196,8 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_div` instead. """
if isinstance(other, Symbol):
+ if other._is_np_compat():
+ return other.__rtruediv__(self)
return _internal._Div(self, other)
if isinstance(other, Number):
return _internal._DivScalar(self, scalar=other)
@@ -192,6 +217,8 @@ class Symbol(SymbolBase):
array([[ 0.33333334, 0.33333334, 0.33333334],
[ 0.33333334, 0.33333334, 0.33333334]], dtype=float32)
"""
+ if isinstance(other, Symbol) and other._is_np_compat():
+ return other.__truediv__(self)
if isinstance(other, Number):
return _internal._RDivScalar(self, scalar=other)
else:
@@ -203,6 +230,8 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_mod` instead. """
if isinstance(other, Symbol):
+ if other._is_np_compat():
+ return other.__rmod__(self)
return _internal._Mod(self, other)
if isinstance(other, Number):
return _internal._ModScalar(self, scalar=other)
@@ -222,6 +251,8 @@ class Symbol(SymbolBase):
array([[ 1., 1., 1.,
[ 1., 1., 1., dtype=float32)
"""
+ if isinstance(other, Symbol) and other._is_np_compat():
+ return other.__mod__(self)
if isinstance(other, Number):
return _internal._RModScalar(self, scalar=other)
else:
@@ -245,6 +276,8 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_pow` instead. """
if isinstance(other, Symbol):
+ if other._is_np_compat():
+ return other.__rpow__(self)
return _internal._Power(self, other)
if isinstance(other, Number):
return _internal._PowerScalar(self, scalar=other)
@@ -252,7 +285,15 @@ class Symbol(SymbolBase):
raise TypeError('type %s not supported' % str(type(other)))
def __rpow__(self, other):
- raise NotImplementedForSymbol(self.__rpow__, 'y**x', other)
+ """x.__rpow__(y) <=> y ** x"""
+ if isinstance(other, Symbol):
+ if other._is_np_compat():
+ return other.__pow__(self)
+ return other.__pow__(self)
+ elif isinstance(other, Number):
+ return _internal._rpower_scalar(self, scalar=other)
+ else:
+ raise TypeError('type %s not supported' % str(type(other)))
def __neg__(self):
"""x.__neg__() <=> -x
@@ -307,6 +348,8 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_equal` instead. """
if isinstance(other, Symbol):
+ if other._is_np_compat():
+ return other.__eq__(self)
return _internal._equal(self, other)
if isinstance(other, numeric_types):
return _internal._equal_scalar(self, scalar=other)
@@ -319,6 +362,8 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_not_equal` instead. """
if isinstance(other, Symbol):
+ if other._is_np_compat():
+ return other.__ne__(self)
return _internal._not_equal(self, other)
if isinstance(other, numeric_types):
return _internal._not_equal_scalar(self, scalar=other)
@@ -331,6 +376,8 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_greater` instead. """
if isinstance(other, Symbol):
+ if other._is_np_compat():
+ return other.__lt__(self)
return _internal._greater(self, other)
if isinstance(other, numeric_types):
return _internal._greater_scalar(self, scalar=other)
@@ -343,6 +390,8 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_greater_equal` instead. """
if isinstance(other, Symbol):
+ if other._is_np_compat():
+ return other.__le__(self)
return _internal._greater_equal(self, other)
if isinstance(other, numeric_types):
return _internal._greater_equal_scalar(self, scalar=other)
@@ -355,6 +404,8 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_lesser` instead. """
if isinstance(other, Symbol):
+ if other._is_np_compat():
+ return other.__gt__(self)
return _internal._lesser(self, other)
if isinstance(other, numeric_types):
return _internal._lesser_scalar(self, scalar=other)
@@ -367,6 +418,8 @@ class Symbol(SymbolBase):
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_lesser_equal` instead. """
if isinstance(other, Symbol):
+ if other._is_np_compat():
+ return other.__ge__(self)
return _internal._lesser_equal(self, other)
if isinstance(other, numeric_types):
return _internal._lesser_equal_scalar(self, scalar=other)
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index aa46a96..4612ab4 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -89,7 +89,8 @@ def get_etol(etol=None):
def random_arrays(*shapes):
"""Generate some random numpy arrays."""
- arrays = [np.random.randn(*s).astype(default_dtype())
+ arrays = [np.array(np.random.randn(), dtype=default_dtype())
+ if len(s) == 0 else np.random.randn(*s).astype(default_dtype())
for s in shapes]
if len(arrays) == 1:
return arrays[0]
@@ -408,16 +409,20 @@ def create_sparse_array_zd(shape, stype, density, data_init=None,
density=density,
shuffle_csr_indices=shuffle_csr_indices)
-def rand_shape_2d(dim0=10, dim1=10):
- return rnd.randint(1, dim0 + 1), rnd.randint(1, dim1 + 1)
+def rand_shape_2d(dim0=10, dim1=10, allow_zero_size=False):
+ low = 0 if allow_zero_size else 1
+ return rnd.randint(low, dim0 + 1), rnd.randint(low, dim1 + 1)
-def rand_shape_3d(dim0=10, dim1=10, dim2=10):
- return rnd.randint(1, dim0 + 1), rnd.randint(1, dim1 + 1), rnd.randint(1, dim2 + 1)
+def rand_shape_3d(dim0=10, dim1=10, dim2=10, allow_zero_size=False):
+ low = 0 if allow_zero_size else 1
+ return rnd.randint(low, dim0 + 1), rnd.randint(low, dim1 + 1), rnd.randint(low, dim2 + 1)
-def rand_shape_nd(num_dim, dim=10):
- return tuple(rnd.randint(1, dim+1, size=num_dim))
+
+def rand_shape_nd(num_dim, dim=10, allow_zero_size=False):
+ low = 0 if allow_zero_size else 1
+ return tuple(rnd.randint(low, dim+1, size=num_dim))
def rand_coord_2d(x_low, x_high, y_low, y_high):
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 35bd3ee..35362c2 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -1580,3 +1580,12 @@ int MXStorageEmptyCache(int dev_type, int dev_id) {
Storage::Get()->ReleaseAll(ctx);
API_END();
}
+
+int MXShallowCopyNDArray(NDArrayHandle src_handle, NDArrayHandle* out) {
+ NDArray* ret = nullptr;
+ API_BEGIN();
+ NDArray* src_array = static_cast<NDArray*>(src_handle);
+ ret = new NDArray(*src_array);
+ *out = ret;
+ API_END_HANDLE_ERROR(delete ret);
+}
diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h
index 013ecab..118341d 100644
--- a/src/c_api/c_api_common.h
+++ b/src/c_api/c_api_common.h
@@ -31,6 +31,7 @@
#include <mxnet/c_api.h>
#include <mxnet/c_api_error.h>
#include <mxnet/base.h>
+#include <mxnet/op_attr_types.h>
#include <nnvm/graph.h>
#include <vector>
#include <string>
@@ -162,4 +163,10 @@ inline void CopyAttr(const nnvm::IndexedGraph& idx,
extern const std::vector<std::string> kHiddenKeys;
} // namespace mxnet
+inline bool IsNumpyCompatOp(const nnvm::Op* op) {
+ static const auto& is_np_compat =
+ nnvm::Op::GetAttr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible");
+ return is_np_compat.get(op, false);
+}
+
#endif // MXNET_C_API_C_API_COMMON_H_
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index c9c6000..f65c804 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -378,3 +378,19 @@ int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out) {
*out = reinterpret_cast<SymbolHandle>(sym);
API_END();
}
+
+int MXIsCachedOpOutputFromNumpyCompatOp(CachedOpHandle handle,
+ int output_idx,
+ int* is_from_np_op) {
+ API_BEGIN();
+ CachedOpPtr op = *static_cast<CachedOpPtr*>(handle);
+ const auto& output_entries = op->GetForwardSym().outputs;
+ CHECK_LT(output_idx, static_cast<int>(output_entries.size()));
+ const nnvm::NodePtr& node_ptr = output_entries[output_idx].node;
+ if (node_ptr->is_variable()) {
+ *is_from_np_op = 0;
+ } else {
+ *is_from_np_op = (IsNumpyCompatOp(node_ptr->op()) ? 1 : 0);
+ }
+ API_END();
+}
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index 80ae543..b63380d 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -1059,11 +1059,20 @@ int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_h
API_BEGIN();
nnvm::Symbol *source = static_cast<nnvm::Symbol *>(sym_handle);
CHECK_EQ(source->outputs.size(), 1U)
- << "Generating atomic symbol from other symbol only works for nongrouped symbol.";
- const auto& node = source->outputs[0];
+ << "Generating atomic symbol from other symbol only works for nongrouped symbol.";
+ const auto &node = source->outputs[0];
const auto *op = node.node->op();
const auto attrs = source->ListAttrs(nnvm::Symbol::ListAttrOption::kShallow);
*s = nnvm::Symbol::CreateFunctor(op, attrs);
*ret_sym_handle = s;
API_END_HANDLE_ERROR(delete s);
}
+
+int MXShallowCopySymbol(SymbolHandle src, SymbolHandle* out) {
+ nnvm::Symbol* out_sym = new nnvm::Symbol;
+ API_BEGIN();
+ nnvm::Symbol* src_sym = static_cast<nnvm::Symbol*>(src);
+ *out_sym = *src_sym;
+ *out = out_sym;
+ API_END_HANDLE_ERROR(delete out_sym);
+}
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index b867162..106a9e0 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -856,7 +856,6 @@ inline std::multimap<size_t, NDArray> AllocateMemory(
}
CHECK_EQ(stypes[i], kDefaultStorage);
if (mem_plan[i].root == i) {
- CHECK_GT(mem_plan[i].size, 0);
auto iter = pool.lower_bound(mem_plan[i].size);
if (iter != pool.end()) {
*arrays[i] = iter->second.AsArray(shapes[i], dtypes[i]);
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index bee8bef..f883a35 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -1205,7 +1205,10 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op
<< "from.shape = " << from.shape() << " to.shape=" << to.shape();
CHECK(!mxnet::op::shape_is_none(from.shape()))
<< "source operands have undefined shape";
- if (from.shape().Size() == 0U) return;
+ // zero-size array, no need to copy
+ if (from.shape().Size() == 0U) {
+ return;
+ }
// important: callback must always capture by value
const Context from_ctx = from.ctx();
const int a = from_ctx.dev_mask();
@@ -1865,6 +1868,10 @@ void NDArray::SyncCopyFromCPU(const void *data, size_t size) const {
mxnet::TShape dshape = this->shape();
CHECK_EQ(dshape.Size(), size)
<< "Memory size do not match";
+ // zero-size array, no need to copy
+ if (size == 0U) {
+ return;
+ }
TBlob src((void*)data, dshape, cpu::kDevMask, this->dtype_, 0); // NOLINT(*)
if (this->ctx().dev_mask() == cpu::kDevMask) {
@@ -1996,6 +2003,10 @@ void NDArray::SyncCopyToCPU(void *data, size_t size) const {
mxnet::TShape dshape = this->shape();
CHECK_EQ(dshape.Size(), size)
<< "Memory size do not match";
+ // zero-size array, no need to copy
+ if (size == 0U) {
+ return;
+ }
TBlob dst(data, dshape, cpu::kDevMask, this->dtype_, 0); // NOLINT(*)
if (this->ctx().dev_mask() == cpu::kDevMask) {
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc
index 6c81bf6..13b575a 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cc
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc
@@ -65,7 +65,8 @@ NNVM_REGISTER_OP(_numpy_sum)
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_numpy_sum"});
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_numpy_sum"})
+.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
NNVM_REGISTER_OP(_backward_numpy_sum)
.set_num_outputs(1)
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc
new file mode 100644
index 0000000..e8988c8
--- /dev/null
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cc
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_elemwise_binary_op.cc
+ * \brief CPU Implementation of basic functions for elementwise numpy binary broadcast operator.
+ */
+
+#include "../tensor/elemwise_binary_broadcast_op.h"
+#include "../tensor/elemwise_binary_scalar_op.h"
+
+namespace mxnet {
+namespace op {
+
+bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 1U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ const int itype = in_attrs->at(0);
+ if (itype == -1) return false;
+ auto is_float = [](const int dtype) {
+ return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16;
+ };
+ CHECK(is_float(itype)) << "numpy binary scalar op currently only supports float dtype";
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, itype);
+ return true;
+}
+
+#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
+ NNVM_REGISTER_OP(name) \
+ .set_num_inputs(1) \
+ .set_num_outputs(1) \
+ .set_attr_parser([](NodeAttrs* attrs) { \
+ attrs->parsed = std::stod(attrs->dict["scalar"]); \
+ }) \
+ .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
+ .set_attr<nnvm::FInferType>("FInferType", NumpyBinaryScalarType) \
+ .set_attr<nnvm::FInplaceOption>("FInplaceOption", \
+ [](const NodeAttrs& attrs){ \
+ return std::vector<std::pair<int, int> >{{0, 0}}; \
+ }) \
+ .set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true) \
+ .add_argument("data", "NDArray-or-Symbol", "source input") \
+ .add_argument("scalar", "float", "scalar input")
+
+
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_add)
+.describe(R"code(Add arguments element-wise with broadcasting if necessary.
+
+Example::
+
+ x = [[ 1., 1., 1.],
+ [ 1., 1., 1.]]
+
+ y = [[ 0.],
+ [ 1.]]
+
+ add(x, y) = [[ 1., 1., 1.],
+ [ 2., 2., 2.]]
+
+)code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::plus>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"})
+.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_subtract)
+.describe(R"code(Subtract arguments element-wise with broadcasting if necessary.
+
+Example::
+
+ x = [[ 1., 1., 1.],
+ [ 1., 1., 1.]]
+
+ y = [[ 0.],
+ [ 1.]]
+
+ subtract(x, y) = [[ 1., 1., 1.],
+ [ 0., 0., 0.]]
+
+)code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::minus>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"})
+.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_multiply)
+.describe(R"code(Multiply arguments with broadcasting if necessary.
+
+Example::
+
+ x = [[ 1., 1., 1.],
+ [ 1., 1., 1.]]
+
+ y = [[ 0.],
+ [ 1.]]
+
+ multiply(x, y) = [[ 0., 0., 0.],
+ [ 1., 1., 1.]]
+
+)code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::mul>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"})
+.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_mod)
+.describe(R"code(Return element-wise remainder of division.
+It is equivalent to the Python modulus operator``x1 % x2`` and has the same sign as the divisor x2.
+
+Example::
+
+ x = [[ 8., 8., 8.],
+ [ 8., 8., 8.]]
+
+ y = [[ 2.],
+ [ 3.]]
+
+ mod(x, y) = [[ 0., 0., 0.],
+ [ 2., 2., 2.]]
+
+)code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::mod>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mod"})
+.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_power)
+.describe(R"code(First array elements raised to powers from second array, element-wise.
+
+Raise each base in x1 to the positionally-corresponding power in x2. x1 and x2 must be
+broadcastable to the same shape.
+
+Example::
+
+ x = [[ 1., 1., 1.],
+ [ 1., 1., 1.]]
+
+ y = [[ 0.],
+ [ 1.]]
+
+ power(x, y) = [[ 2., 2., 2.],
+ [ 4., 4., 4.]]
+
+)code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::power>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"})
+.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_add_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::plus>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});
+
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_subtract_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::minus>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});
+
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_rsubtract_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rminus>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"negative"});
+
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_multiply_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::mul>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_mul_scalar"});
+
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_mod_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::mod>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_mod_scalar"});
+
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_rmod_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rmod>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_rmod_scalar"});
+
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_power_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::power>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_power_scalar"});
+
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_rpower_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rpower>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_rpower_scalar"});
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu
new file mode 100644
index 0000000..186bd1b
--- /dev/null
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cu
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_elemwise_broadcast_op.cu
+ * \brief GPU Implementation of basic functions for elementwise binary broadcast operator.
+ */
+#include "../tensor/elemwise_binary_broadcast_op.h"
+#include "../tensor/elemwise_binary_scalar_op.h"
+
+namespace mxnet {
+namespace op {
+NNVM_REGISTER_OP(_np_add)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::plus>);
+
+NNVM_REGISTER_OP(_np_subtract)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::minus>);
+
+NNVM_REGISTER_OP(_np_multiply)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::mul>);
+
+NNVM_REGISTER_OP(_np_mod)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::mod>);
+
+NNVM_REGISTER_OP(_np_power)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::power>);
+
+NNVM_REGISTER_OP(_np_add_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::plus>);
+
+NNVM_REGISTER_OP(_np_subtract_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::minus>);
+
+NNVM_REGISTER_OP(_np_rsubtract_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rminus>);
+
+NNVM_REGISTER_OP(_np_multiply_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::mul>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryScalarOp::ComputeEx<gpu, op::mshadow_op::mul>);
+
+NNVM_REGISTER_OP(_np_mod_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::mod>);
+
+NNVM_REGISTER_OP(_np_rmod_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rmod>);
+
+NNVM_REGISTER_OP(_np_power_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::power>);
+
+NNVM_REGISTER_OP(_np_rpower_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rpower>);
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/numpy/np_init_op.cc b/src/operator/numpy/np_init_op.cc
new file mode 100644
index 0000000..0abd010
--- /dev/null
+++ b/src/operator/numpy/np_init_op.cc
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_init_op.cc
+ * \brief CPU Implementation of numpy init op
+ */
+#include "../tensor/init_op.h"
+#include "../tensor/elemwise_unary_op.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(_np_zeros)
+.describe("Return a new array of given shape, type, and context, filled with zeros.")
+.set_num_inputs(0)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<InitOpParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", InitShape<InitOpParam>)
+.set_attr<nnvm::FInferType>("FInferType", InitType<InitOpParam>)
+.set_attr<FInferStorageType>("FInferStorageType", InitStorageType<InitOpParam, true, true>)
+.set_attr<FCompute>("FCompute<cpu>", FillCompute<cpu, 0>)
+.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
+.add_arguments(InitOpParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_np_ones)
+.describe("Return a new array of given shape, type, and context, filled with ones.")
+.set_num_inputs(0)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<InitOpParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", InitShape<InitOpParam>)
+.set_attr<nnvm::FInferType>("FInferType", InitType<InitOpParam>)
+.set_attr<FCompute>("FCompute<cpu>", FillCompute<cpu, 1>)
+.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
+.add_arguments(InitOpParam::__FIELDS__());
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/numpy/np_init_op.cu b/src/operator/numpy/np_init_op.cu
new file mode 100644
index 0000000..4e6f81d
--- /dev/null
+++ b/src/operator/numpy/np_init_op.cu
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_init_op.cu
+ * \brief GPU Implementation of numpy init op
+ */
+
+#include "../tensor/init_op.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(_np_zeros)
+.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 0>);
+
+NNVM_REGISTER_OP(_np_ones)
+.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 1>);
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/numpy/np_true_divide.cc b/src/operator/numpy/np_true_divide.cc
new file mode 100644
index 0000000..3bafa26
--- /dev/null
+++ b/src/operator/numpy/np_true_divide.cc
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_true_divide.cc
+ * \brief CPU Implementation of true_divide operator.
+ */
+#include "../tensor/elemwise_binary_broadcast_op.h"
+#include "../tensor/elemwise_binary_scalar_op.h"
+
+namespace mxnet {
+namespace op {
+
+template <int num_inputs>
+bool TrueDivideType(const nnvm::NodeAttrs& attrs,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK_EQ(in_attrs->size(), static_cast<size_t>(num_inputs));
+ CHECK_EQ(out_attrs->size(), 1U);
+ for (const int dtype : *in_attrs) {
+ if (dtype == -1) return false;
+ }
+ if (num_inputs == 2) {
+ const int lhs_dtype = in_attrs->at(0);
+ const int rhs_dtype = in_attrs->at(1);
+ CHECK_EQ(lhs_dtype, rhs_dtype)
+ << "_true_divide currently only supports same dtype for dividend and divisor";
+ }
+ auto is_float = [](const int dtype) {
+ return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16;
+ };
+
+ for (const int dtype : *in_attrs) {
+ CHECK(is_float(dtype)) << "_true_divide currently only supports float dtype";
+ }
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
+ return true;
+}
+
+NNVM_REGISTER_OP(_true_divide)
+.describe(R"code(
+Returns a true division of the inputs, element-wise.
+
+It currently only supports dtype float16, float32, and float64.
+
+Example::
+
+ x = [[ 6., 6., 6.],
+ [ 6., 6., 6.]]
+
+ y = [[ 2.],
+ [ 3.]]
+
+ _true_divide(x, y) = [[ 3., 3., 3.],
+ [ 2., 2., 2.]]
+
+)code" ADD_FILELINE)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"lhs", "rhs"};
+ })
+.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
+.set_attr<nnvm::FInferType>("FInferType", TrueDivideType<2>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+ [](const NodeAttrs& attrs){
+ return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
+ })
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::div>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"})
+.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
+.add_argument("lhs", "NDArray-or-Symbol", "Dividend array")
+.add_argument("rhs", "NDArray-or-Symbol", "Divisor array");
+
+NNVM_REGISTER_OP(_true_divide_scalar)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser([](NodeAttrs* attrs) {
+ attrs->parsed = std::stod(attrs->dict["scalar"]);
+ })
+.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", TrueDivideType<1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::pair<int, int> >{{0, 0}};
+ })
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::div>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_div_scalar"})
+.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
+.add_argument("data", "NDArray-or-Symbol", "source input")
+.add_argument("scalar", "float", "scalar input");
+
+NNVM_REGISTER_OP(_rtrue_divide_scalar)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser([](NodeAttrs* attrs) {
+ attrs->parsed = std::stod(attrs->dict["scalar"]);
+ })
+.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", TrueDivideType<1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::pair<int, int> >{{0, 0}};
+ })
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rdiv>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_rdiv_scalar"})
+.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
+.add_argument("data", "NDArray-or-Symbol", "source input")
+.add_argument("scalar", "float", "scalar input");
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/numpy/np_true_divide.cu b/src/operator/numpy/np_true_divide.cu
new file mode 100644
index 0000000..cbc7cf9
--- /dev/null
+++ b/src/operator/numpy/np_true_divide.cu
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_true_divide.cu
+ * \brief GPU Implementation of true_divide operator.
+ */
+#include "../tensor/elemwise_binary_broadcast_op.h"
+#include "../tensor/elemwise_binary_scalar_op.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(_true_divide)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::div>);
+
+NNVM_REGISTER_OP(_true_divide_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::div>);
+
+NNVM_REGISTER_OP(_rtrue_divide_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rdiv>);
+
+} // namespace op
+} // namespace mxnet
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 105b5aa..e95b677 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -36,6 +36,7 @@ from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_sa
from common import run_in_spawned_process
from test_operator import *
from test_numpy_op import *
+from test_numpy_ndarray import *
from test_optimizer import *
from test_random import *
from test_exc_handling import *
diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py
new file mode 100644
index 0000000..88e56ac
--- /dev/null
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -0,0 +1,358 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: skip-file
+from __future__ import absolute_import
+from __future__ import division
+import numpy as _np
+import mxnet as mx
+from mxnet import numpy as np
+from mxnet.gluon import HybridBlock
+from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, assert_exception
+from common import with_seed
+import random
+
+
+@with_seed()
+def test_array_creation():
+ dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, None]
+ objects = [[], (), [[1, 2], [3, 4]],
+ _np.random.uniform(size=rand_shape_nd(3, allow_zero_size=True)),
+ mx.nd.array(_np.random.uniform(size=rand_shape_nd(3, allow_zero_size=True)))]
+ for dtype in dtypes:
+ for src in objects:
+ mx_arr = np.array(src, dtype=dtype)
+ assert mx_arr.context == mx.current_context()
+ if isinstance(src, mx.nd.NDArray):
+ np_arr = _np.array(src.asnumpy(), dtype=dtype)
+ else:
+ np_arr = _np.array(src, dtype=dtype)
+ assert same(mx_arr.asnumpy(), np_arr)
+ assert mx_arr.dtype == np_arr.dtype
+
+
+@with_seed()
+@mx.use_np_compat
+def test_zeros():
+ # test np.zeros in Gluon
+ class TestZeros(HybridBlock):
+ def __init__(self, shape, dtype=None):
+ super(TestZeros, self).__init__()
+ self._shape = shape
+ self._dtype = dtype
+
+ def hybrid_forward(self, F, x, *args, **kwargs):
+ return x + F.np.zeros(shape, dtype)
+
+ class TestZerosOutputType(HybridBlock):
+ def hybrid_forward(self, F, x, *args, **kwargs):
+ return x, F.np.zeros(shape=())
+
+ # test np.zeros in imperative
+ def check_zero_array_creation(shape, dtype):
+ np_out = _np.zeros(shape=shape, dtype=dtype)
+ mx_out = np.zeros(shape=shape, dtype=dtype)
+ assert same(mx_out.asnumpy(), np_out)
+ if dtype is None:
+ assert mx_out.dtype == _np.float32
+ assert np_out.dtype == _np.float64
+
+ shapes = [(0,), (2, 0, 2), (0, 0, 0, 0), ()]
+ shapes += [rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)]
+ dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, None]
+ for shape in shapes:
+ for dtype in dtypes:
+ check_zero_array_creation(shape, dtype)
+ x = mx.nd.array(_np.random.uniform(size=shape), dtype=dtype)
+ if dtype is None:
+ x = x.astype('float32')
+ for hybridize in [True, False]:
+ test_zeros = TestZeros(shape, dtype)
+ test_zeros_output_type = TestZerosOutputType()
+ if hybridize:
+ test_zeros.hybridize()
+ test_zeros_output_type.hybridize()
+ y = test_zeros(x)
+ assert type(y) == np.ndarray
+ assert same(x.asnumpy(), y.asnumpy())
+ y = test_zeros_output_type(x)
+ assert type(y[1]) == np.ndarray
+
+
+@with_seed()
+@mx.use_np_compat
+def test_ones():
+ # test np.ones in Gluon
+ class TestOnes(HybridBlock):
+ def __init__(self, shape, dtype=None):
+ super(TestOnes, self).__init__()
+ self._shape = shape
+ self._dtype = dtype
+
+ def hybrid_forward(self, F, x, *args, **kwargs):
+ return x * F.np.ones(shape, dtype)
+
+ class TestOnesOutputType(HybridBlock):
+ def hybrid_forward(self, F, x, *args, **kwargs):
+ return x, F.np.ones(shape=())
+
+ # test np.ones in imperative
+ def check_ones_array_creation(shape, dtype):
+ np_out = _np.ones(shape=shape, dtype=dtype)
+ mx_out = np.ones(shape=shape, dtype=dtype)
+ assert same(mx_out.asnumpy(), np_out)
+ if dtype is None:
+ assert mx_out.dtype == _np.float32
+ assert np_out.dtype == _np.float64
+
+ shapes = [(0,), (2, 0, 2), (0, 0, 0, 0), ()]
+ shapes += [rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)]
+ dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, None]
+ for shape in shapes:
+ for dtype in dtypes:
+ check_ones_array_creation(shape, dtype)
+ x = mx.nd.array(_np.random.uniform(size=shape), dtype=dtype).as_np_ndarray()
+ if dtype is None:
+ x = x.astype('float32')
+ for hybridize in [True, False]:
+ test_ones = TestOnes(shape, dtype)
+ test_ones_output_type = TestOnesOutputType()
+ if hybridize:
+ test_ones.hybridize()
+ test_ones_output_type.hybridize()
+ y = test_ones(x)
+ assert type(y) == np.ndarray
+ assert same(x.asnumpy(), y.asnumpy())
+ y = test_ones_output_type(x)
+ assert type(y[1]) == np.ndarray
+
+
+@with_seed()
+@mx.use_np_compat
+def test_ndarray_binary_element_wise_ops():
+ # Cannot test operators like >, because boolean arrays are not supported yet.
+ np_op_map = {'+': _np.add, '*': _np.multiply, '-': _np.subtract, '/': _np.divide,
+ 'mod': _np.mod, 'pow': _np.power,
+ # '>': _np.greater, '>=': _np.greater_equal,
+ # '<': _np.less, '<=': _np.less_equal
+ }
+
+ def get_np_ret(x1, x2, op):
+ return np_op_map[op](x1, x2)
+
+ class TestBinaryElementWiseOp(HybridBlock):
+ def __init__(self, op, scalar=None, reverse=False):
+ super(TestBinaryElementWiseOp, self).__init__()
+ self._op = op
+ self._scalar = scalar
+ self._reverse = reverse # if false, scalar is the right operand.
+
+ def hybrid_forward(self, F, x, *args):
+ if self._op == '+':
+ if self._scalar is not None:
+ return x + self._scalar if not self._reverse else self._scalar + x
+ else:
+ return x + args[0] if not self._reverse else args[0] + x
+ elif self._op == '*':
+ if self._scalar is not None:
+ return x * self._scalar if not self._reverse else self._scalar * x
+ else:
+ return x * args[0] if not self._reverse else args[0] * x
+ elif self._op == '-':
+ if self._scalar is not None:
+ return x - self._scalar if not self._reverse else self._scalar - x
+ else:
+ return x - args[0] if not self._reverse else args[0] - x
+ elif self._op == '/':
+ if self._scalar is not None:
+ return x / self._scalar if not self._reverse else self._scalar / x
+ else:
+ return x / args[0] if not self._reverse else args[0] / x
+ elif self._op == 'mod':
+ if self._scalar is not None:
+ return x % self._scalar if not self._reverse else self._scalar % x
+ else:
+ return x % args[0] if not self._reverse else args[0] % x
+ elif self._op == 'pow':
+ if self._scalar is not None:
+ return x ** self._scalar if not self._reverse else self._scalar ** x
+ else:
+ return x ** args[0] if not self._reverse else args[0] ** x
+ elif self._op == '>':
+ if self._scalar is not None:
+ return x > self._scalar
+ else:
+ return x > args[0]
+ elif self._op == '>=':
+ if self._scalar is not None:
+ return x >= self._scalar
+ else:
+ return x >= args[0]
+ elif self._op == '<':
+ if self._scalar is not None:
+ return x < self._scalar
+ else:
+ return x < args[0]
+ elif self._op == '<=':
+ if self._scalar is not None:
+ return x <= self._scalar
+ else:
+ return x <= args[0]
+ else:
+ print(self._op)
+ assert False
+
+ def check_binary_op_result(shape1, shape2, op, dtype=None):
+ if shape1 is None:
+ mx_input1 = abs(_np.random.uniform()) + 1
+ np_input1 = mx_input1
+ else:
+ mx_input1 = rand_ndarray(shape1, dtype=dtype).abs() + 1
+ np_input1 = mx_input1.asnumpy()
+ if shape2 is None:
+ mx_input2 = abs(_np.random.uniform()) + 1
+ np_input2 = mx_input2
+ else:
+ mx_input2 = rand_ndarray(shape2, dtype=dtype).abs() + 1
+ np_input2 = mx_input2.asnumpy()
+
+ scalar = None
+ reverse = False
+ if isinstance(mx_input1, mx.nd.NDArray) and not isinstance(mx_input2, mx.nd.NDArray):
+ scalar = mx_input2
+ reverse = False
+ elif isinstance(mx_input2, mx.nd.NDArray) and not isinstance(mx_input1, mx.nd.NDArray):
+ scalar = mx_input1
+ reverse = True
+
+ np_out = get_np_ret(np_input1, np_input2, op)
+ for hybridize in [True, False]:
+ if scalar is None:
+ get_mx_ret = TestBinaryElementWiseOp(op)
+ if hybridize:
+ get_mx_ret.hybridize()
+ mx_out = get_mx_ret(mx_input1.as_np_ndarray(), mx_input2.as_np_ndarray())
+ assert type(mx_out) == np.ndarray
+ assert np_out.shape == mx_out.shape
+ assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5)
+
+ mx_out = get_mx_ret(mx_input1, mx_input2.as_np_ndarray())
+ assert type(mx_out) == np.ndarray
+ assert np_out.shape == mx_out.shape
+ assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5)
+
+ mx_out = get_mx_ret(mx_input1.as_np_ndarray(), mx_input2)
+ assert type(mx_out) == np.ndarray
+ assert np_out.shape == mx_out.shape
+ assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5)
+ else:
+ get_mx_ret = TestBinaryElementWiseOp(op, scalar=scalar, reverse=reverse)
+ if hybridize:
+ get_mx_ret.hybridize()
+ if reverse:
+ mx_out = get_mx_ret(mx_input2.as_np_ndarray())
+ assert type(mx_out) == np.ndarray
+ else:
+ mx_out = get_mx_ret(mx_input1.as_np_ndarray())
+ assert type(mx_out) == np.ndarray
+ assert np_out.shape == mx_out.shape
+ assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5)
+
+ dtypes = [_np.float32, _np.float64, None]
+ ops = np_op_map.keys()
+ for dtype in dtypes:
+ for op in ops:
+ check_binary_op_result((3, 4), (3, 4), op, dtype)
+ check_binary_op_result(None, (3, 4), op, dtype)
+ check_binary_op_result((3, 4), None, op, dtype)
+ check_binary_op_result((1, 4), (3, 1), op, dtype)
+ check_binary_op_result(None, (3, 1), op, dtype)
+ check_binary_op_result((1, 4), None, op, dtype)
+ check_binary_op_result((1, 4), (3, 5, 4), op, dtype)
+ check_binary_op_result((), (3, 5, 4), op, dtype)
+ check_binary_op_result((), None, op, dtype)
+ check_binary_op_result(None, (), op, dtype)
+ check_binary_op_result((0, 2), (1, 1), op, dtype)
+ check_binary_op_result((0, 2), None, op, dtype)
+ check_binary_op_result(None, (0, 2), op, dtype)
+
+
+@with_seed()
+def test_np_op_output_type():
+ # test imperative invoke
+ data = np.array([1., 3.], dtype='float32')
+ ret = np.sum(data)
+ assert type(ret) == np.ndarray
+ ret = mx.nd.sin(data)
+ assert type(ret) == mx.nd.NDArray
+
+ # test cached op
+ class TestCachedOpOutputType(HybridBlock):
+ @mx.use_np_compat
+ def hybrid_forward(self, F, x, *args, **kwargs):
+ ret1 = F.sin(x)
+ ret2 = F.np.sum(x)
+ return ret1, ret2
+
+ net = TestCachedOpOutputType()
+ for hybridize in [True, False]:
+ if hybridize:
+ net.hybridize()
+ ret1, ret2 = net(data)
+ assert type(ret1) == mx.nd.NDArray
+ assert type(ret2) == np.ndarray
+
+
+@with_seed()
+def test_grad_ndarray_type():
+ data = np.array(2, dtype=_np.float32)
+ data.attach_grad()
+ assert type(data.grad) == np.ndarray
+ assert type(data.detach()) == np.ndarray
+
+
+@with_seed()
+def test_np_ndarray_astype():
+ mx_data = np.array([2, 3, 4, 5], dtype=_np.int32)
+ np_data = mx_data.asnumpy()
+
+ def check_astype_equal(dtype, copy, expect_zero_copy=False):
+ mx_ret = mx_data.astype(dtype=dtype, copy=copy)
+ np_ret = np_data.astype(dtype=dtype, copy=copy)
+ assert mx_ret.dtype == np_ret.dtype
+ assert same(mx_ret.asnumpy(), np_ret)
+ if expect_zero_copy:
+ assert id(mx_ret) == id(mx_data)
+ assert id(np_ret) == id(np_data)
+
+ for dtype in [_np.int8, _np.uint8, _np.int32, _np.float16, _np.float32, _np.float64]:
+ for copy in [True, False]:
+ check_astype_equal(dtype, copy, copy is False and mx_data.dtype == dtype)
+
+
+@with_seed()
+def test_np_ndarray_copy():
+ mx_data = np.array([2, 3, 4, 5], dtype=_np.int32)
+ assert_exception(mx_data.copy, NotImplementedError, order='F')
+ mx_ret = mx_data.copy()
+ np_ret = mx_data.asnumpy().copy()
+ assert same(mx_ret.asnumpy(), np_ret)
+
+
+if __name__ == '__main__':
+ import nose
+ nose.runmodule()