You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2019/09/24 19:58:09 UTC

[incubator-mxnet] branch numpy_prs updated: Implements ldexp. (#15845)

This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch numpy_prs
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/numpy_prs by this push:
     new 3e4a2a4  Implements ldexp. (#15845)
3e4a2a4 is described below

commit 3e4a2a4d254cd17247ca27f2d622fcb1de271537
Author: ckt624 <ck...@gmail.com>
AuthorDate: Tue Sep 24 15:57:21 2019 -0400

    Implements ldexp. (#15845)
    
    Remove spaces.
    
    Change tests.
    
    Reorganize files.
    
    Change styles.
    
    Add spaces
---
 python/mxnet/ndarray/numpy/_op.py              | 42 +++++++++++++++++-
 python/mxnet/numpy/multiarray.py               | 40 ++++++++++++++++-
 python/mxnet/symbol/numpy/_symbol.py           | 31 ++++++++++++-
 src/operator/mshadow_op.h                      | 11 +++++
 src/operator/numpy/np_elemwise_broadcast_op.cc | 37 ++++++++++++++++
 src/operator/numpy/np_elemwise_broadcast_op.cu | 19 ++++++++
 src/operator/operator_tune.cc                  |  5 +++
 tests/python/unittest/test_numpy_op.py         | 61 ++++++++++++++++++++++++++
 8 files changed, 242 insertions(+), 4 deletions(-)

diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index a3b4a27..99ef61b 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -35,7 +35,7 @@ __all__ = ['zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mo
            'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
            'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
            'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
-           'unique']
+           'unique', 'ldexp']
 
 
 @set_module('mxnet.ndarray.numpy')
@@ -3246,7 +3246,7 @@ def hypot(x1, x2, out=None):
     Notes
     -----
     This function differs from the original numpy.arange in the following aspects:
-	    - Only support float16, float32 and float64.
+        - Only support float16, float32 and float64.
 
     Examples
     --------
@@ -3263,3 +3263,41 @@ def hypot(x1, x2, out=None):
            [ 5.,  5.,  5.]])
     """
     return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out)
+
+
+@set_module('mxnet.ndarray.numpy')
+def ldexp(x1, x2, out=None):
+    """
+    Returns x1 * 2**x2, element-wise.
+    The mantissas `x1` and twos exponents `x2` are used to construct
+    floating point numbers ``x1 * 2**x2``.
+
+    Parameters
+    ----------
+    x1 : ndarray or scalar
+        Array of multipliers.
+    x2 : ndarray or scalar, int
+        Array of twos exponents.
+    out : ndarray, optional
+        A location into which the result is stored. If provided, it must have
+        a shape that the inputs broadcast to. If not, a freshly-allocated array is returned.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        The result of ``x1 * 2**x2``.
+        This is a scalar if both `x1` and `x2` are scalars.
+
+    Notes
+    -----
+    Complex dtypes are not supported, they will raise a TypeError.
+    Different from numpy, we allow x2 to be float besides int.
+    `ldexp` is useful as the inverse of `frexp`, if used by itself it is
+    more clear to simply use the expression ``x1 * 2**x2``.
+
+    Examples
+    --------
+    >>> np.ldexp(5, np.arange(4))
+    array([  5.,  10.,  20.,  40.])
+    """
+    return _ufunc_helper(x1, x2, _npi.ldexp, _np.ldexp, _npi.ldexp_scalar, _npi.rldexp_scalar, out)
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 4972bda..3fd5801 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -54,7 +54,7 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'full', 'add', 'subtrac
            'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
            'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
            'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
-           'rad2deg', 'deg2rad', 'unique']
+           'rad2deg', 'deg2rad', 'unique', 'ldexp']
 
 # Return code for dispatching indexing function call
 _NDARRAY_UNSUPPORTED_INDEXING = -1
@@ -4792,3 +4792,41 @@ def hypot(x1, x2, out=None):
            [ 5.,  5.,  5.]])
     """
     return _mx_nd_np.hypot(x1, x2, out=out)
