You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/10/27 16:49:30 UTC

[incubator-mxnet] branch master updated: [API] Add logaddexp (#20673)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 64999b4  [API] Add logaddexp (#20673)
64999b4 is described below

commit 64999b486f1c03cf1bf96b3e38337877c85afc49
Author: Zhenghui Jin <69...@users.noreply.github.com>
AuthorDate: Wed Oct 27 09:47:09 2021 -0700

    [API] Add logaddexp (#20673)
    
    * add logaddexp
    
    * update
    
    * fix lint
    
    * update operator_tune.cc
    
    * add tests
    
    * add tests
    
    * add comma
    
    * fix numpy op test
---
 docs/python_docs/python/api/np/routines.math.rst   |  2 +
 python/mxnet/amp/lists/symbol_fp16.py              |  2 +
 python/mxnet/ndarray/numpy/_op.py                  | 44 +++++++++++++++-
 python/mxnet/numpy/multiarray.py                   | 43 +++++++++++++++-
 python/mxnet/numpy_dispatch_protocol.py            |  1 +
 src/api/operator/numpy/np_elemwise_broadcast_op.cc |  8 +++
 src/common/cuda/rtc/backward_functions-inl.h       | 14 +++++
 src/common/cuda/rtc/forward_functions-inl.h        | 10 ++++
 src/operator/mshadow_op.h                          |  7 +++
 src/operator/numpy/np_elemwise_broadcast_op_lae.cc | 60 ++++++++++++++++++++++
 src/operator/numpy/np_elemwise_broadcast_op_lae.cu | 44 ++++++++++++++++
 src/operator/operator_tune.cc                      |  3 ++
 .../gpu/{test_numpy_op.py => test_numpy_einsum.py} |  0
 .../python/unittest/test_numpy_interoperability.py |  8 +++
 tests/python/unittest/test_numpy_op.py             |  2 +
 15 files changed, 246 insertions(+), 2 deletions(-)

diff --git a/docs/python_docs/python/api/np/routines.math.rst b/docs/python_docs/python/api/np/routines.math.rst
index bb1301b..c909a56 100644
--- a/docs/python_docs/python/api/np/routines.math.rst
+++ b/docs/python_docs/python/api/np/routines.math.rst
@@ -105,6 +105,7 @@ Exponents and logarithms
    log10
    log2
    log1p
+   logaddexp
 
 
 Other special functions
@@ -133,6 +134,7 @@ Rational routines
    :toctree: generated/
 
    lcm
+   gcd
 
 
 Arithmetic operations
diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py
index b561b33..307336c 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -486,6 +486,8 @@ FP32_FUNCS = [
     '_npi_expm1',
     '_npi_ldexp',
     '_npi_ldexp_scalar',
+    '_npi_logaddexp',
+    '_npi_logaddexp_scalar',
     '_npi_log',
     '_npi_log10',
     '_npi_log1p',
diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index cf1bd52..ef1c6b7 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -51,7 +51,7 @@ __all__ = ['shape', 'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_li
            'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
            'atleast_1d', 'atleast_2d', 'atleast_3d', 'fill_diagonal', 'squeeze',
            'where', 'bincount', 'rollaxis', 'diagflat', 'repeat', 'prod', 'pad', 'cumsum', 'sum', 'diag', 'diagonal',
-           'positive']
+           'positive', 'logaddexp']
 
 
 @set_module('mxnet.ndarray.numpy')
@@ -6915,6 +6915,48 @@ def ldexp(x1, x2, out=None, **kwargs):
 
 
 @set_module('mxnet.ndarray.numpy')
+@wrap_np_binary_func
+def logaddexp(x1, x2, out=None, **kwargs):
+    """
+    Logarithm of the sum of exponentiations of the inputs.
+
+    Calculates log(exp(x1) + exp(x2)). This function is useful in statistics where
+    the calculated probabilities of events may be so small as to exceed the range of
+    normal floating point numbers. In such cases the logarithm of the calculate
+    probability is stored. This function allows adding probabilities stored
+    in such a fashion.
+
+    Parameters
+    ----------
+    x1 : ndarray or scalar
+        Array of multipliers.
+    x2 : ndarray or scalar, int
+        Array of twos exponents.
+    out : ndarray, optional
+        A location into which the result is stored. If provided, it must have
+        a shape that the inputs broadcast to. If not, a freshly-allocated array is returned.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        Logarithm of exp(x1) + exp(x2). This is a scalar if both x1 and x2 are scalars.
+
+    Examples
+    --------
+    >>> prob1 = np.log(1e-50)
+    >>> prob2 = np.log(2.5e-50)
+    >>> prob12 = np.logaddexp(prob1, prob2)
+    >>> prob12
+    -113.87649168120691
+    >>> np.exp(prob12)
+    3.5000000000000057e-50
+    """
+    if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
+        return _np.logaddexp(x1, x2, out=out)
+    return _api_internal.logaddexp(x1, x2, out)
+
+
+@set_module('mxnet.ndarray.numpy')
 def vdot(a, b):
     r"""
     Return the dot product of two vectors.
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index a58f1fa..427f8ff 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -80,7 +80,8 @@ __all__ = ['ndarray', 'empty', 'empty_like', 'array', 'shape', 'median',
            'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul',
            'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount',
            'atleast_1d', 'atleast_2d', 'atleast_3d', 'fill_diagonal', 'squeeze',
-           'diagflat', 'repeat', 'prod', 'pad', 'cumsum', 'sum', 'rollaxis', 'diag', 'diagonal', 'positive']
+           'diagflat', 'repeat', 'prod', 'pad', 'cumsum', 'sum', 'rollaxis', 'diag', 'diagonal',
+           'positive', 'logaddexp']
 
 __all__ += fallback.__all__
 
@@ -9505,6 +9506,46 @@ def ldexp(x1, x2, out=None, **kwargs):
 
 
 @set_module('mxnet.numpy')
+@wrap_np_binary_func
+def logaddexp(x1, x2, out=None, **kwargs):
+    """
+    Logarithm of the sum of exponentiations of the inputs.
+
+    Calculates log(exp(x1) + exp(x2)). This function is useful in statistics where
+    the calculated probabilities of events may be so small as to exceed the range of
+    normal floating point numbers. In such cases the logarithm of the calculate
+    probability is stored. This function allows adding probabilities stored
+    in such a fashion.
+
+    Parameters
+    ----------
+    x1 : ndarray or scalar
+        Array of multipliers.
+    x2 : ndarray or scalar, int
+        Array of twos exponents.
+    out : ndarray, optional
+        A location into which the result is stored. If provided, it must have
+        a shape that the inputs broadcast to. If not, a freshly-allocated array is returned.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        Logarithm of exp(x1) + exp(x2). This is a scalar if both x1 and x2 are scalars.
+
+    Examples
+    --------
+    >>> prob1 = np.log(1e-50)
+    >>> prob2 = np.log(2.5e-50)
+    >>> prob12 = np.logaddexp(prob1, prob2)
+    >>> prob12
+    -113.87649168120691
+    >>> np.exp(prob12)
+    3.5000000000000057e-50
+    """
+    return _mx_nd_np.logaddexp(x1, x2, out)
+
+
+@set_module('mxnet.numpy')
 def vdot(a, b):
     r"""
     Return the dot product of two vectors.
diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py
index f047076..c293621 100644
--- a/python/mxnet/numpy_dispatch_protocol.py
+++ b/python/mxnet/numpy_dispatch_protocol.py
@@ -251,6 +251,7 @@ _NUMPY_ARRAY_UFUNC_LIST = [
     'lcm',
     'gcd',
     # 'ldexp',
+    'logaddexp',
     'subtract',
     'multiply',
     'true_divide',
diff --git a/src/api/operator/numpy/np_elemwise_broadcast_op.cc b/src/api/operator/numpy/np_elemwise_broadcast_op.cc
index 184a4e2..b9f1060 100644
--- a/src/api/operator/numpy/np_elemwise_broadcast_op.cc
+++ b/src/api/operator/numpy/np_elemwise_broadcast_op.cc
@@ -139,6 +139,14 @@ MXNET_REGISTER_API("_npi.bitwise_and")
       UFuncHelper(args, ret, op, op_scalar, nullptr);
     });
 
+MXNET_REGISTER_API("_npi.logaddexp")
+    .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
+      using namespace runtime;
+      const nnvm::Op* op        = Op::Get("_npi_logaddexp");
+      const nnvm::Op* op_scalar = Op::Get("_npi_logaddexp_scalar");
+      UFuncHelper(args, ret, op, op_scalar, nullptr);
+    });
+
 MXNET_REGISTER_API("_npi.copysign")
     .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
       using namespace runtime;
