You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2019/09/24 20:56:06 UTC
[incubator-mxnet] branch numpy_prs updated: Numpy Operators: Inner,
Outer, vdot (#15846)
This is an automated email from the ASF dual-hosted git repository.
reminisce pushed a commit to branch numpy_prs
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/numpy_prs by this push:
new f3f6b1d Numpy Operators: Inner, Outer, vdot (#15846)
f3f6b1d is described below
commit f3f6b1da3366ec3dbf537110688d9a7aafd6a8ce
Author: ckt624 <ck...@gmail.com>
AuthorDate: Tue Sep 24 16:55:21 2019 -0400
Numpy Operators: Inner, Outer, vdot (#15846)
* Implements tensordot and dot.
Change tests.
Add spaces.
Reorganize codes.
Implements inner, outer, vdot.
Remove spaces.
Change tests.
change test format.
Change indent.
Change styles
* Fix
---
python/mxnet/ndarray/numpy/_op.py | 163 +++++++++++++++++++++++++--
python/mxnet/numpy/multiarray.py | 149 ++++++++++++++++++++++++-
python/mxnet/symbol/numpy/_symbol.py | 155 +++++++++++++++++++++++++-
tests/python/unittest/test_numpy_op.py | 195 +++++++++++++++++++++++++++++++++
4 files changed, 651 insertions(+), 11 deletions(-)
diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 99ef61b..7afe337 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -35,7 +35,7 @@ __all__ = ['zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mo
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
- 'unique', 'ldexp']
+ 'unique', 'ldexp', 'vdot', 'inner', 'outer']
@set_module('mxnet.ndarray.numpy')
@@ -1702,7 +1702,7 @@ def tan(x, out=None, where=True, **kwargs):
Parameters:
----------
- x : array_like
+ x : ndarray
Input array.
out : ndarray, None, or tuple of ndarray and None, optional
A location into which the result is stored. If provided,
@@ -2122,7 +2122,7 @@ def concatenate(seq, axis=0, out=None):
"""Join a sequence of arrays along an existing axis.
Parameters
----------
- a1, a2, ... : sequence of array_like
+ a1, a2, ... : sequence of ndarray
The arrays must have the same shape, except in the dimension
corresponding to `axis` (the first, by default).
axis : int, optional
@@ -2147,7 +2147,7 @@ def stack(arrays, axis=0, out=None):
For example, if `axis=0` it will be the first dimension and if `axis=-1` it will be the last dimension.
Parameters
----------
- arrays : sequence of array_like
+ arrays : sequence of ndarray
Each array must have the same shape.
axis : int, optional
The axis in the result array along which the input arrays are stacked.
@@ -2309,7 +2309,7 @@ def clip(a, a_min, a_max, out=None):
Notes
-----
- array_like `a_min` and `a_max` are not supported.
+ ndarray `a_min` and `a_max` are not supported.
Examples
--------
@@ -2468,7 +2468,7 @@ def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint:
Parameters
----------
- a : array_like
+ a : ndarray
Calculate the standard deviation of these values.
axis : None or int or tuple of ints, optional
Axis or axes along which the standard deviation is computed. The
@@ -2535,7 +2535,7 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint:
Parameters
----------
- a : array_like
+ a : ndarray
Array containing numbers whose variance is desired. If `a` is not an
array, a conversion is attempted.
axis : None or int or tuple of ints, optional
@@ -3229,7 +3229,7 @@ def hypot(x1, x2, out=None):
Parameters
----------
- x1, x2 : array_like
+ x1, x2 : ndarray
Leg of the triangle(s).
out : ndarray, None, or tuple of ndarray and None, optional
A location into which the result is stored. If provided, it must have
@@ -3301,3 +3301,150 @@ def ldexp(x1, x2, out=None):
array([ 5., 10., 20., 40.])
"""
return _ufunc_helper(x1, x2, _npi.ldexp, _np.ldexp, _npi.ldexp_scalar, _npi.rldexp_scalar, out)
+
+
+@set_module('mxnet.ndarray.numpy')
+def inner(a, b):
+ r"""
+ Inner product of two arrays.
+ Ordinary inner product of vectors for 1-D arrays (without complex
+ conjugation), in higher dimensions a sum product over the last axes.
+
+ Parameters
+ ----------
+ a, b : ndarray
+ If `a` and `b` are nonscalar, their last dimensions must match.
+
+ Returns
+ -------
+ out : ndarray
+ `out.shape = a.shape[:-1] + b.shape[:-1]`
+
+ Raises
+ ------
+ ValueError
+ If the last dimension of `a` and `b` has different size.
+
+ See Also
+ --------
+ tensordot : Sum products over arbitrary axes.
+ dot : Generalised matrix product, using second last dimension of `b`.
+ einsum : Einstein summation convention.
+
+ Notes
+ -----
+ For vectors (1-D arrays) it computes the ordinary inner-product::
+ np.inner(a, b) = sum(a[:]*b[:])
+ More generally, if `ndim(a) = r > 0` and `ndim(b) = s > 0`::
+ np.inner(a, b) = np.tensordot(a, b, axes=(-1,-1))
+ or explicitly::
+ np.inner(a, b)[i0,...,ir-1,j0,...,js-1]
+ = sum(a[i0,...,ir-1,:]*b[j0,...,js-1,:])
+ In addition `a` or `b` may be scalars, in which case::
+ np.inner(a,b) = a*b
+
+ Examples
+ --------
+ Ordinary inner product for vectors:
+ >>> a = np.array([1,2,3])
+ >>> b = np.array([0,1,0])
+ >>> np.inner(a, b)
+ 2
+ A multidimensional example:
+ >>> a = np.arange(24).reshape((2,3,4))
+ >>> b = np.arange(4)
+ >>> np.inner(a, b)
+ array([[ 14, 38, 62],
+ [ 86, 110, 134]])
+ """
+ return tensordot(a, b, [-1, -1])
+
+
+@set_module('mxnet.ndarray.numpy')
+def outer(a, b):
+ r"""
+ Compute the outer product of two vectors.
+ Given two vectors, ``a = [a0, a1, ..., aM]`` and
+ ``b = [b0, b1, ..., bN]``,
+ the outer product [1]_ is::
+ [[a0*b0 a0*b1 ... a0*bN ]
+ [a1*b0 .
+ [ ... .
+ [aM*b0 aM*bN ]]
+
+ Parameters
+ ----------
+ a : (M,) ndarray
+ First input vector. Input is flattened if
+ not already 1-dimensional.
+ b : (N,) ndarray
+ Second input vector. Input is flattened if
+ not already 1-dimensional.
+
+ Returns
+ -------
+ out : (M, N) ndarray
+ ``out[i, j] = a[i] * b[j]``
+ See also
+ --------
+ inner
+ einsum : ``einsum('i,j->ij', a.ravel(), b.ravel())`` is the equivalent.
+ ufunc.outer : A generalization to N dimensions and other operations.
+ ``np.multiply.outer(a.ravel(), b.ravel())`` is the equivalent.
+ References
+ ----------
+ .. [1] : G. H. Golub and C. F. Van Loan, *Matrix Computations*, 3rd
+ ed., Baltimore, MD, Johns Hopkins University Press, 1996,
+ pg. 8.
+ Examples
+ --------
+ Make a (*very* coarse) grid for computing a Mandelbrot set:
+ >>> rl = np.outer(np.ones((5,)), np.linspace(-2, 2, 5))
+ >>> rl
+ array([[-2., -1., 0., 1., 2.],
+ [-2., -1., 0., 1., 2.],
+ [-2., -1., 0., 1., 2.],
+ [-2., -1., 0., 1., 2.],
+ [-2., -1., 0., 1., 2.]])
+ """
+ return tensordot(a.flatten(), b.flatten(), 0)
+
+
+@set_module('mxnet.ndarray.numpy')
+def vdot(a, b):
+ r"""
+ Return the dot product of two vectors.
+ Note that `vdot` handles multidimensional arrays differently than `dot`:
+ it does *not* perform a matrix product, but flattens input arguments
+ to 1-D vectors first. Consequently, it should only be used for vectors.
+
+ Parameters
+ ----------
+ a : ndarray
+ First argument to the dot product.
+ b : ndarray
+ Second argument to the dot product.
+
+ Returns
+ -------
+ output : ndarray
+ Dot product of `a` and `b`.
+
+ See Also
+ --------
+ dot : Return the dot product without using the complex conjugate of the
+ first argument.
+
+ Examples
+ --------
+ Note that higher-dimensional arrays are flattened!
+ >>> a = np.array([[1, 4], [5, 6]])
+ >>> b = np.array([[4, 1], [2, 2]])
+ >>> np.vdot(a, b)
+ 30
+ >>> np.vdot(b, a)
+ 30
+ >>> 1*4 + 4*1 + 5*2 + 6*2
+ 30
+ """
+ return tensordot(a.flatten(), b.flatten(), 1)
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 3fd5801..7a68c16 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -54,7 +54,7 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'full', 'add', 'subtrac
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
- 'rad2deg', 'deg2rad', 'unique', 'ldexp']
+ 'rad2deg', 'deg2rad', 'unique', 'ldexp', 'vdot', 'inner', 'outer']
# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
@@ -4830,3 +4830,150 @@ def ldexp(x1, x2, out=None):
array([ 5., 10., 20., 40.])
"""
return _mx_nd_np.ldexp(x1, x2, out)
+
+
+@set_module('mxnet.numpy')
+def inner(a, b):
+ r"""Inner product of two arrays.
+ Ordinary inner product of vectors for 1-D arrays (without complex
+ conjugation), in higher dimensions a sum product over the last axes.
+
+ Parameters
+ ----------
+ a, b : ndarray
+ If `a` and `b` are nonscalar, their last dimensions must match.
+
+ Returns
+ -------
+ out : ndarray
+ `out.shape = a.shape[:-1] + b.shape[:-1]`
+
+ Raises
+ ------
+ ValueError
+ If the last dimension of `a` and `b` has different size.
+
+ See Also
+ --------
+ tensordot : Sum products over arbitrary axes.
+ dot : Generalised matrix product, using second last dimension of `b`.
+ einsum : Einstein summation convention.
+
+ Notes
+ -----
+ For vectors (1-D arrays) it computes the ordinary inner-product::
+ np.inner(a, b) = sum(a[:]*b[:])
+ More generally, if `ndim(a) = r > 0` and `ndim(b) = s > 0`::
+ np.inner(a, b) = np.tensordot(a, b, axes=(-1,-1))
+ or explicitly::
+ np.inner(a, b)[i0,...,ir-1,j0,...,js-1]
+ = sum(a[i0,...,ir-1,:]*b[j0,...,js-1,:])
+ In addition `a` or `b` may be scalars, in which case::
+ np.inner(a,b) = a*b
+
+ Examples
+ --------
+ Ordinary inner product for vectors:
+ >>> a = np.array([1,2,3])
+ >>> b = np.array([0,1,0])
+ >>> np.inner(a, b)
+ 2
+ A multidimensional example:
+ >>> a = np.arange(24).reshape((2,3,4))
+ >>> b = np.arange(4)
+ >>> np.inner(a, b)
+ array([[ 14, 38, 62],
+ [ 86, 110, 134]])
+ """
+ return tensordot(a, b, [-1, -1])
+
+
+@set_module('mxnet.numpy')
+def outer(a, b):
+ r"""Compute the outer product of two vectors.
+ Given two vectors, ``a = [a0, a1, ..., aM]`` and
+ ``b = [b0, b1, ..., bN]``,
+ the outer product [1]_ is::
+ [[a0*b0 a0*b1 ... a0*bN ]
+ [a1*b0 .
+ [ ... .
+ [aM*b0 aM*bN ]]
+
+ Parameters
+ ----------
+ a : (M,) ndarray
+ First input vector. Input is flattened if
+ not already 1-dimensional.
+ b : (N,) ndarray
+ Second input vector. Input is flattened if
+ not already 1-dimensional.
+
+ Returns
+ -------
+ out : (M, N) ndarray
+ ``out[i, j] = a[i] * b[j]``
+ See also
+ --------
+ inner
+ einsum : ``einsum('i,j->ij', a.ravel(), b.ravel())`` is the equivalent.
+ ufunc.outer : A generalization to N dimensions and other operations.
+ ``np.multiply.outer(a.ravel(), b.ravel())`` is the equivalent.
+
+ References
+ ----------
+ .. [1] : G. H. Golub and C. F. Van Loan, *Matrix Computations*, 3rd
+ ed., Baltimore, MD, Johns Hopkins University Press, 1996,
+ pg. 8.
+
+ Examples
+ --------
+ Make a (*very* coarse) grid for computing a Mandelbrot set:
+ >>> rl = np.outer(np.ones((5,)), np.linspace(-2, 2, 5))
+ >>> rl
+ array([[-2., -1., 0., 1., 2.],
+ [-2., -1., 0., 1., 2.],
+ [-2., -1., 0., 1., 2.],
+ [-2., -1., 0., 1., 2.],
+ [-2., -1., 0., 1., 2.]])
+ """
+ return tensordot(a.flatten(), b.flatten(), 0)
+
+
+@set_module('mxnet.numpy')
+def vdot(a, b):
+ r"""
+ Return the dot product of two vectors.
+ Note that `vdot` handles multidimensional arrays differently than `dot`:
+ it does *not* perform a matrix product, but flattens input arguments
+ to 1-D vectors first. Consequently, it should only be used for vectors.
+
+ Parameters
+ ----------
+ a : ndarray
+ First argument to the dot product.
+ b : ndarray
+ Second argument to the dot product.
+
+ Returns
+ -------
+ output : ndarray
+ Dot product of `a` and `b`.
+
+ See Also
+ --------
+ dot : Return the dot product without using the complex conjugate of the
+ first argument.
+
+ Examples
+ --------
+ Note that higher-dimensional arrays are flattened!
+ >>> a = np.array([[1, 4], [5, 6]])
+ >>> b = np.array([[4, 1], [2, 2]])
+ >>> np.vdot(a, b)
+ 30
+ >>> np.vdot(b, a)
+ 30
+ >>> 1*4 + 4*1 + 5*2 + 6*2
+ 30
+ """
+ return tensordot(a.flatten(), b.flatten(), 1)
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index af1eaed..73a6726 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -37,7 +37,7 @@ __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'rem
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
- 'unique', 'ldexp']
+ 'unique', 'ldexp', 'vdot', 'inner', 'outer']
def _num_outputs(sym):
@@ -1469,6 +1469,7 @@ def absolute(x, out=None, **kwargs):
r"""
Calculate the absolute value element-wise.
np.abs is a shorthand for this function.
+
Parameters
----------
x : _Symbol
@@ -3397,10 +3398,10 @@ def unique(ar, return_index=False, return_inverse=False, return_counts=False, ax
@set_module('mxnet.symbol.numpy')
def ldexp(x1, x2, out=None):
"""
- ldexp(x1, x2, out=None)
Returns x1 * 2**x2, element-wise.
The mantissas `x1` and twos exponents `x2` are used to construct
floating point numbers ``x1 * 2**x2``.
+
Parameters
----------
x1 : _Symbol
@@ -3409,10 +3410,12 @@ def ldexp(x1, x2, out=None):
Array of twos exponents.
out : _Symbol or None
Dummy parameter to keep the consistency with the ndarray counterpart.
+
Returns
-------
y : _Symbol
The result of ``x1 * 2**x2``.
+
Notes
-----
Complex dtypes are not supported, they will raise a TypeError.
@@ -3423,4 +3426,152 @@ def ldexp(x1, x2, out=None):
return _ufunc_helper(x1, x2, _npi.ldexp, _np.ldexp, _npi.ldexp_scalar, _npi.rldexp_scalar, out)
+@set_module('mxnet.symbol.numpy')
+def inner(a, b):
+ r"""Inner product of two arrays.
+ Ordinary inner product of vectors for 1-D arrays (without complex
+ conjugation), in higher dimensions a sum product over the last axes.
+
+ Parameters
+ ----------
+ a, b : _Symbol
+ If `a` and `b` are nonscalar, their last dimensions must match.
+
+ Returns
+ -------
+ out : _Symbol
+ `out.shape = a.shape[:-1] + b.shape[:-1]`
+
+ Raises
+ ------
+ ValueError
+ If the last dimension of `a` and `b` has different size.
+
+ See Also
+ --------
+ tensordot : Sum products over arbitrary axes.
+ dot : Generalised matrix product, using second last dimension of `b`.
+ einsum : Einstein summation convention.
+
+ Notes
+ -----
+ For vectors (1-D arrays) it computes the ordinary inner-product::
+ np.inner(a, b) = sum(a[:]*b[:])
+ More generally, if `ndim(a) = r > 0` and `ndim(b) = s > 0`::
+ np.inner(a, b) = np.tensordot(a, b, axes=(-1,-1))
+ or explicitly::
+ np.inner(a, b)[i0,...,ir-1,j0,...,js-1]
+ = sum(a[i0,...,ir-1,:]*b[j0,...,js-1,:])
+ In addition `a` or `b` may be scalars, in which case::
+ np.inner(a,b) = a*b
+
+ Examples
+ --------
+ Ordinary inner product for vectors:
+ >>> a = np.array([1,2,3])
+ >>> b = np.array([0,1,0])
+ >>> np.inner(a, b)
+ 2
+ A multidimensional example:
+ >>> a = np.arange(24).reshape((2,3,4))
+ >>> b = np.arange(4)
+ >>> np.inner(a, b)
+ array([[ 14, 38, 62],
+ [ 86, 110, 134]])
+ """
+ return tensordot(a, b, [-1, -1])
+
+
+@set_module('mxnet.symbol.numpy')
+def outer(a, b):
+ r"""Compute the outer product of two vectors.
+ Given two vectors, ``a = [a0, a1, ..., aM]`` and
+ ``b = [b0, b1, ..., bN]``,
+ the outer product [1]_ is::
+ [[a0*b0 a0*b1 ... a0*bN ]
+ [a1*b0 .
+ [ ... .
+ [aM*b0 aM*bN ]]
+
+ Parameters
+ ----------
+ a : (M,) ndarray
+ First input vector. Input is flattened if
+ not already 1-dimensional.
+ b : (N,) ndarray
+ Second input vector. Input is flattened if
+ not already 1-dimensional.
+
+ Returns
+ -------
+ out : (M, N) ndarray
+ ``out[i, j] = a[i] * b[j]``
+
+ See also
+ --------
+ inner
+ einsum : ``einsum('i,j->ij', a.ravel(), b.ravel())`` is the equivalent.
+ ufunc.outer : A generalization to N dimensions and other operations.
+ ``np.multiply.outer(a.ravel(), b.ravel())`` is the equivalent.
+
+ References
+ ----------
+ .. [1] : G. H. Golub and C. F. Van Loan, *Matrix Computations*, 3rd
+ ed., Baltimore, MD, Johns Hopkins University Press, 1996,
+ pg. 8.
+
+ Examples
+ --------
+ Make a (*very* coarse) grid for computing a Mandelbrot set:
+ >>> rl = np.outer(np.ones((5,)), np.linspace(-2, 2, 5))
+ >>> rl
+ array([[-2., -1., 0., 1., 2.],
+ [-2., -1., 0., 1., 2.],
+ [-2., -1., 0., 1., 2.],
+ [-2., -1., 0., 1., 2.],
+ [-2., -1., 0., 1., 2.]])
+ """
+ return tensordot(a.flatten(), b.flatten(), 0)
+
+
+@set_module('mxnet.symbol.numpy')
+def vdot(a, b):
+ r"""
+ Return the dot product of two vectors.
+ Note that `vdot` handles multidimensional arrays differently than `dot`:
+ it does *not* perform a matrix product, but flattens input arguments
+ to 1-D vectors first. Consequently, it should only be used for vectors.
+
+ Parameters
+ ----------
+ a : _Symbol
+ First argument to the dot product.
+ b : _Symbol
+ Second argument to the dot product.
+
+ Returns
+ -------
+ output : _Symbol
+ Dot product of `a` and `b`.
+
+ See Also
+ --------
+ dot : Return the dot product without using the complex conjugate of the
+ first argument.
+
+ Examples
+ --------
+ Note that higher-dimensional arrays are flattened!
+ >>> a = np.array([[1, 4], [5, 6]])
+ >>> b = np.array([[4, 1], [2, 2]])
+ >>> np.vdot(a, b)
+ 30
+ >>> np.vdot(b, a)
+ 30
+ >>> 1*4 + 4*1 + 5*2 + 6*2
+ 30
+ """
+ return tensordot(a.flatten(), b.flatten(), 1)
+
+
_set_np_symbol_class(_Symbol)
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index af3b430..264f7c0 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -281,6 +281,201 @@ def test_np_ldexp():
@with_seed()
@use_np
+def test_np_vdot():
+ class TestVdot(HybridBlock):
+ def __init__(self):
+ super(TestVdot, self).__init__()
+
+ def hybrid_forward(self, F, a, b):
+ return F.np.vdot(a, b)
+
+ def vdot_backward(a, b):
+ return [b, a]
+
+ # test different size inputs
+ tensor_shapes = [(), (5,), (3, 3)]
+
+ for hybridize in [True, False]:
+ for shape in tensor_shapes:
+ for dtype in [_np.float32, _np.float64]:
+ test_vdot = TestVdot()
+ if hybridize:
+ test_vdot.hybridize()
+ a = rand_ndarray(shape=shape, dtype=dtype).as_np_ndarray()
+ b = rand_ndarray(shape=shape, dtype=dtype).as_np_ndarray()
+ a.attach_grad()
+ b.attach_grad()
+
+ np_out = _np.vdot(a.asnumpy(), b.asnumpy())
+ with mx.autograd.record():
+ mx_out = test_vdot(a, b)
+ assert mx_out.shape == np_out.shape
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol = 1e-3, atol = 1e-5)
+ mx_out.backward()
+ np_backward = vdot_backward(a.asnumpy(), b.asnumpy())
+ assert_almost_equal(a.grad.asnumpy(), np_backward[0], rtol = 1e-2, atol=1e-2)
+ assert_almost_equal(b.grad.asnumpy(), np_backward[1], rtol = 1e-2, atol=1e-2)
+
+ # Test imperative once again
+ mx_out = np.vdot(a, b)
+ np_out = _np.vdot(a.asnumpy(), b.asnumpy())
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+
+ # test numeric gradient
+ if len(shape) > 0 and _np.prod(shape) > 0:
+ a_sym = mx.sym.Variable("a").as_np_ndarray()
+ b_sym = mx.sym.Variable("b").as_np_ndarray()
+ mx_sym = mx.sym.np.vdot(a_sym, b_sym).as_nd_ndarray()
+ check_numeric_gradient(mx_sym, [a.as_nd_ndarray(), b.as_nd_ndarray()],
+ rtol=1e-1, atol=1e-1, dtype=dtype)
+
+
+@with_seed()
+@use_np
+def test_np_inner():
+ class TestInner(HybridBlock):
+ def __init__(self):
+ super(TestInner, self).__init__()
+
+ def hybrid_forward(self, F, a, b):
+ return F.np.inner(a, b)
+
+ def inner_backward(a, b):
+ a_axes_summed = [a.ndim - 1]
+ b_axes_summed = [b.ndim - 1]
+
+ a_axes_remained = []
+ for i in range(a.ndim):
+ if not (i in a_axes_summed):
+ a_axes_remained.append(i)
+ a_axes = a_axes_remained[:] + a_axes_summed[:]
+
+ b_axes_remained = []
+ for i in range(b.ndim):
+ if not (i in b_axes_summed):
+ b_axes_remained.append(i)
+ b_axes = b_axes_summed[:] + b_axes_remained[:]
+
+ ad1 = _np.prod([a.shape[i] for i in a_axes_remained]) if len(a_axes_remained) > 0 else 1
+ ad2 = _np.prod([a.shape[i] for i in a_axes_summed]) if len(a_axes_summed) > 0 else 1
+ bd1 = _np.prod([b.shape[i] for i in b_axes_summed]) if len(b_axes_summed) > 0 else 1
+ bd2 = _np.prod([b.shape[i] for i in b_axes_remained]) if len(b_axes_remained) > 0 else 1
+
+ out_grad = _np.ones((ad1, bd2))
+
+ new_a = _np.transpose(a, a_axes)
+ new_a_shape = new_a.shape[:]
+ new_a = new_a.reshape((ad1, ad2))
+ new_b = _np.transpose(b, b_axes)
+ new_b_shape = new_b.shape[:]
+ new_b = new_b.reshape((bd1, bd2))
+
+ reverse_a_axes = [0 for i in a_axes]
+ for i in range(len(a_axes)):
+ reverse_a_axes[a_axes[i]] = i
+
+ reverse_b_axes = [0 for i in b_axes]
+ for i in range(len(b_axes)):
+ reverse_b_axes[b_axes[i]] = i
+
+ grad_b = _np.dot(new_a.T, out_grad).reshape(new_b_shape)
+ grad_b = _np.transpose(grad_b, reverse_b_axes)
+ grad_a = _np.dot(out_grad, new_b.T).reshape(new_a_shape)
+ grad_a = _np.transpose(grad_a, reverse_a_axes)
+
+ return [grad_a, grad_b]
+
+ # test non zero size input
+ tensor_shapes = [
+ ((3,), (3,)),
+ ((2, 3), (3,)),
+ ((3,), (2, 3))
+ ]
+
+ for hybridize in [True, False]:
+ for a_shape, b_shape in tensor_shapes:
+ for dtype in [_np.float32, _np.float64]:
+ test_inner = TestInner()
+ if hybridize:
+ test_inner.hybridize()
+ a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray()
+ b = rand_ndarray(shape=b_shape, dtype=dtype).as_np_ndarray()
+ a.attach_grad()
+ b.attach_grad()
+
+ np_out = _np.inner(a.asnumpy(), b.asnumpy())
+ with mx.autograd.record():
+ mx_out = test_inner(a, b)
+ assert mx_out.shape == np_out.shape
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol = 1e-3, atol = 1e-5)
+ mx_out.backward()
+ np_backward = inner_backward(a.asnumpy(), b.asnumpy())
+ assert_almost_equal(a.grad.asnumpy(), np_backward[0], rtol = 1e-2, atol=1e-2)
+ assert_almost_equal(b.grad.asnumpy(), np_backward[1], rtol = 1e-2, atol=1e-2)
+
+ # Test imperative once again
+ mx_out = np.inner(a, b)
+ np_out = _np.inner(a.asnumpy(), b.asnumpy())
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+
+ # test numeric gradient
+ a_sym = mx.sym.Variable("a").as_np_ndarray()
+ b_sym = mx.sym.Variable("b").as_np_ndarray()
+ mx_sym = mx.sym.np.inner(a_sym, b_sym).as_nd_ndarray()
+ check_numeric_gradient(mx_sym, [a.as_nd_ndarray(), b.as_nd_ndarray()],
+ rtol=1e-1, atol=1e-1, dtype=dtype)
+
+
+@with_seed()
+@use_np
+def test_np_outer():
+ class TestOuter(HybridBlock):
+ def __init__(self):
+ super(TestOuter, self).__init__()
+
+ def hybrid_forward(self, F, a, b):
+ return F.np.outer(a, b)
+
+ # test non zero size input
+ tensor_shapes = [
+ ((3,), (3,)),
+ ((2, 3), (6,)),
+ ((6,), (2, 3))
+ ]
+
+ for hybridize in [True, False]:
+ for a_shape, b_shape in tensor_shapes:
+ for dtype in [_np.float32, _np.float64]:
+ test_outer = TestOuter()
+ if hybridize:
+ test_outer.hybridize()
+ a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray()
+ b = rand_ndarray(shape=b_shape, dtype=dtype).as_np_ndarray()
+ a.attach_grad()
+ b.attach_grad()
+
+ np_out = _np.outer(a.asnumpy(), b.asnumpy())
+ with mx.autograd.record():
+ mx_out = test_outer(a, b)
+ assert mx_out.shape == np_out.shape
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+ mx_out.backward()
+
+ # Test imperative once again
+ mx_out = np.outer(a, b)
+ np_out = _np.outer(a.asnumpy(), b.asnumpy())
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+
+ # test numeric gradient
+ a_sym = mx.sym.Variable("a").as_np_ndarray()
+ b_sym = mx.sym.Variable("b").as_np_ndarray()
+ mx_sym = mx.sym.np.outer(a_sym, b_sym).as_nd_ndarray()
+ check_numeric_gradient(mx_sym, [a.as_nd_ndarray(), b.as_nd_ndarray()],
+ rtol=1e-1, atol=1e-1, dtype=dtype)
+
+
+@with_seed()
+@use_np
def test_np_sum():
class TestSum(HybridBlock):
def __init__(self, axis=None, dtype=None, keepdims=False):