You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sx...@apache.org on 2019/09/15 18:58:46 UTC

[incubator-mxnet] branch master updated: [Numpy] Numpy copysign (#15851)

This is an automated email from the ASF dual-hosted git repository.

sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 90091b1  [Numpy] Numpy copysign (#15851)
90091b1 is described below

commit 90091b155d6f53c070e3c406f9edc69f38d02e96
Author: Haozheng Fan <fh...@gmail.com>
AuthorDate: Mon Sep 16 02:57:51 2019 +0800

    [Numpy] Numpy copysign (#15851)
    
    * add numpy compatible copysign
    
    * fix scalar op registration error
    
    * add test
---
 python/mxnet/ndarray/numpy/_op.py              |  53 ++++++++++++-
 python/mxnet/numpy/multiarray.py               |  53 ++++++++++++-
 python/mxnet/symbol/numpy/_symbol.py           |  36 ++++++++-
 src/operator/mshadow_op.h                      |  10 +++
 src/operator/numpy/np_elemwise_broadcast_op.cc |  36 +++++++++
 src/operator/numpy/np_elemwise_broadcast_op.cu |  21 +++++
 src/operator/operator_tune.cc                  |   5 ++
 tests/python/unittest/test_numpy_op.py         | 105 +++++++++++++++++++++++++
 8 files changed, 316 insertions(+), 3 deletions(-)

diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 671345c..b8e4f3f 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -33,7 +33,7 @@ __all__ = ['zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mo
            'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
            'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
            'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
-           'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices']
+           'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign']
 
 
 @set_module('mxnet.ndarray.numpy')
@@ -2432,3 +2432,54 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
     else:
         raise ValueError("The dimensions must be sequence of ints")
 # pylint: enable=redefined-outer-name
+
+
+@set_module('mxnet.ndarray.numpy')
+def copysign(x1, x2, out=None):
+    r"""copysign(x1, x2, out=None)
+
+    Change the sign of x1 to that of x2, element-wise.
+
+    If `x2` is a scalar, its sign will be copied to all elements of `x1`.
+
+    Parameters
+    ----------
+    x1 : ndarray or scalar
+        Values to change the sign of.
+    x2 : ndarray or scalar
+        The sign of `x2` is copied to `x1`.
+    out : ndarray or None, optional
+        A location into which the result is stored. It must be of the
+        right shape and right type to hold the output. If not provided
+        or `None`,a freshly-allocated array is returned.
+
+    Returns
+    -------
+    out : ndarray or scalar
+        The values of `x1` with the sign of `x2`.
+        This is a scalar if both `x1` and `x2` are scalars.
+
+    Notes
+    -------
+    This function differs from the original `numpy.copysign
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.copysign.html>`_ in
+    the following aspects:
+
+    - ``where`` param is not supported.
+
+    Examples
+    --------
+    >>> np.copysign(1.3, -1)
+    -1.3
+    >>> 1/np.copysign(0, 1)
+    inf
+    >>> 1/np.copysign(0, -1)
+    -inf
+
+    >>> a = np.array([-1, 0, 1])
+    >>> np.copysign(a, -1.1)
+    array([-1., -0., -1.])
+    >>> np.copysign(a, np.arange(3)-1)
+    array([-1.,  0.,  1.])
+    """
+    return _ufunc_helper(x1, x2, _npi.copysign, _np.copysign, _npi.copysign_scalar, _npi.rcopysign_scalar, out)
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 1f8aa92..632cfad 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -52,7 +52,7 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'full', 'add', 'subtrac
            'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
            'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
            'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
-           'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices']
+           'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign']
 
 # Return code for dispatching indexing function call
 _NDARRAY_UNSUPPORTED_INDEXING = -1
