You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2020/01/19 23:17:25 UTC

[incubator-mxnet] branch master updated: [numpy] add op random.exponential (#17280)

This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 28742cf  [numpy] add op random.exponential (#17280)
28742cf is described below

commit 28742cf6459f15dea9ff8e4b25ba5feb5e58efff
Author: Yiyan66 <57...@users.noreply.github.com>
AuthorDate: Mon Jan 20 07:16:37 2020 +0800

    [numpy] add op random.exponential (#17280)
    
    * C++ ok
    
    * before rebase
    
    * sanity
    
    * change sth
    
    * change sth
    
    * change sth
---
 python/mxnet/ndarray/numpy/random.py           |  30 ++++-
 python/mxnet/numpy/random.py                   |  25 ++++-
 python/mxnet/symbol/numpy/random.py            |  33 +++++-
 src/operator/numpy/random/np_exponential_op.cc |  72 ++++++++++++
 src/operator/numpy/random/np_exponential_op.cu |  35 ++++++
 src/operator/numpy/random/np_exponential_op.h  | 146 +++++++++++++++++++++++++
 tests/nightly/test_np_random.py                |  16 +++
 tests/python/unittest/test_numpy_op.py         |  30 +++++
 8 files changed, 384 insertions(+), 3 deletions(-)

diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py
index 913ceaa..6d1bda9 100644
--- a/python/mxnet/ndarray/numpy/random.py
+++ b/python/mxnet/ndarray/numpy/random.py
@@ -23,7 +23,7 @@ from . import _internal as _npi
 from ..ndarray import NDArray
 
 
-__all__ = ['randint', 'uniform', 'normal', "choice", "rand", "multinomial", "shuffle", 'gamma']
+__all__ = ['randint', 'uniform', 'normal', "choice", "rand", "multinomial", "shuffle", 'gamma', 'exponential']
 
 
 def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
@@ -319,6 +319,34 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None):
             return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True, out=out)
 
 
+def exponential(scale, size):
+    r"""Draw samples from an exponential distribution.
+    Parameters
+    ----------
+    scale : float or array_like of floats
+        The scale parameter, :math:`\beta = 1/\lambda`. Must be
+        non-negative.
+    size : int or tuple of ints, optional
+        Output shape.  If the given shape is, e.g., ``(m, n, k)``, then
+        ``m * n * k`` samples are drawn.  If size is ``None`` (default),
+        a single value is returned if ``scale`` is a scalar.  Otherwise,
+        ``np.array(scale).size`` samples are drawn.
+    Returns
+    -------
+    out : ndarray or scalar
+        Drawn samples from the parameterized exponential distribution.
+    """
+    from ...numpy import ndarray as np_ndarray
+    tensor_type_name = np_ndarray
+    if size == ():
+        size = None
+    is_tensor = isinstance(scale, tensor_type_name)
+    if is_tensor:
+        return _npi.exponential(scale, scale=None, size=size)
+    else:
+        return _npi.exponential(scale=scale, size=size)
+
+
 def gamma(shape, scale=1.0, size=None, dtype=None, ctx=None, out=None):
     """Draw samples from a Gamma distribution.
 
diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py
index 198f2fc..fe98a12 100644
--- a/python/mxnet/numpy/random.py
+++ b/python/mxnet/numpy/random.py
@@ -20,8 +20,9 @@
 from __future__ import absolute_import
 from ..ndarray import numpy as _mx_nd_np
 
+
 __all__ = ["randint", "uniform", "normal", "choice", "rand", "multinomial", "shuffle", "randn",
-           "gamma"]
+           "gamma", "exponential"]
 
 
 def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
@@ -324,6 +325,28 @@ def rand(*size, **kwargs):
     return _mx_nd_np.random.uniform(0, 1, size=output_shape, **kwargs)
 
 
+def exponential(scale=1.0, size=None):
+    r"""Draw samples from an exponential distribution.
+
+    Parameters
+    ----------
+    scale : float or array_like of floats
+        The scale parameter, :math:`\beta = 1/\lambda`. Must be
+        non-negative.
+    size : int or tuple of ints, optional
+        Output shape.  If the given shape is, e.g., ``(m, n, k)``, then
+        ``m * n * k`` samples are drawn.  If size is ``None`` (default),
+        a single value is returned if ``scale`` is a scalar.  Otherwise,
+        ``np.array(scale).size`` samples are drawn.
+
+    Returns
+    -------
+    out : ndarray or scalar
+        Drawn samples from the parameterized exponential distribution.
+    """
+    return _mx_nd_np.random.exponential(scale, size)
+
+
 def shuffle(x):
     """
     Modify a sequence in-place by shuffling its contents.
diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py
index c6b23b5..33e57f3 100644
--- a/python/mxnet/symbol/numpy/random.py
+++ b/python/mxnet/symbol/numpy/random.py
@@ -21,7 +21,8 @@ from __future__ import absolute_import
 from ...context import current_context
 from . import _internal as _npi
 
-__all__ = ['randint', 'uniform', 'normal', 'rand', 'shuffle', 'gamma']
+
+__all__ = ['randint', 'uniform', 'normal', 'rand', 'shuffle', 'gamma', 'exponential']
 
 
 def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
@@ -347,6 +348,36 @@ def gamma(shape, scale=1.0, size=None, dtype=None, ctx=None, out=None):
     raise ValueError("Distribution parameters must be either _Symbol or numbers")
 
 
+def exponential(scale=1.0, size=None):
+    r"""Draw samples from an exponential distribution.
+
+    Parameters
+    ----------
+    scale : float or array_like of floats
+        The scale parameter, :math:`\beta = 1/\lambda`. Must be
+        non-negative.
+    size : int or tuple of ints, optional
+        Output shape.  If the given shape is, e.g., ``(m, n, k)``, then
+        ``m * n * k`` samples are drawn.  If size is ``None`` (default),
+        a single value is returned if ``scale`` is a scalar.  Otherwise,
+        ``np.array(scale).size`` samples are drawn.
+
+    Returns
+    -------
+    out : ndarray or scalar
+        Drawn samples from the parameterized exponential distribution.
+    """
+    from ..numpy import _Symbol as np_symbol
+    tensor_type_name = np_symbol
+    if size == ():
+        size = None
+    is_tensor = isinstance(scale, tensor_type_name)
+    if is_tensor:
+        return _npi.exponential(scale, scale=None, size=size)
+    else:
+        return _npi.exponential(scale=scale, size=size)
+
+
 def shuffle(x):
     """
     Modify a sequence in-place by shuffling its contents.
diff --git a/src/operator/numpy/random/np_exponential_op.cc b/src/operator/numpy/random/np_exponential_op.cc
new file mode 100644
index 0000000..cc79fd8
--- /dev/null
+++ b/src/operator/numpy/random/np_exponential_op.cc
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_exponential_op.cc
+ * \brief Operator for numpy sampling from exponential distributions
+ */
+
+#include "./np_exponential_op.h"
+#include "./dist_common.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(NumpyExponentialParam);
+
+NNVM_REGISTER_OP(_npi_exponential)
+.set_num_inputs(
+  [](const nnvm::NodeAttrs& attrs) {
+    const NumpyExponentialParam& param = nnvm::get<NumpyExponentialParam>(attrs.parsed);
+    int num_inputs = 1;
+    if (param.scale.has_value()) {
+      num_inputs -= 1;
+    }
+    return num_inputs;
+  })
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    const NumpyExponentialParam& param = nnvm::get<NumpyExponentialParam>(attrs.parsed);
+    int num_inputs = 1;
+    if (param.scale.has_value()) {
+      num_inputs -= 1;
+    }
+    return (num_inputs == 0) ? std::vector<std::string>() : std::vector<std::string>{"input1"};
+  })
+.set_attr_parser(ParamParser<NumpyExponentialParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", UnaryDistOpShape<NumpyExponentialParam>)
+.set_attr<nnvm::FInferType>("FInferType",
+  [](const nnvm::NodeAttrs &attrs, std::vector<int> *in_attrs,  std::vector<int> *out_attrs) {
+    (*out_attrs)[0] = mshadow::kFloat32;
+    return true;
+  })
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const nnvm::NodeAttrs& attrs) {
+      return std::vector<ResourceRequest>{
+        ResourceRequest::kRandom, ResourceRequest::kTempSpace};
+  })
+.set_attr<FCompute>("FCompute<cpu>", NumpyExponentialForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_argument("input1", "NDArray-or-Symbol", "Source input")
+.add_arguments(NumpyExponentialParam::__FIELDS__());
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/numpy/random/np_exponential_op.cu b/src/operator/numpy/random/np_exponential_op.cu
new file mode 100644
index 0000000..1c0ff12
--- /dev/null
+++ b/src/operator/numpy/random/np_exponential_op.cu
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_exponential_op.cu
+ * \brief Operator for numpy sampling from exponential distributions
+ */
+
+#include "./np_exponential_op.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(_npi_exponential)
+.set_attr<FCompute>("FCompute<gpu>", NumpyExponentialForward<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/numpy/random/np_exponential_op.h b/src/operator/numpy/random/np_exponential_op.h
new file mode 100644
index 0000000..6f64429
--- /dev/null
+++ b/src/operator/numpy/random/np_exponential_op.h
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_exponential_op.h
+ * \brief Operator for numpy sampling from exponential distribution.
+ */
+
+#ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_EXPONENTIAL_OP_H_
+#define MXNET_OPERATOR_NUMPY_RANDOM_NP_EXPONENTIAL_OP_H_
+
+#include <mxnet/operator_util.h>
+#include <algorithm>
+#include <string>
+#include <vector>
+#include <cmath>
+#include "../../elemwise_op_common.h"
+#include "../../mshadow_op.h"
+#include "../../mxnet_op.h"
+#include "../../operator_common.h"
+#include "../../tensor/elemwise_binary_broadcast_op.h"
+#include "./dist_common.h"
+
+namespace mxnet {
+namespace op {
+
+struct NumpyExponentialParam : public dmlc::Parameter<NumpyExponentialParam> {
+  dmlc::optional<float> scale;
+  dmlc::optional<mxnet::Tuple<int>> size;
+  DMLC_DECLARE_PARAMETER(NumpyExponentialParam) {
+      DMLC_DECLARE_FIELD(scale)
+      .set_default(dmlc::optional<float>(1.0));
+      DMLC_DECLARE_FIELD(size)
+      .set_default(dmlc::optional<mxnet::Tuple<int>>())
+      .describe("Output shape. If the given shape is, "
+          "e.g., (m, n, k), then m * n * k samples are drawn. "
+          "Default is None, in which case a single value is returned.");
+  }
+};
+
+template <typename DType>
+struct scalar_exponential_kernel {
+  MSHADOW_XINLINE static void Map(index_t i, float scale, float *threshold,
+                                  DType *out) {
+    out[i] = -scale * log(threshold[i]);
+  }
+};
+
+namespace mxnet_op {
+
+template <typename IType>
+struct check_legal_scale_kernel {
+  MSHADOW_XINLINE static void Map(index_t i, IType *scalar, float* flag) {
+    if (scalar[i] < 0.0) {
+      flag[0] = -1.0;
+    }
+  }
+};
+
+
+template <int ndim, typename IType, typename OType>
+struct exponential_kernel {
+  MSHADOW_XINLINE static void Map(index_t i,
+                                  const Shape<ndim> &stride,
+                                  const Shape<ndim> &oshape,
+                                  IType *scales, float* threshold, OType *out) {
+    Shape<ndim> coord = unravel(i, oshape);
+    auto idx = static_cast<index_t>(dot(coord, stride));
+    out[i] =  -scales[idx] * log(threshold[i]);
+  }
+};
+
+}  // namespace mxnet_op
+
+template <typename xpu>
+void NumpyExponentialForward(const nnvm::NodeAttrs &attrs,
+                             const OpContext &ctx,
+                             const std::vector<TBlob> &inputs,
+                             const std::vector<OpReqType> &req,
+                             const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  const NumpyExponentialParam &param = nnvm::get<NumpyExponentialParam>(attrs.parsed);
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  index_t output_len = outputs[0].Size();
+  Random<xpu, float> *prnd = ctx.requested[0].get_random<xpu, float>(s);
+  Tensor<xpu, 1, float> workspace =
+      ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(output_len + 1), s);
+  Tensor<xpu, 1, float> uniform_tensor = workspace.Slice(0, output_len);
+  Tensor<xpu, 1, float> indicator_device = workspace.Slice(output_len, output_len + 1);
+  float indicator_host = 1.0;
+  float *indicator_device_ptr = indicator_device.dptr_;
+  Kernel<set_zero, xpu>::Launch(s, 1, indicator_device_ptr);
+  prnd->SampleUniform(&workspace, 0.0, 1.0);
+  if (param.scale.has_value()) {
+    CHECK_GE(param.scale.value(), 0.0) << "ValueError: expect scale >= 0";
+    MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      Kernel<scalar_exponential_kernel<DType>, xpu>::Launch(
+        s, outputs[0].Size(), param.scale.value(),
+        uniform_tensor.dptr_, outputs[0].dptr<DType>());
+    });
+  } else {
+    MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, {
+      Kernel<check_legal_scale_kernel<IType>, xpu>::Launch(
+      s, inputs[0].Size(), inputs[0].dptr<IType>(), indicator_device_ptr);
+    });
+    _copy<xpu>(s, &indicator_host, indicator_device_ptr);
+    CHECK_GE(indicator_host, 0.0) << "ValueError: expect scale >= 0";
+    mxnet::TShape new_lshape, new_oshape;
+    int ndim = FillShape(inputs[0].shape_, inputs[0].shape_, outputs[0].shape_,
+                         &new_lshape, &new_lshape, &new_oshape);
+    MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, {
+      MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+        BROADCAST_NDIM_SWITCH(ndim, NDim, {
+          Shape<NDim> oshape = new_oshape.get<NDim>();
+          Shape<NDim> stride = calc_stride(new_lshape.get<NDim>());
+          Kernel<exponential_kernel<NDim, IType, OType>, xpu>::Launch(
+              s, outputs[0].Size(), stride, oshape, inputs[0].dptr<IType>(),
+              uniform_tensor.dptr_, outputs[0].dptr<OType>());
+        });
+      });
+    });
+  }
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_NUMPY_RANDOM_NP_EXPONENTIAL_OP_H_
diff --git a/tests/nightly/test_np_random.py b/tests/nightly/test_np_random.py
index d086ac4..8b60426 100644
--- a/tests/nightly/test_np_random.py
+++ b/tests/nightly/test_np_random.py
@@ -42,6 +42,22 @@ import scipy.stats as ss
 @retry(5)
 @with_seed()
 @use_np
+def test_np_exponential():
+    samples = 1000000
+    # Generation test
+    trials = 8
+    num_buckets = 5
+    for scale in [1.0, 5.0]:
+        buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.expon.ppf(x, scale=scale), num_buckets)
+        buckets = np.array(buckets, dtype="float32").tolist()
+        probs = [(buckets[i][1] - buckets[i][0])/scale for i in range(num_buckets)]
+        generator_mx_np = lambda x: mx.np.random.exponential(size=x).asnumpy()
+        verify_generator(generator=generator_mx_np, buckets=buckets, probs=probs, nsamples=samples, nrepeat=trials)
+
+
+@retry(5)
+@with_seed()
+@use_np
 def test_np_uniform():
     types = [None, "float32", "float64"]
     ctx = mx.context.current_context()
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 8d54a53..273e520 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -3433,6 +3433,36 @@ def test_np_random():
 
 @with_seed()
 @use_np
+def test_np_exponential():
+    class TestRandomExp(HybridBlock):
+        def __init__(self, shape):
+            super(TestRandomExp, self).__init__()
+            self._shape = shape
+
+        def hybrid_forward(self, F, scale):
+            return F.np.random.exponential(scale, self._shape)
+
+    shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None]
+    for hybridize in [False, True]:
+        for shape in shapes:
+            test_exponential = TestRandomExp(shape)
+            if hybridize:
+                test_exponential.hybridize()
+            np_out = _np.random.exponential(size = shape)
+            mx_out = test_exponential(np.array([1]))
+    
+    for shape in shapes:
+        mx_out = np.random.exponential(np.array([1]), shape)
+        np_out = _np.random.exponential(np.array([1]).asnumpy(), shape)
+        assert_almost_equal(mx_out.asnumpy().shape, np_out.shape)
+
+    def _test_exponential_exception(scale):
+        output = np.random.exponential(scale=scale).asnumpy()
+    assertRaises(ValueError, _test_exponential_exception, -1)
+
+
+@with_seed()
+@use_np
 def test_np_randn():
     # Test shapes.
     shapes = [