+
+
+@set_module('mxnet.numpy')
+def ldexp(x1, x2, out=None):
+    """
+    Returns x1 * 2**x2, element-wise.
+    The mantissas `x1` and twos exponents `x2` are used to construct
+    floating point numbers ``x1 * 2**x2``.
+
+    Parameters
+    ----------
+    x1 : ndarray or scalar
+        Array of multipliers.
+    x2 : ndarray or scalar, int
+        Array of twos exponents.
+    out : ndarray, optional
+        A location into which the result is stored. If provided, it must have
+        a shape that the inputs broadcast to. If not, a freshly-allocated array is returned.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        The result of ``x1 * 2**x2``.
+        This is a scalar if both `x1` and `x2` are scalars.
+
+    Notes
+    -----
+    Complex dtypes are not supported, they will raise a TypeError.
+    Different from numpy, we allow x2 to be float besides int.
+    `ldexp` is useful as the inverse of `frexp`, if used by itself it is
+    more clear to simply use the expression ``x1 * 2**x2``.
+
+    Examples
+    --------
+    >>> np.ldexp(5, np.arange(4))
+    array([  5.,  10.,  20.,  40.])
+    """
+    return _mx_nd_np.ldexp(x1, x2, out)
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 57b18ec..af1eaed 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -37,7 +37,7 @@ __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'rem
            'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
            'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
            'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
-           'unique']
+           'unique', 'ldexp']
 
 
 def _num_outputs(sym):
@@ -3394,4 +3394,33 @@ def unique(ar, return_index=False, return_inverse=False, return_counts=False, ax
     return _npi.unique(ar, return_index, return_inverse, return_counts, axis)
 
 
+@set_module('mxnet.symbol.numpy')
+def ldexp(x1, x2, out=None):
+    """
+    ldexp(x1, x2, out=None)
+    Returns x1 * 2**x2, element-wise.
+    The mantissas `x1` and twos exponents `x2` are used to construct
+    floating point numbers ``x1 * 2**x2``.
+    Parameters
+    ----------
+    x1 : _Symbol
+        Array of multipliers.
+    x2 : _Symbol
+        Array of twos exponents.
+    out : _Symbol or None
+        Dummy parameter to keep the consistency with the ndarray counterpart.
+    Returns
+    -------
+    y : _Symbol
+        The result of ``x1 * 2**x2``.
+    Notes
+    -----
+    Complex dtypes are not supported, they will raise a TypeError.
+    Different from numpy, we allow x2 to be float besides int.
+    `ldexp` is useful as the inverse of `frexp`, if used by itself it is
+    more clear to simply use the expression ``x1 * 2**x2``.
+    """
+    return _ufunc_helper(x1, x2, _npi.ldexp, _np.ldexp, _npi.ldexp_scalar, _npi.rldexp_scalar, out)
+
+
 _set_np_symbol_class(_Symbol)
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index 6261638..cf18f82 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -367,6 +367,17 @@ MXNET_UNARY_MATH_OP(reciprocal_cube_root, 1.0f / math::cbrt(a));
 
 MXNET_UNARY_MATH_OP(reciprocal_cube_root_grad, -1.0f / (3.0f * math::cbrt(a) * math::id(a)));
 
+/*! \brief used for generate element of ldexp */
+MXNET_BINARY_MATH_OP(ldexp, math::id(a) * math::pow(2.0f, b));
+
+MXNET_BINARY_MATH_OP(ldexp_grad, math::pow(2.0f, b));
+
+MXNET_BINARY_MATH_OP(ldexp_rgrad, math::id(a) * math::pow(2.0f, b) * math::log(2.0f));
+
+MXNET_BINARY_MATH_OP(rldexp, math::id(b) * math::pow(2.0f, a));  // swap a and b if a is scalar.
+
+MXNET_BINARY_MATH_OP(rldexp_grad, math::id(b) * math::pow(2.0f, a) * math::log(2.0f));
+
 /*! \brief used for generate element of round */
 MXNET_SIMPLE_UNARY_MATH_OP(round);
 
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc
index 16d4ef8..ed891a9 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cc
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cc
@@ -263,5 +263,42 @@ NNVM_REGISTER_OP(_backward_npi_hypot)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, mshadow_op::hypot_grad_left,
                                                                   mshadow_op::hypot_grad_right>);
 
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_ldexp)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::ldexp>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_ldexp"});
+
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_ldexp_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::ldexp>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_ldexp_scalar"});
+
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rldexp_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rldexp>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_rldexp_scalar"});
+
+NNVM_REGISTER_OP(_backward_npi_ldexp)
+.set_num_inputs(3)
+.set_num_outputs(2)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 1}};
+  })
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, mshadow_op::ldexp_grad,
+                                                                  mshadow_op::ldexp_rgrad>);
+
+MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_ldexp_scalar)
+.add_argument("scalar", "float", "scalar value")
+.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Backward<cpu, mshadow_op::ldexp_grad>);
+
+MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rldexp_scalar)
+.add_argument("scalar", "float", "scalar value")
+.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Backward<cpu, mshadow_op::rldexp_grad>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu
index 77525ce..c91b5e9 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cu
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cu
@@ -112,5 +112,24 @@ NNVM_REGISTER_OP(_npi_rarctan2_scalar)
 NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rarctan2_grad>);
 