@@ -3935,3 +3935,54 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
     """
     return _mx_nd_np.indices(dimensions=dimensions, dtype=dtype, ctx=ctx)
 # pylint: enable=redefined-outer-name
+
+
+@set_module('mxnet.numpy')
+def copysign(x1, x2, out=None):
+    r"""copysign(x1, x2, out=None)
+
+    Change the sign of x1 to that of x2, element-wise.
+
+    If `x2` is a scalar, its sign will be copied to all elements of `x1`.
+
+    Parameters
+    ----------
+    x1 : ndarray or scalar
+        Values to change the sign of.
+    x2 : ndarray or scalar
+        The sign of `x2` is copied to `x1`.
+    out : ndarray or None, optional
+        A location into which the result is stored. It must be of the
+        right shape and right type to hold the output. If not provided
+        or `None`,a freshly-allocated array is returned.
+
+    Returns
+    -------
+    out : ndarray or scalar
+        The values of `x1` with the sign of `x2`.
+        This is a scalar if both `x1` and `x2` are scalars.
+
+    Notes
+    -------
+    This function differs from the original `numpy.copysign
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.copysign.html>`_ in
+    the following aspects:
+
+    - ``where`` param is not supported.
+
+    Examples
+    --------
+    >>> np.copysign(1.3, -1)
+    -1.3
+    >>> 1/np.copysign(0, 1)
+    inf
+    >>> 1/np.copysign(0, -1)
+    -inf
+
+    >>> a = np.array([-1, 0, 1])
+    >>> np.copysign(a, -1.1)
+    array([-1., -0., -1.])
+    >>> np.copysign(a, np.arange(3)-1)
+    array([-1.,  0.,  1.])
+    """
+    return _mx_nd_np.copysign(x1, x2, out=out)
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 077008a..5a38f81 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -35,7 +35,7 @@ __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'rem
            'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
            'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
            'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
-           'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices']
+           'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign']
 
 
 def _num_outputs(sym):
@@ -2744,4 +2744,38 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
 # pylint: enable=redefined-outer-name
 
 
+@set_module('mxnet.symbol.numpy')
+def copysign(x1, x2, out=None):
+    r"""copysign(x1, x2, out=None)
+
+    Change the sign of x1 to that of x2, element-wise.
+
+    If `x2` is a scalar, its sign will be copied to all elements of `x1`.
+
+    Parameters
+    ----------
+    x1 : _Symbol or scalar
+        Values to change the sign of.
+    x2 : _Symbol or scalar
+        The sign of `x2` is copied to `x1`.
+    out : _Symbol or None
+        Dummy parameter to keep the consistency with the ndarray counterpart.
+
+    Returns
+    -------
+    out : _Symbol
+        The values of `x1` with the sign of `x2`.
+        This is a scalar if both `x1` and `x2` are scalars.
+
+    Notes
+    -------
+    This function differs from the original `numpy.copysign
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.copysign.html>`_ in
+    the following aspects:
+
+    - ``where`` param is not supported.
+    """
+    return _ufunc_helper(x1, x2, _npi.copysign, _np.copysign, _npi.copysign_scalar, _npi.rcopysign_scalar, out)
+
+
 _set_np_symbol_class(_Symbol)
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index 616192e..f3d24b2 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -417,6 +417,16 @@ MXNET_BINARY_MATH_OP(rdiv, math::id(b) / math::id(a));
 
 MXNET_BINARY_MATH_OP(rdiv_grad, -math::id(b) / math::sqr(a));
 
+MXNET_BINARY_MATH_OP(copysign, (a >= 0 && b >= 0) || (a < 0 && b < 0) ? a : -a);
+
+MXNET_BINARY_MATH_OP(copysign_grad, (a >= 0 && b >= 0) || (a < 0 && b < 0) ? 1: -1);
+
+MXNET_BINARY_MATH_OP(copysign_rgrad, 0);
+
+MXNET_BINARY_MATH_OP(rcopysign, (b >= 0 && a >= 0) || (b < 0 && a < 0) ? b : -b);
+
+MXNET_BINARY_MATH_OP(rcopysign_grad, 0);
+
 struct mod : public mxnet_op::tunable {
   template<typename DType>
   MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc
index 697657d..a9254e8 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cc
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cc
@@ -76,6 +76,26 @@ MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_power)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::power>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"});
 
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_copysign)
+.describe(R"code()code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::copysign>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign"});
+
+NNVM_REGISTER_OP(_backward_npi_copysign)
+.set_num_inputs(3)
+.set_num_outputs(2)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 1}};
+  })
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, mshadow_op::copysign_grad,
+                                                                  mshadow_op::copysign_rgrad>);
+
 MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar)
 .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::plus>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});
@@ -108,5 +128,21 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rpower_scalar)
 .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rpower>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_rpower_scalar"});
 
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::copysign>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign_scalar"});
+
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rcopysign_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rcopysign>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_rcopysign_scalar"});
+
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_copysign_scalar)
+.set_attr<FCompute>("FCompute<cpu>",
+                    BinaryScalarOp::Backward<cpu, mshadow_op::copysign_grad>);
+
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_rcopysign_scalar)
+.set_attr<FCompute>("FCompute<cpu>",
+                    BinaryScalarOp::Backward<cpu, mshadow_op::rcopysign_grad>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu
index ac8def2..ecf8e85 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cu
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cu
@@ -42,6 +42,13 @@ NNVM_REGISTER_OP(_npi_mod)
 NNVM_REGISTER_OP(_npi_power)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::power>);
 
+NNVM_REGISTER_OP(_npi_copysign)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::copysign>);
+
+NNVM_REGISTER_OP(_backward_npi_copysign)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, mshadow_op::copysign_grad,
+                                                                  mshadow_op::copysign_rgrad>);
+
 NNVM_REGISTER_OP(_npi_add_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::plus>);
 
@@ -66,5 +73,19 @@ NNVM_REGISTER_OP(_npi_power_scalar)
 NNVM_REGISTER_OP(_npi_rpower_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rpower>);
 
+NNVM_REGISTER_OP(_npi_copysign_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::copysign>);
+
+NNVM_REGISTER_OP(_npi_rcopysign_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rcopysign>);
+
+NNVM_REGISTER_OP(_backward_npi_copysign_scalar)
+.set_attr<FCompute>("FCompute<gpu>",
+                    BinaryScalarOp::Backward<gpu, mshadow_op::copysign_grad>);
+
+NNVM_REGISTER_OP(_backward_npi_rcopysign_scalar)
+.set_attr<FCompute>("FCompute<gpu>",
+                    BinaryScalarOp::Backward<gpu, mshadow_op::rcopysign_grad>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index 98ce14e..5159525 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -328,6 +328,11 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::elu); // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::copysign);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rcopysign);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::copysign_grad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::copysign_rgrad);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rcopysign_grad);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::xelu_grad); // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gelu_grad); // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::prelu_grad); // NOLINT()
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index c5b0907..1f2af8d 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -1853,6 +1853,111 @@ def test_np_linalg_norm():
                     assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-5, rtol=1e-4)
 
 
+@with_seed()
+@use_np
+def test_np_copysign():
+    class TestCopysign(HybridBlock):
+        def __init__(self):
+            super(TestCopysign, self).__init__()
+
+        def hybrid_forward(self, F, a1, a2):
+	            return F.np.copysign(a1, a2)
+
+    def get_grad(a1, a2):
+        sign = _np.logical_or(_np.logical_and(a1 < 0, a2 < 0),
+                              _np.logical_and(a1 >= 0, a2 >= 0))
+        sign = 2 * sign.astype(int) - 1
+        sign = sign.reshape(-1, *a1.shape)
+        sign = _np.sum(sign, axis=0)
+        return sign, _np.zeros_like(a2)
+    
+    def get_grad_left(a1, a2):
+        sign = _np.logical_or(_np.logical_and(a1 < 0, a2 < 0),
+                              _np.logical_and(a1 >= 0, a2 >= 0))
+        sign = 2 * sign.astype(int) - 1
+        sign = sign.reshape(a1.shape)
+        return sign
+    
+    def get_grad_right(a1, a2):
+        return _np.zeros_like(a2)
+        
+    shapes = [
+        (),
+        (1),
+        (2, 1),
+        (3, 2, 1),
+        (4, 3, 2, 1),
+        (2, 4, 3, 2, 1)
+    ]
+    types = ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']
+    for a1shape in shapes:
+        for a2shape in shapes:
+            for hybridize in [True, False]:
+                for dtype in types:
+                    test_copysign = TestCopysign()
+                    if hybridize:
+                        test_copysign.hybridize()
+                    rtol = 1e-3
+                    atol = 1e-5
+                    a1_np = _np.array(_np.random.uniform(-1.0, 1.0, a1shape), dtype=dtype)
+                    a2_np = _np.array(_np.random.uniform(-1.0, 1.0, a2shape), dtype=dtype)
+                    a1 = np.array(a1_np, dtype=dtype)
+                    a2 = np.array(a2_np, dtype=dtype)
+                    a1.attach_grad()
+                    a2.attach_grad()
+                    expected_np = _np.copysign(a1_np, a2_np)
+                    with mx.autograd.record():
+                        mx_out = test_copysign(a1, a2)
+                    assert mx_out.shape == expected_np.shape
+                    assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)
+
+                    # Test gradient
+                    mx_out.backward()
+                    a1_grad, a2_grad = get_grad(a1_np, a2_np)
+                    assert_almost_equal(a1.grad.asnumpy(), a1_grad, rtol=rtol, atol=atol)
+                    assert_almost_equal(a2.grad.asnumpy(), a2_grad, rtol=rtol, atol=atol)
+
+                    # Test imperative once again
+                    mx_out = np.copysign(a1, a2)
+                    expected_np = _np.copysign(a1_np, a2_np)
+                    assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)
+    
+    types = ['float16', 'float32', 'float64']
+    for x_shape in shapes:
+        for dtype in types:
+            # Test left
+            x_np = _np.array(_np.random.uniform(-2.0, 2.0, x_shape), dtype=dtype)
+            scalar = _np.random.uniform(-2.0, 2.0)
+            x = np.array(x_np, dtype=dtype)
+            x.attach_grad()
+            expected_np = _np.copysign(x_np, scalar)
+            with mx.autograd.record():
+                mx_out = np.copysign(x, scalar)
+            assert mx_out.shape == expected_np.shape
+            assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)
+            
+            # Test gradient
+            mx_out.backward()
+            x_grad = get_grad_left(x_np, scalar)
+            assert_almost_equal(x.grad.asnumpy(), x_grad, rtol=rtol, atol=atol)
+            
+            # Test right
+            x_np = _np.array(_np.random.uniform(-2.0, 2.0, x_shape), dtype=dtype)
+            scalar = _np.random.uniform(-2.0, 2.0)
+            x = np.array(x_np, dtype=dtype)
+            x.attach_grad()
+            expected_np = _np.copysign(scalar, x_np)
+            with mx.autograd.record():
+                mx_out = np.copysign(scalar, x)
+            assert mx_out.shape == expected_np.shape
+            assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)
+            
+            # Test gradient
+            mx_out.backward()
+            x_grad = get_grad_right(scalar, x_np)
+            assert_almost_equal(x.grad.asnumpy(), x_grad, rtol=rtol, atol=atol)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()