diff --git a/src/common/cuda/rtc/backward_functions-inl.h b/src/common/cuda/rtc/backward_functions-inl.h
index 50d8469..d28ce73 100644
--- a/src/common/cuda/rtc/backward_functions-inl.h
+++ b/src/common/cuda/rtc/backward_functions-inl.h
@@ -463,6 +463,20 @@ rldexp_grad(const DType val,
 }
 
 template <typename DType, typename DType2>
+__device__ inline mixed_type<DType, DType2>
+logaddexp_grad(const DType val,
+           const DType2 val2) {
+  return op::exp(val) / (op::exp(val) + op::exp(val2));
+}
+
+template <typename DType, typename DType2>
+__device__ inline mixed_type<DType, DType2>
+logaddexp_rgrad(const DType val,
+           const DType2 val2) {
+  return op::exp(val2) / (op::exp(val) + op::exp(val2));
+}
+
+template <typename DType, typename DType2>
 __device__ inline DType smooth_l1_grad(const DType val, const DType2 scalar) {
   auto bsq = scalar * scalar;
   auto ibsq = 1.0f / bsq;
diff --git a/src/common/cuda/rtc/forward_functions-inl.h b/src/common/cuda/rtc/forward_functions-inl.h
index 7a886a0..333ae04 100644
--- a/src/common/cuda/rtc/forward_functions-inl.h
+++ b/src/common/cuda/rtc/forward_functions-inl.h
@@ -621,6 +621,16 @@ rldexp(const DType a, const DType2 b) {
   return ldexp(b, a);
 }
 
+template <typename DType, typename DType2>
+__device__ inline mixed_type<DType, DType2>
+logaddexp(const DType a, const DType2 b) {
+  if (type_util::has_double_or_integral<DType, DType2>::value) {
+    return ::log(::exp(static_cast<double>(a)) + ::exp(static_cast<double>(b)));
+  } else {
+    return ::log(::expf(static_cast<float>(a)) + ::expf(static_cast<float>(b)));
+  }
+}
+
 #undef DEFINE_BINARY_MATH_FUNC
 
 template <typename DType, typename DType2>
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index a852fff..677d924 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -726,6 +726,13 @@ MXNET_BINARY_MATH_OP(rldexp, math::id(b) * math::pow(2.0f, a));  // swap a and b
 
 MXNET_BINARY_MATH_OP(rldexp_grad, math::id(b) * math::pow(2.0f, a) * math::log(2.0f));
 
+/*! \brief used for generate element of logaddexp */
+MXNET_BINARY_MATH_OP(logaddexp, math::log(math::exp(a) + math::exp(b)));
+
+MXNET_BINARY_MATH_OP(logaddexp_grad, math::exp(a) / (math::exp(a) + math::exp(b)));
+
+MXNET_BINARY_MATH_OP(logaddexp_rgrad, math::exp(b) / (math::exp(a) + math::exp(b)));
+
 /*! \brief used for generate element of round */
 MXNET_SIMPLE_UNARY_MATH_OP(round);
 
diff --git a/src/operator/numpy/np_elemwise_broadcast_op_lae.cc b/src/operator/numpy/np_elemwise_broadcast_op_lae.cc
new file mode 100644
index 0000000..05d83d8
--- /dev/null
+++ b/src/operator/numpy/np_elemwise_broadcast_op_lae.cc
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file np_elemwise_broadcast_op_lae.cc
+ * \brief CPU Implementation of basic functions for elementwise numpy binary logaddexp.
+ */
+
+#include "./np_elemwise_broadcast_op.h"
+
+namespace mxnet {
+namespace op {
+
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_logaddexp)
+    .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::logaddexp>)
+    .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_logaddexp"});
+
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_logaddexp_scalar)
+    .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::logaddexp>)
+    .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_logaddexp_scalar"});
+
+NNVM_REGISTER_OP(_backward_npi_logaddexp)
+    .set_num_inputs(3)
+    .set_num_outputs(2)
+    .set_attr<nnvm::TIsBackward>("TIsBackward", true)
+    .set_attr<nnvm::FInplaceOption>("FInplaceOption",
+                                    [](const NodeAttrs& attrs) {
+                                      return std::vector<std::pair<int, int> >{{0, 1}};
+                                    })
+    .set_attr<FResourceRequest>("FResourceRequest",
+                                [](const NodeAttrs& attrs) {
+                                  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+                                })
+    .set_attr<FCompute>(
+        "FCompute<cpu>",
+        BinaryBroadcastBackwardUseIn<cpu, mshadow_op::logaddexp_grad, mshadow_op::logaddexp_rgrad>);
+
+MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_logaddexp_scalar)
+    .add_arguments(NumpyBinaryScalarParam::__FIELDS__())
+    .set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
+    .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Backward<cpu, mshadow_op::logaddexp_grad>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/numpy/np_elemwise_broadcast_op_lae.cu b/src/operator/numpy/np_elemwise_broadcast_op_lae.cu
new file mode 100644
index 0000000..a8503a1
--- /dev/null
+++ b/src/operator/numpy/np_elemwise_broadcast_op_lae.cu
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file np_elemwise_broadcast_op_lae.cu
+ * \brief GPU Implementation of basic functions for elementwise binary broadcast logaddexp operator.
+ */
+
+#include "./np_elemwise_broadcast_op.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(_npi_logaddexp)
+    .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastRTCCompute{"logaddexp"});
+
+NNVM_REGISTER_OP(_npi_logaddexp_scalar)
+    .set_attr<FCompute>("FCompute<gpu>", BinaryScalarRTCCompute{"logaddexp"});
+
+NNVM_REGISTER_OP(_backward_npi_logaddexp)
+    .set_attr<FCompute>("FCompute<gpu>",
+                        BinaryBroadcastRTCBackwardUseIn{"logaddexp_grad", "logaddexp_rgrad"});
+
+NNVM_REGISTER_OP(_backward_npi_logaddexp_scalar)
+    .set_attr<FCompute>("FCompute<gpu>", BinaryScalarRTCBackward{"logaddexp_grad"});
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index f83885d..02cf907 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -463,6 +463,9 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rldexp);
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_grad);                  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_rgrad);                 // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rldexp_grad);                 // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logaddexp);                   // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logaddexp_grad);              // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logaddexp_rgrad);             // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::posone);                      // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::negone);                      // NOLINT()
 /*!
diff --git a/tests/python/gpu/test_numpy_op.py b/tests/python/gpu/test_numpy_einsum.py
similarity index 100%
rename from tests/python/gpu/test_numpy_op.py
rename to tests/python/gpu/test_numpy_einsum.py
diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py
index db643d1..09deace 100644
--- a/tests/python/unittest/test_numpy_interoperability.py
+++ b/tests/python/unittest/test_numpy_interoperability.py
@@ -1541,6 +1541,13 @@ def _add_workload_ldexp():
     OpArgMngr.add_workload('ldexp', np.array(2., np.float64), np.array(-9223372036854775808, np.int64))
 
 
+def _add_workload_logaddexp(array_pool):
+    OpArgMngr.add_workload('logaddexp', array_pool['4x1'], array_pool['1x2'])
+    OpArgMngr.add_workload('logaddexp', array_pool['4x1'], 2)
+    OpArgMngr.add_workload('logaddexp', 2, array_pool['4x1'])
+    OpArgMngr.add_workload('logaddexp', array_pool['4x1'], array_pool['1x1x0'])
+
+
 def _add_workload_subtract(array_pool):
     OpArgMngr.add_workload('subtract', array_pool['4x1'], array_pool['1x2'])
     OpArgMngr.add_workload('subtract', array_pool['4x1'], 2)
@@ -3082,6 +3089,7 @@ def _prepare_workloads():
     _add_workload_bitwise_xor()
     _add_workload_bitwise_or()
     _add_workload_ldexp()
+    _add_workload_logaddexp(array_pool)
     _add_workload_subtract(array_pool)
     _add_workload_multiply(array_pool)
     _add_workload_power(array_pool)
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 1010475..de0aeb1 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -3113,6 +3113,8 @@ def test_np_binary_funcs():
         'hypot': (-1, 1, [lambda y, x1, x2: x1 / y],
                          [lambda y, x1, x2: x2 / y]),
         'ldexp': (-3, 3, [None], None, [[onp.int32]]),
+        'logaddexp': (-10, 10, [lambda y, x1, x2: onp.exp(x1) / (onp.exp(x1) + onp.exp(x2))],
+                               [lambda y, x1, x2: onp.exp(x2) / (onp.exp(x1) + onp.exp(x2))])
     }
     if is_op_runnable():
         funcs['logical_and'] = (-100, 100, [None], None, [[onp.float32, onp.float64]])