+NNVM_REGISTER_OP(_npi_ldexp)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::ldexp>);
+
+NNVM_REGISTER_OP(_npi_ldexp_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::ldexp>);
+
+NNVM_REGISTER_OP(_npi_rldexp_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rldexp>);
+
+NNVM_REGISTER_OP(_backward_npi_ldexp)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, mshadow_op::ldexp_grad,
+                                                                  mshadow_op::ldexp_rgrad>);
+
+NNVM_REGISTER_OP(_backward_npi_ldexp_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<gpu, mshadow_op::ldexp_grad>);
+
+NNVM_REGISTER_OP(_backward_npi_rldexp_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<gpu, mshadow_op::rldexp_grad>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index 1d64438..73a1e00 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -372,6 +372,11 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient);  // NO
 IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>);  // NOLINT()
 IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<1>);  // NOLINT()
 IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ldexp);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rldexp);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_grad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_rgrad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rldexp_grad);  // NOLINT()
 /*!
  * \brief Tuner objects, *not* automatically generated
  */
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index d2dc6ab..af3b430 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -220,6 +220,67 @@ def test_np_dot():
 
 @with_seed()
 @use_np
+def test_np_ldexp():
+    class TestLdexp(HybridBlock):
+        def __init__(self):
+            super(TestLdexp, self).__init__()
+
+        def hybrid_forward(self, F, x1, x2):
+            return F.np.ldexp(x1, x2)
+        
+    def _np_ldexp(x1, x2):
+        return x1 * _np.power(2.0, x2)
+
+    def dldx(x1, x2): 
+        grad_a = _np.power(2.0, x2)
+        grad_b = _np_ldexp(x1, x2) * _np.log(2.0)
+        if len(x1) == 1:
+            grad_a = _np.sum(grad_a)
+        if len(x2) == 1:
+            grad_b = _np.sum(grad_b)
+        return [grad_a, grad_b]
+
+    shapes = [ 
+        ((3, 1), (3, 1)),
+        ((3, 1, 2), (3, 1, 2)),
+        ((1, ),(1, )),
+        ((1, ), (2, )),
+        ((3, ), (1, )),
+        ((3, 0), (3, 0)),  # zero-size shape
+        ((0, 1), (0, 1)),  # zero-size shape
+        ((2, 0, 2), (2, 0, 2)),  # zero-size shape
+        ] 
+
+    for hybridize in [True, False]:
+        for shape1, shape2 in shapes:
+            for dtype in [_np.float16, _np.float32, _np.float64]:
+                test_ldexp = TestLdexp()
+                if hybridize:
+                    test_ldexp.hybridize()
+                x1 = rand_ndarray(shape=shape1, dtype=dtype).as_np_ndarray() 
+                x1.attach_grad()
+                x2 = rand_ndarray(shape=shape2, dtype=dtype).as_np_ndarray()
+                x2.attach_grad()
+
+                np_out = _np_ldexp(x1.asnumpy(), x2.asnumpy())
+                with mx.autograd.record():
+                    mx_out = test_ldexp(x1, x2)
+                assert mx_out.shape == np_out.shape
+                assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1)
+
+                mx_out.backward()
+                np_backward = dldx(x1.asnumpy(), x2.asnumpy())
+                assert_almost_equal(x1.grad.asnumpy(), np_backward[0], atol=1e-1, rtol=1e-1)
+                assert_almost_equal(x2.grad.asnumpy(), np_backward[1], atol=1e-1, rtol=1e-1)
+
+                # Test imperative once again
+                mx_out = np.ldexp(x1, x2)
+                np_out = _np_ldexp(x1.asnumpy(), x2.asnumpy())
+                assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1)
+
+
+@with_seed()
+@use_np
 def test_np_sum():
     class TestSum(HybridBlock):
         def __init__(self, axis=None, dtype=None, keepdims=False):