You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2022/02/05 16:55:02 UTC
[incubator-mxnet] branch master updated: [FEATURE] Add binomial sampling and fix multinomial sampling (#20734)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new e9becb9 [FEATURE] Add binomial sampling and fix multinomial sampling (#20734)
e9becb9 is described below
commit e9becb91c82486723d288b572c6721eaa6f2fd21
Author: Vincenzo Territo <47...@users.noreply.github.com>
AuthorDate: Sat Feb 5 17:53:43 2022 +0100
[FEATURE] Add binomial sampling and fix multinomial sampling (#20734)
* implement binomial sampling
* add correct multinomial implementation
* small fix in binomial symbol api doc
* change npx_categorical to npx_multinomial
* small sanity fix
* fix python unit tests
* rename previous multinomial implementation to categorical
* small fix
---
python/mxnet/amp/lists/symbol_fp16.py | 4 +
python/mxnet/ndarray/random.py | 134 ++++++++++++++++--
python/mxnet/symbol/random.py | 82 +++++++++--
src/operator/random/multisample_op.cc | 38 ++++++
src/operator/random/multisample_op.cu | 3 +
src/operator/random/sample_multinomial_op.cc | 80 ++++++++---
src/operator/random/sample_multinomial_op.cu | 9 +-
src/operator/random/sample_multinomial_op.h | 194 ++++++++++++++++++++++++--
src/operator/random/sample_op.cc | 16 +++
src/operator/random/sample_op.cu | 2 +
src/operator/random/sample_op.h | 73 ++++++++++
src/operator/random/sampler.h | 195 +++++++++++++++++++++++++++
tests/python/unittest/test_numpy_op.py | 27 ++++
tests/python/unittest/test_random.py | 79 ++++++++---
14 files changed, 868 insertions(+), 68 deletions(-)
diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py
index a6d8a53..3f3dd48 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -336,6 +336,8 @@ FP16_FP32_FUNCS = [
'_random_exponential_like',
'_random_gamma',
'_random_gamma_like',
+ '_random_binomial',
+ '_random_binomial_like',
'_random_generalized_negative_binomial',
'_random_generalized_negative_binomial_like',
'_random_negative_binomial',
@@ -353,7 +355,9 @@ FP16_FP32_FUNCS = [
'_rnn_param_concat',
'_sample_exponential',
'_sample_gamma',
+ '_sample_binomial',
'_sample_generalized_negative_binomial',
+ '_sample_categorical',
'_sample_multinomial',
'_sample_negative_binomial',
'_sample_normal',
diff --git a/python/mxnet/ndarray/random.py b/python/mxnet/ndarray/random.py
index b0683b4..1e7955f 100644
--- a/python/mxnet/ndarray/random.py
+++ b/python/mxnet/ndarray/random.py
@@ -23,8 +23,8 @@ from . import _internal
from .ndarray import NDArray
-__all__ = ['uniform', 'normal', 'randn', 'poisson', 'exponential', 'gamma',
- 'multinomial', 'negative_binomial', 'generalized_negative_binomial',
+__all__ = ['uniform', 'normal', 'randn', 'poisson', 'exponential', 'gamma', 'binomial',
+ 'categorical', 'multinomial', 'negative_binomial', 'generalized_negative_binomial',
'shuffle', 'randint']
@@ -383,6 +383,59 @@ def gamma(alpha=1, beta=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwarg
[alpha, beta], shape, dtype, ctx, out, kwargs)
+def binomial(n=1, p=0.5, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs):
+ """Draw random samples from a binomial distribution.
+
+ Samples are distributed according to a binomial distribution parametrized
+ by *n* (number of trials) and *p* (success probability).
+
+ Parameters
+ ----------
+ n : float or NDArray, optional
+ Number of experiments, > 0.
+ p : float or NDArray, optional
+ Success probability in each experiment, >= 0 and <= 1.
+ shape : int or tuple of ints, optional
+ The number of samples to draw. If shape is, e.g., `(m, n)` and `n` and
+ `p` are scalars, output shape will be `(m, n)`. If `n` and `p`
+ are NDArrays with shape, e.g., `(x, y)`, then output will have shape
+ `(x, y, m, n)`, where `m*n` samples are drawn for each `[n, p)` pair.
+ dtype : {'float16', 'float32', 'float64'}, optional
+ Data type of output samples. Default is 'float32'
+ ctx : Context, optional
+ Device context of output. Default is current context. Overridden by
+ `n.context` when `n` is an NDArray.
+ out : NDArray, optional
+ Store output to an existing NDArray.
+
+ Returns
+ -------
+ NDArray
+ If input `shape` has shape, e.g., `(m, n)` and `n` and `p` are scalars, output
+ shape will be `(m, n)`. If `n` and `p` are NDArrays with shape, e.g.,
+ `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are
+ drawn for each `[n, p)` pair.
+
+ Examples
+ --------
+ >>> mx.nd.random.binomial(10, 0.1)
+ [ 1.]
+ <NDArray 1 @cpu(0)>
+ >>> mx.nd.random.binomial(10, 0.6, shape=(2,))
+ [ 4. 6.]
+ <NDArray 2 @cpu(0)>
+ >>> n = mx.nd.array([10,2,3])
+ >>> p = mx.nd.array([0.2,0.3,0.4])
+ >>> mx.nd.random.binomial(n, p, shape=2)
+ [[ 1. 4.]
+ [ 0. 2.]
+ [ 1. 1.]]
+ <NDArray 3x2 @cpu(0)>
+ """
+ return _random_helper(_internal._random_binomial, _internal._sample_binomial,
+ [n, p], shape, dtype, ctx, out, kwargs)
+
+
def negative_binomial(k=1, p=1, shape=_Null, dtype=_Null, ctx=None,
out=None, **kwargs):
"""Draw random samples from a negative binomial distribution.
@@ -496,9 +549,8 @@ def generalized_negative_binomial(mu=1, alpha=1, shape=_Null, dtype=_Null, ctx=N
_internal._sample_generalized_negative_binomial,
[mu, alpha], shape, dtype, ctx, out, kwargs)
-
-def multinomial(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kwargs):
- """Concurrent sampling from multiple multinomial distributions.
+def categorical(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kwargs):
+ """Concurrent sampling from multiple categorical distributions.
.. note:: The input distribution must be normalized, i.e. `data` must sum to
1 along its last dimension.
@@ -507,8 +559,8 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kw
----------
data : NDArray
An *n* dimensional array whose last dimension has length `k`, where
- `k` is the number of possible outcomes of each multinomial distribution.
- For example, data with shape `(m, n, k)` specifies `m*n` multinomial
+ `k` is the number of possible outcomes of each categorical distribution.
+ For example, data with shape `(m, n, k)` specifies `m*n` categorical
distributions each with `k` possible outcomes.
shape : int or tuple of ints, optional
The number of samples to draw from each distribution. If shape is empty
@@ -530,7 +582,7 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kw
For input `data` with `n` dimensions and shape `(d1, d2, ..., dn-1, k)`, and input
`shape` with shape `(s1, s2, ..., sx)`, returns an NDArray with shape
`(d1, d2, ... dn-1, s1, s2, ..., sx)`. The `s1, s2, ... sx` dimensions of the
- returned NDArray consist of 0-indexed values sampled from each respective multinomial
+ returned NDArray consist of 0-indexed values sampled from each respective categorical
distribution provided in the `k` dimension of `data`.
For the case `n`=1, and `x`=1 (one shape dimension), returned NDArray has shape `(s1,)`.
@@ -542,24 +594,80 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kw
Examples
--------
>>> probs = mx.nd.array([0, 0.1, 0.2, 0.3, 0.4])
- >>> mx.nd.random.multinomial(probs)
+ >>> mx.nd.random.categorical(probs)
[3]
<NDArray 1 @cpu(0)>
>>> probs = mx.nd.array([[0, 0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1, 0]])
- >>> mx.nd.random.multinomial(probs)
+ >>> mx.nd.random.categorical(probs)
[3 1]
<NDArray 2 @cpu(0)>
- >>> mx.nd.random.multinomial(probs, shape=2)
+ >>> mx.nd.random.categorical(probs, shape=2)
[[4 4]
[1 2]]
<NDArray 2x2 @cpu(0)>
- >>> mx.nd.random.multinomial(probs, get_prob=True)
+ >>> mx.nd.random.categorical(probs, get_prob=True)
[3 2]
<NDArray 2 @cpu(0)>
[-1.20397282 -1.60943794]
<NDArray 2 @cpu(0)>
"""
- return _internal._sample_multinomial(data, shape, get_prob, out=out, dtype=dtype, **kwargs)
+ return _internal._sample_categorical(data, shape, get_prob, out=out, dtype=dtype, **kwargs)
+
+
+def multinomial(n=[1], p=[[1.0]], shape=_Null, dtype='float32', ctx=None, out=None, **kwargs):
+ """Concurrent sampling from multiple multinomial distributions.
+
+ .. note:: The input distribution must be normalized, i.e. `p` must sum to
+ 1 along its last dimension.
+
+ Parameters
+ ----------
+ n : NDArray
+ An *n* dimensional array containing the number of trials of each
+ multinomial distribution.
+ p : NDArray
+ An *n+1* dimensional array containing the probabilities of each multinomial
+ distribution. Its last dimension has length `k`, where `k` is the number
+ of possible outcomes of each multinomial distribution.
+ For example, p with shape `(m, n, k)` specifies `m*n` multinomial
+ distributions each with `k` possible outcomes.
+ shape : int or tuple of ints, optional
+ The number of samples to draw from each distribution. If shape is empty
+ one sample will be drawn from each distribution.
+ out : NDArray, optional
+ Store output to an existing NDArray.
+ ctx : Context, optional
+ Device context of output. Default is current context. Overridden by
+ `n.context` when `n` is an NDArray.
+ dtype : {'float16', 'float32', 'float64'}, optional
+ Data type of output samples. Default is 'float32'
+
+ Returns
+ -------
+ NDArray
+ If input `shape` has shape, e.g., `(m, n)` and `n` and `p` are a scalar and an array of length k
+ respectively, output shape will be `(m, n, k)`. If `n` and `p` are NDArrays with shape, e.g.,
+ `(x, y)` and `(x, y, k)`, then output will have shape `(x, y, m, n, k)`, where `m*n`
+ samples are drawn for each `[n, p)` pair.
+
+ Examples
+ --------
+ >>> mx.nd.random.multinomial(mx.nd.array([10]), mx.nd.array([[0.1, 0.9]]))
+ [[ 1. 9.]]
+ <NDArray 1x2 @cpu(0)>
+ >>> mx.nd.random.multinomial(mx.nd.array([10]), mx.nd.array([[0.6, 0.4]]), shape=(2,))
+ [[[ 5. 5.]
+ [ 6. 4.]]]
+ <NDArray 1x2x2 @cpu(0)>
+ >>> n = mx.nd.array([10, 2, 3])
+ >>> p = mx.nd.array([[0.2, 0.8], [0.3, 0.7], [0.4, 0.6]])
+ >>> mx.nd.random.binomial(n, p)
+ [[ 2. 8.]
+ [ 1. 1.]
+ [ 1. 2.]]
+ <NDArray 3x2 @cpu(0)>
+ """
+ return _internal._sample_multinomial(n, p, shape=shape, out=out, ctx=ctx, dtype=dtype, **kwargs)
def shuffle(data, **kwargs):
diff --git a/python/mxnet/symbol/random.py b/python/mxnet/symbol/random.py
index b2ff104..827ec40 100644
--- a/python/mxnet/symbol/random.py
+++ b/python/mxnet/symbol/random.py
@@ -22,8 +22,8 @@ from . import _internal
from .symbol import Symbol
-__all__ = ['uniform', 'normal', 'randn', 'poisson', 'exponential', 'gamma', 'multinomial',
- 'negative_binomial', 'generalized_negative_binomial', 'shuffle', 'randint']
+__all__ = ['uniform', 'normal', 'randn', 'poisson', 'exponential', 'gamma', 'categorical', 'multinomial',
+ 'binomial', 'negative_binomial', 'generalized_negative_binomial', 'shuffle', 'randint']
def _random_helper(random, sampler, params, shape, dtype, kwargs):
@@ -240,6 +240,38 @@ def gamma(alpha=1, beta=1, shape=_Null, dtype=_Null, **kwargs):
[alpha, beta], shape, dtype, kwargs)
+def binomial(n=1, p=0.5, shape=_Null, dtype=_Null, **kwargs):
+ """Draw random samples from a binomial distribution.
+
+ Samples are distributed according to a binomial distribution parametrized
+ by *n* (number of trials) and *p* (success probability).
+
+ Parameters
+ ----------
+ n : float or Symbol, optional
+ Number of experiments, > 0.
+ p : float or Symbol, optional
+ Success probability in each experiment, >= 0 and <= 1.
+ shape : int or tuple of ints, optional
+ The number of samples to draw. If shape is, e.g., `(m, n)` and `n` and
+ `p` are scalars, output shape will be `(m, n)`. If `n` and `p`
+ are NDArrays with shape, e.g., `(x, y)`, then output will have shape
+ `(x, y, m, n)`, where `m*n` samples are drawn for each `[n, p)` pair.
+ dtype : {'float16', 'float32', 'float64'}, optional
+ Data type of output samples. Default is 'float32'
+
+ Returns
+ -------
+ Symbol
+ If input `shape` has shape, e.g., `(m, n)` and `n` and `p` are scalars, output
+ shape will be `(m, n)`. If `n` and `p` are NDArrays with shape, e.g.,
+ `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are
+ drawn for each `[n, p)` pair.
+ """
+ return _random_helper(_internal._random_binomial, _internal._sample_binomial,
+ [n, p], shape, dtype, kwargs)
+
+
def negative_binomial(k=1, p=1, shape=_Null, dtype=_Null, **kwargs):
"""Draw random samples from a negative binomial distribution.
@@ -311,8 +343,8 @@ def generalized_negative_binomial(mu=1, alpha=1, shape=_Null, dtype=_Null, **kwa
[mu, alpha], shape, dtype, kwargs)
-def multinomial(data, shape=_Null, get_prob=True, dtype='int32', **kwargs):
- """Concurrent sampling from multiple multinomial distributions.
+def categorical(data, shape=_Null, get_prob=True, dtype='int32', **kwargs):
+ """Concurrent sampling from multiple categorical distributions.
.. note:: The input distribution must be normalized, i.e. `data` must sum to
1 along its last dimension.
@@ -321,8 +353,8 @@ def multinomial(data, shape=_Null, get_prob=True, dtype='int32', **kwargs):
----------
data : Symbol
An *n* dimensional array whose last dimension has length `k`, where
- `k` is the number of possible outcomes of each multinomial distribution.
- For example, data with shape `(m, n, k)` specifies `m*n` multinomial
+ `k` is the number of possible outcomes of each categorical distribution.
+ For example, data with shape `(m, n, k)` specifies `m*n` categorical
distributions each with `k` possible outcomes.
shape : int or tuple of ints, optional
The number of samples to draw from each distribution. If shape is empty
@@ -343,7 +375,7 @@ def multinomial(data, shape=_Null, get_prob=True, dtype='int32', **kwargs):
`shape` with shape `(s1, s2, ..., sx)`, returns a Symbol that resovles to shape
`(d1, d2, ... dn-1, s1, s2, ..., sx)`. The `s1, s2, ... sx` dimensions of the
returned Symbol's resolved value will consist of 0-indexed values sampled from each
- respective multinomial distribution provided in the `k` dimension of `data`.
+ respective categorical distribution provided in the `k` dimension of `data`.
For the case `n`=1, and `x`=1 (one shape dimension), returned Symbol will resolve to
shape `(s1,)`.
@@ -352,7 +384,41 @@ def multinomial(data, shape=_Null, get_prob=True, dtype='int32', **kwargs):
outputs: `[ndarray_output, log_likelihood_output]`, where `log_likelihood_output` will resolve
to the same shape as the sampled outputs in ndarray_output.
"""
- return _internal._sample_multinomial(data, shape, get_prob, dtype=dtype, **kwargs)
+ return _internal._sample_categorical(data, shape, get_prob, dtype=dtype, **kwargs)
+
+
+def multinomial(n=[1], p=[[1.0]], shape=_Null, dtype='float32', **kwargs):
+ """Concurrent sampling from multiple multinomial distributions.
+
+ .. note:: The input distribution must be normalized, i.e. `p` must sum to
+ 1 along its last dimension.
+
+ Parameters
+ ----------
+ n : Symbol
+ An *n* dimensional array containing the number of trials of each
+ multinomial distribution.
+ p : Symbol
+ An *n+1* dimensional array containing the probabilities of each multinomial
+ distribution. Its last dimension has length `k`, where `k` is the number
+ of possible outcomes of each multinomial distribution.
+ For example, p with shape `(m, n, k)` specifies `m*n` multinomial
+ distributions each with `k` possible outcomes.
+ shape : int or tuple of ints, optional
+ The number of samples to draw from each distribution. If shape is empty
+ one sample will be drawn from each distribution.
+ dtype : {'float16', 'float32', 'float64'}, optional
+ Data type of output samples. Default is 'float32'
+
+ Returns
+ -------
+ Symbol
+ If input `shape` has shape, e.g., `(m, n)` and `n` and `p` are a scalar and an array of length k
+ respectively, output shape will be `(m, n, k)`. If `n` and `p` are NDArrays with shape, e.g.,
+ `(x, y)` and `(x, y, k)`, then output will have shape `(x, y, m, n, k)`, where `m*n`
+ samples are drawn for each `[n, p)` pair.
+ """
+ return _internal._sample_multinomial(n, p, shape, dtype=dtype, **kwargs)
def shuffle(data, **kwargs):
diff --git a/src/operator/random/multisample_op.cc b/src/operator/random/multisample_op.cc
index b7f2214..7b0a02d 100644
--- a/src/operator/random/multisample_op.cc
+++ b/src/operator/random/multisample_op.cc
@@ -218,6 +218,37 @@ Examples::
)code");
}
+inline std::string binomial_desc() {
+ return std::string(R"code(Concurrent sampling from multiple
+binomial distributions with parameters *n* (number of trials) and *p* (success probability).
+
+The parameters of the distributions are provided as input arrays.
+Let *[s]* be the shape of the input arrays, *n* be the dimension of *[s]*, *[t]*
+be the shape specified as the parameter of the operator, and *m* be the dimension
+of *[t]*. Then the output will be a *(n+m)*-dimensional array with shape *[s]x[t]*.
+
+For any valid *n*-dimensional index *i* with respect to the input arrays, *output[i]*
+will be an *m*-dimensional array that holds randomly drawn samples from the distribution
+which is parameterized by the input values at index *i*. If the shape parameter of the
+operator is not set, then one sample will be drawn per distribution and the output array
+has the same shape as the input arrays.
+
+Samples will always be returned as a floating point data type.
+
+Examples::
+
+ n = [ 20, 49 ]
+ p = [ 0.4 , 0.77 ]
+
+ // Draw a single sample for each distribution
+ sample_binomial(n, p) = [ 5., 36.]
+
+ // Draw a vector containing two samples for each distribution
+ sample_binomial(n, p, shape=(2)) = [[ 5., 40.],
+ [ 11., 35.]]
+)code");
+}
+
inline std::string negative_binomial_desc() {
return std::string(R"code(Concurrent sampling from multiple
negative binomial distributions with parameters *k* (failure limit) and *p* (failure probability).
@@ -312,6 +343,13 @@ MXNET_OPERATOR_REGISTER_SAMPLING1(poisson,
"Lambda (rate) parameters of the distributions.",
poisson_desc)
.add_alias("_npx_tensor_poisson");
+MXNET_OPERATOR_REGISTER_SAMPLING2(binomial,
+ BinomialSampler<cpu>,
+ "n",
+ "p",
+ "Number of experiments.",
+ "Success probabilities in each experiment.",
+ binomial_desc);
MXNET_OPERATOR_REGISTER_SAMPLING2(negative_binomial,
NegativeBinomialSampler<cpu>,
"k",
diff --git a/src/operator/random/multisample_op.cu b/src/operator/random/multisample_op.cu
index 4c571e7..e7756c6 100644
--- a/src/operator/random/multisample_op.cu
+++ b/src/operator/random/multisample_op.cu
@@ -42,6 +42,9 @@ NNVM_REGISTER_OP(_sample_exponential)
NNVM_REGISTER_OP(_sample_poisson)
.set_attr<FCompute>("FCompute<gpu>", MultiSampleOpForward<gpu, PoissonSampler<gpu>, 1>);
+NNVM_REGISTER_OP(_sample_binomial)
+ .set_attr<FCompute>("FCompute<gpu>", MultiSampleOpForward<gpu, BinomialSampler<gpu>, 2>);
+
NNVM_REGISTER_OP(_sample_negative_binomial)
.set_attr<FCompute>("FCompute<gpu>",
MultiSampleOpForward<gpu, NegativeBinomialSampler<gpu>, 2>);
diff --git a/src/operator/random/sample_multinomial_op.cc b/src/operator/random/sample_multinomial_op.cc
index 01d66e2..5b3edaa 100644
--- a/src/operator/random/sample_multinomial_op.cc
+++ b/src/operator/random/sample_multinomial_op.cc
@@ -26,15 +26,16 @@
namespace mxnet {
namespace op {
+DMLC_REGISTER_PARAMETER(SampleCategoricalParam);
DMLC_REGISTER_PARAMETER(SampleMultinomialParam);
-NNVM_REGISTER_OP(_sample_multinomial)
- .add_alias("sample_multinomial")
+NNVM_REGISTER_OP(_sample_categorical)
+ .add_alias("sample_categorical")
.add_alias("_npx__random_categorical")
- .describe(R"code(Concurrent sampling from multiple multinomial distributions.
+ .describe(R"code(Concurrent sampling from multiple categorical distributions.
*data* is an *n* dimensional array whose last dimension has length *k*, where
-*k* is the number of possible outcomes of each multinomial distribution. This
+*k* is the number of possible outcomes of each categorical distribution. This
operator will draw *shape* samples from each distribution. If shape is empty
one sample will be drawn from each distribution.
@@ -51,23 +52,23 @@ Examples::
probs = [[0, 0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1, 0]]
// Draw a single sample for each distribution
- sample_multinomial(probs) = [3, 0]
+ sample_categorical(probs) = [3, 0]
// Draw a vector containing two samples for each distribution
- sample_multinomial(probs, shape=(2)) = [[4, 2],
+ sample_categorical(probs, shape=(2)) = [[4, 2],
[0, 0]]
// requests log likelihood
- sample_multinomial(probs, get_prob=True) = [2, 1], [0.2, 0.3]
+ sample_categorical(probs, get_prob=True) = [2, 1], [0.2, 0.3]
)code")
.set_num_inputs(1)
.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
- const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(attrs.parsed);
+ const SampleCategoricalParam& param = nnvm::get<SampleCategoricalParam>(attrs.parsed);
return param.get_prob ? 2U : 1U;
})
- .set_attr_parser(ParamParser<SampleMultinomialParam>)
- .set_attr<mxnet::FInferShape>("FInferShape", SampleMultinomialOpShape)
- .set_attr<nnvm::FInferType>("FInferType", SampleMultinomialOpType)
+ .set_attr_parser(ParamParser<SampleCategoricalParam>)
+ .set_attr<mxnet::FInferShape>("FInferShape", SampleCategoricalOpShape)
+ .set_attr<nnvm::FInferType>("FInferType", SampleCategoricalOpType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kRandom,
@@ -76,9 +77,9 @@ Examples::
.set_attr<nnvm::FGradient>(
"FGradient",
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
- const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(n->attrs.parsed);
+ const SampleCategoricalParam& param = nnvm::get<SampleCategoricalParam>(n->attrs.parsed);
if (param.get_prob) {
- return MakeGradNode("_backward_sample_multinomial",
+ return MakeGradNode("_backward_sample_categorical",
n,
{ograds[1], n->inputs[0], nnvm::NodeEntry{n, 0, 0}},
std::unordered_map<std::string, std::string>());
@@ -86,13 +87,58 @@ Examples::
return MakeZeroGradNodes(n, ograds);
}
})
- .set_attr<FCompute>("FCompute<cpu>", SampleMultinomialForward<cpu>)
+ .set_attr<FCompute>("FCompute<cpu>", SampleCategoricalForward<cpu>)
.add_argument("data",
"NDArray-or-Symbol",
"Distribution probabilities. Must sum to one on the last axis.")
+ .add_arguments(SampleCategoricalParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_sample_multinomial)
+ .add_alias("sample_multinomial")
+ .add_alias("_npx__random_multinomial")
+ .describe(R"code(Concurrent sampling from multiple multinomial distributions.
+
+Samples are distributed according to a multinomial distribution parametrized by
+*n* (number of experiments) and *p* (success probabilities of the k possible outcomes
+in each experiment). Samples will always be returned as a floating point data type.
+
+Note that the input distribution must be normalized, i.e. *p* must sum to
+1 along its last axis.
+
+Examples::
+
+ n = [5., 6.]
+ probs = [[0., 0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1, 0.]]
+
+ multinomial(n, probs) = [[0., 0., 0., 3., 2.],
+ [0., 3., 1., 2., 0.]]
+)code")
+ .set_num_inputs(2)
+ .set_num_outputs(1)
+ .set_attr_parser(ParamParser<SampleMultinomialParam>)
+ .set_attr<mxnet::FInferShape>("FInferShape", SampleMultinomialOpShape)
+ .set_attr<nnvm::FInferType>("FInferType", SampleMultinomialOpType)
+ .set_attr<FResourceRequest>("FResourceRequest",
+ [](const nnvm::NodeAttrs& attrs) {
+ return std::vector<ResourceRequest>{
+ ResourceRequest::kParallelRandom,
+ ResourceRequest::kTempSpace};
+ })
+ .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+ .set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ std::vector<std::string> v = {"n", "p"};
+ v.resize(2);
+ return v;
+ })
+ .set_attr<FCompute>("FCompute<cpu>", SampleMultinomialForward<cpu>)
+ .add_argument("n", "NDArray-or-Symbol", "Number of experiments")
+ .add_argument("p",
+ "NDArray-or-Symbol",
+ "Probability of every outcome in each experiment. Must sum to 1 on the last axis")
.add_arguments(SampleMultinomialParam::__FIELDS__());
-struct SampleMultinomialBackwardCPUKernel {
+struct SampleCategoricalBackwardCPUKernel {
template <typename DType, typename IType>
MSHADOW_XINLINE static void
Map(int i, index_t K, index_t M, DType* ograd, DType* dist, IType* out, DType* igrad) {
@@ -103,12 +149,12 @@ struct SampleMultinomialBackwardCPUKernel {
}
};
-NNVM_REGISTER_OP(_backward_sample_multinomial)
+NNVM_REGISTER_OP(_backward_sample_categorical)
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>",
- SampleMultinomialBackward<SampleMultinomialBackwardCPUKernel, cpu>);
+ SampleCategoricalBackward<SampleCategoricalBackwardCPUKernel, cpu>);
} // namespace op
} // namespace mxnet
diff --git a/src/operator/random/sample_multinomial_op.cu b/src/operator/random/sample_multinomial_op.cu
index 79b45ff..3123d7b 100644
--- a/src/operator/random/sample_multinomial_op.cu
+++ b/src/operator/random/sample_multinomial_op.cu
@@ -26,10 +26,13 @@
namespace mxnet {
namespace op {
+NNVM_REGISTER_OP(_sample_categorical)
+ .set_attr<FCompute>("FCompute<gpu>", SampleCategoricalForward<gpu>);
+
NNVM_REGISTER_OP(_sample_multinomial)
.set_attr<FCompute>("FCompute<gpu>", SampleMultinomialForward<gpu>);
-struct SampleMultinomialBackwardGPUKernel {
+struct SampleCategoricalBackwardGPUKernel {
template <typename DType, typename IType>
MSHADOW_XINLINE static void
Map(int i, index_t K, index_t M, DType* ograd, DType* dist, IType* out, DType* igrad) {
@@ -40,9 +43,9 @@ struct SampleMultinomialBackwardGPUKernel {
}
};
-NNVM_REGISTER_OP(_backward_sample_multinomial)
+NNVM_REGISTER_OP(_backward_sample_categorical)
.set_attr<FCompute>("FCompute<gpu>",
- SampleMultinomialBackward<SampleMultinomialBackwardGPUKernel, gpu>);
+ SampleCategoricalBackward<SampleCategoricalBackwardGPUKernel, gpu>);
} // namespace op
} // namespace mxnet
diff --git a/src/operator/random/sample_multinomial_op.h b/src/operator/random/sample_multinomial_op.h
index 9cc1de6..3346ae4 100644
--- a/src/operator/random/sample_multinomial_op.h
+++ b/src/operator/random/sample_multinomial_op.h
@@ -26,19 +26,21 @@
#include <mxnet/operator_util.h>
#include <vector>
+#include <string>
#include "../mshadow_op.h"
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"
+#include "./sampler.h"
namespace mxnet {
namespace op {
-struct SampleMultinomialParam : public dmlc::Parameter<SampleMultinomialParam> {
+struct SampleCategoricalParam : public dmlc::Parameter<SampleCategoricalParam> {
mxnet::TShape shape;
bool get_prob;
int dtype;
- DMLC_DECLARE_PARAMETER(SampleMultinomialParam) {
+ DMLC_DECLARE_PARAMETER(SampleCategoricalParam) {
DMLC_DECLARE_FIELD(shape)
.set_default(mxnet::TShape(0, 1))
.describe("Shape to be sampled from each random distribution.");
@@ -57,10 +59,32 @@ struct SampleMultinomialParam : public dmlc::Parameter<SampleMultinomialParam> {
}
};
-inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs,
+struct SampleMultinomialParam : public dmlc::Parameter<SampleMultinomialParam> {
+ mxnet::TShape shape;
+ std::string ctx;
+ int dtype;
+ DMLC_DECLARE_PARAMETER(SampleMultinomialParam) {
+ DMLC_DECLARE_FIELD(shape)
+ .set_default(mxnet::TShape(0, 1))
+ .describe("Shape to be sampled from each random distribution.");
+ DMLC_DECLARE_FIELD(ctx).set_default("").describe(
+ "Context of output, in format [cpu|gpu|cpu_pinned](n)."
+ " Only used for imperative calls.");
+ DMLC_DECLARE_FIELD(dtype)
+ .add_enum("uint8", mshadow::kUint8)
+ .add_enum("int32", mshadow::kInt32)
+ .add_enum("float16", mshadow::kFloat16)
+ .add_enum("float32", mshadow::kFloat32)
+ .add_enum("float64", mshadow::kFloat64)
+ .set_default(mshadow::kInt32)
+ .describe("DType of the output in case this can't be inferred.");
+ }
+};
+
+inline bool SampleCategoricalOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
- const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(attrs.parsed);
+ const SampleCategoricalParam& param = nnvm::get<SampleCategoricalParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), param.get_prob ? 2U : 1U);
@@ -98,10 +122,10 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs,
return true;
}
-inline bool SampleMultinomialOpType(const nnvm::NodeAttrs& attrs,
+inline bool SampleCategoricalOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
- const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(attrs.parsed);
+ const SampleCategoricalParam& param = nnvm::get<SampleCategoricalParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), param.get_prob ? 2U : 1U);
@@ -116,7 +140,65 @@ inline bool SampleMultinomialOpType(const nnvm::NodeAttrs& attrs,
return true;
}
-struct SampleMultinomialKernel {
+inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector* in_attrs,
+ mxnet::ShapeVector* out_attrs) {
+ const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(attrs.parsed);
+
+ CHECK_EQ(in_attrs->size(), 2U);
+ CHECK_EQ(out_attrs->size(), 1U);
+
+ const mxnet::TShape& n_shape = (*in_attrs)[0];
+ const mxnet::TShape& p_shape = (*in_attrs)[1];
+ if (!ndim_is_known(n_shape) || !ndim_is_known(p_shape) || n_shape.ndim() + 1 != p_shape.ndim())
+ return false;
+
+ mxnet::TShape oshape(p_shape.ndim() + param.shape.ndim(), -1);
+ for (int i = 0; i < p_shape.ndim() - 1; ++i) {
+ if (n_shape[i] != p_shape[i])
+ return false;
+ oshape[i] = p_shape[i];
+ }
+ for (int i = 0; i < param.shape.ndim(); ++i) {
+ oshape[i + p_shape.ndim() - 1] = param.shape[i];
+ }
+ oshape[p_shape.ndim() + param.shape.ndim() - 1] = p_shape[p_shape.ndim() - 1];
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
+
+ return true;
+}
+
+inline bool SampleMultinomialOpType(const nnvm::NodeAttrs& attrs,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(attrs.parsed);
+
+ CHECK_EQ(in_attrs->size(), 2U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ int dtype = -1;
+ int dtype_n = (*in_attrs)[0];
+ int dtype_out = (*out_attrs)[0];
+
+ if (dtype_out != -1) {
+ dtype = dtype_out;
+ if (param.dtype != -1) {
+ CHECK_EQ(dtype_out, param.dtype)
+ << "Output type does not match requested type: " << dtype_out << " vs " << param.dtype;
+ }
+ } else {
+ if (dtype_n != -1) {
+ dtype = dtype_n;
+ } else {
+ dtype = mxnet::common::GetDefaultDtype();
+ }
+ }
+
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, dtype);
+
+ return true;
+}
+
+struct SampleCategoricalKernel {
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i,
index_t K,
@@ -152,15 +234,60 @@ struct SampleMultinomialKernel {
}
};
+template <typename xpu, typename NType, typename PType, typename OType>
+MSHADOW_XINLINE void SampleMultinomial(NType N,
+ const PType* p,
+ OType* out,
+ index_t K,
+ typename RandGenerator<xpu, float>::Impl* gen) {
+ PType remaining_p = 1.0;
+ NType dN = N;
+
+ int j;
+ for (j = 0; j < K - 1; ++j) {
+ out[j] = SampleBinomial<xpu, PType, OType>(static_cast<PType>(dN), p[j] / remaining_p, gen);
+ dN = dN - out[j];
+
+ if (dN <= 0)
+ break;
+ remaining_p -= p[j];
+ }
+ for (j = j + 1; j < K; ++j)
+ out[j] = 0;
+ if (dN > 0)
+ out[K - 1] = dN;
+}
+
template <typename xpu>
-void SampleMultinomialForward(const nnvm::NodeAttrs& attrs,
+struct SampleMultinomialKernel {
+ template <typename NType, typename PType, typename OType>
+ MSHADOW_XINLINE static void Map(index_t id,
+ RandGenerator<xpu, float> gen,
+ const index_t N,
+ const index_t step,
+ index_t nParm,
+ index_t nSample,
+ index_t K,
+ const NType* n,
+ const PType* p,
+ OType* out) {
+ RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
+ index_t nBatch(1 + (nSample - 1) / nParm);
+ SampleMultinomial<xpu, NType, PType, OType>(
+ n[i / nBatch], &p[(i / nBatch) * K], &out[i * K], K, &genImpl);
+ })
+ }
+};
+
+template <typename xpu>
+void SampleCategoricalForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
- const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(attrs.parsed);
+ const SampleCategoricalParam& param = nnvm::get<SampleCategoricalParam>(attrs.parsed);
index_t K = inputs[0].shape_[inputs[0].ndim() - 1];
index_t N = inputs[0].Size() / K;
@@ -174,7 +301,7 @@ void SampleMultinomialForward(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, float> uniform(workspace.dptr_, Shape1(N * M));
prnd->SampleUniform(&uniform, 0, 1);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, {
- Kernel<SampleMultinomialKernel, xpu>::Launch(
+ Kernel<SampleCategoricalKernel, xpu>::Launch(
s,
N,
K,
@@ -188,8 +315,53 @@ void SampleMultinomialForward(const nnvm::NodeAttrs& attrs,
});
}
+template <typename xpu>
+static inline void multinomial_op(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const OpReqType& req,
+ TBlob* num,
+ TBlob* prob,
+ TBlob* outputs) {
+ Stream<xpu>* s = ctx.get_stream<xpu>();
+ MSHADOW_REAL_TYPE_SWITCH(
+ num[0].type_flag_,
+ NType,
+ {MSHADOW_REAL_TYPE_SWITCH(
+ prob[0].type_flag_, PType, {MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+ RandGenerator<xpu, OType>* pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
+ RandGenerator<xpu, float>* gen = reinterpret_cast<RandGenerator<xpu, float>*>(pgen);
+
+ Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
+ Tensor<xpu, 1, NType> n = num->FlatTo1D<xpu, NType>(s);
+ Tensor<xpu, 1, PType> p = prob->FlatTo1D<xpu, PType>(s);
+ index_t K = prob->shape_[prob->ndim() - 1];
+
+ LaunchRNG<SampleMultinomialKernel<xpu>, xpu>(s,
+ gen,
+ out.size(0) / K,
+ n.size(0),
+ out.size(0) / K,
+ K,
+ n.dptr_,
+ p.dptr_,
+ out.dptr_);
+ })})});
+}
+
+template <typename xpu>
+void SampleMultinomialForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ TBlob num = inputs[0];
+ TBlob prob = inputs[1];
+ TBlob out = outputs[0];
+ multinomial_op<xpu>(attrs, ctx, req[0], &num, &prob, &out);
+}
+
template <typename kernel, typename xpu>
-void SampleMultinomialBackward(const nnvm::NodeAttrs& attrs,
+void SampleCategoricalBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
diff --git a/src/operator/random/sample_op.cc b/src/operator/random/sample_op.cc
index 02d91f2..8d9f5d8 100644
--- a/src/operator/random/sample_op.cc
+++ b/src/operator/random/sample_op.cc
@@ -33,6 +33,7 @@ DMLC_REGISTER_PARAMETER(SampleNormalParam);
DMLC_REGISTER_PARAMETER(SampleGammaParam);
DMLC_REGISTER_PARAMETER(SampleExponentialParam);
DMLC_REGISTER_PARAMETER(SamplePoissonParam);
+DMLC_REGISTER_PARAMETER(SampleBinomialParam);
DMLC_REGISTER_PARAMETER(SampleNegBinomialParam);
DMLC_REGISTER_PARAMETER(SampleGenNegBinomialParam);
DMLC_REGISTER_PARAMETER(SampleRandIntParam);
@@ -42,6 +43,7 @@ DMLC_REGISTER_PARAMETER(SampleNormalLikeParam);
DMLC_REGISTER_PARAMETER(SampleGammaLikeParam);
DMLC_REGISTER_PARAMETER(SampleExponentialLikeParam);
DMLC_REGISTER_PARAMETER(SamplePoissonLikeParam);
+DMLC_REGISTER_PARAMETER(SampleBinomialLikeParam);
DMLC_REGISTER_PARAMETER(SampleNegBinomialLikeParam);
DMLC_REGISTER_PARAMETER(SampleGenNegBinomialLikeParam);
@@ -149,6 +151,20 @@ Example::
[ 4., 6.]]
)code" ADD_FILELINE);
+MXNET_OPERATOR_REGISTER_SAMPLE(_random_binomial, SampleBinomialParam)
+ .add_alias("random_binomial")
+ .describe(R"code(Draw random samples from a binomial distribution.
+
+Samples are distributed according to a binomial distribution parametrized by
+*n* (number of experiments) and *p* (success probability in each experiment).
+Samples will always be returned as a floating point data type.
+
+Example::
+
+ binomial(n=3, p=0.4, shape=(2,2)) = [[ 1., 0.],
+ [ 1., 2.]]
+)code" ADD_FILELINE);
+
MXNET_OPERATOR_REGISTER_SAMPLE(_random_negative_binomial, SampleNegBinomialParam)
.add_alias("random_negative_binomial")
.describe(R"code(Draw random samples from a negative binomial distribution.
diff --git a/src/operator/random/sample_op.cu b/src/operator/random/sample_op.cu
index 8ffc633..1fc6239 100644
--- a/src/operator/random/sample_op.cu
+++ b/src/operator/random/sample_op.cu
@@ -36,6 +36,7 @@ MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_normal, SampleNormalParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_gamma, SampleGammaParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_exponential, SampleExponentialParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_poisson, SamplePoissonParam)
+MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_binomial, SampleBinomialParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_negative_binomial, SampleNegBinomialParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_generalized_negative_binomial, SampleGenNegBinomialParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_randint, SampleRandIntParam)
@@ -44,6 +45,7 @@ MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_normal_like, SampleNormalLikeParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_gamma_like, SampleGammaLikeParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_exponential_like, SampleExponentialLikeParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_poisson_like, SamplePoissonLikeParam)
+MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_binomial_like, SampleBinomialLikeParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_negative_binomial_like, SampleNegBinomialLikeParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_generalized_negative_binomial_like,
SampleGenNegBinomialLikeParam)
diff --git a/src/operator/random/sample_op.h b/src/operator/random/sample_op.h
index 05d4363..cfff87c 100644
--- a/src/operator/random/sample_op.h
+++ b/src/operator/random/sample_op.h
@@ -67,6 +67,11 @@ struct PoissonParam {
float lam;
};
+struct BinomialParam {
+ int n;
+ float p;
+};
+
struct NegBinomialParam {
int k;
float p;
@@ -190,6 +195,28 @@ struct SamplePoissonParam : public dmlc::Parameter<SamplePoissonParam>,
}
};
+struct SampleBinomialParam : public dmlc::Parameter<SampleBinomialParam>,
+ BinomialParam,
+ SampleOpParam {
+ DMLC_DECLARE_PARAMETER(SampleBinomialParam) {
+ DMLC_DECLARE_FIELD(n).set_default(1).describe("number of experiments.");
+ DMLC_DECLARE_FIELD(p).set_default(1.0f).describe("success probability in each experiment.");
+ DMLC_DECLARE_FIELD(shape).set_default(mxnet::TShape()).describe("Shape of the output.");
+ DMLC_DECLARE_FIELD(ctx).set_default("").describe(
+ "Context of output, in format [cpu|gpu|cpu_pinned](n)."
+ " Only used for imperative calls.");
+ DMLC_DECLARE_FIELD(dtype)
+ .add_enum("None", -1)
+ .add_enum("float32", mshadow::kFloat32)
+ .add_enum("float64", mshadow::kFloat64)
+ .add_enum("float16", mshadow::kFloat16)
+ .set_default(-1)
+ .describe(
+ "DType of the output in case this can't be inferred. "
+ "Defaults to float32 if not defined (dtype=None).");
+ }
+};
+
struct SampleNegBinomialParam : public dmlc::Parameter<SampleNegBinomialParam>,
NegBinomialParam,
SampleOpParam {
@@ -306,6 +333,13 @@ struct SamplePoissonLikeParam : public dmlc::Parameter<SamplePoissonLikeParam>,
}
};
+struct SampleBinomialLikeParam : public dmlc::Parameter<SampleBinomialLikeParam>, BinomialParam {
+ DMLC_DECLARE_PARAMETER(SampleBinomialLikeParam) {
+ DMLC_DECLARE_FIELD(n).set_default(1).describe("Number of experiments.");
+ DMLC_DECLARE_FIELD(p).set_default(1.0f).describe("success probability in each experiment.");
+ }
+};
+
struct SampleNegBinomialLikeParam : public dmlc::Parameter<SampleNegBinomialLikeParam>,
NegBinomialParam {
DMLC_DECLARE_PARAMETER(SampleNegBinomialLikeParam) {
@@ -443,6 +477,25 @@ static inline void poisson_op(const nnvm::NodeAttrs& attrs,
}
template <typename xpu, typename ParamType>
+static inline void binomial_op(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const OpReqType& req,
+ TBlob* outputs) {
+ Stream<xpu>* s = ctx.get_stream<xpu>();
+ const BinomialParam& param = nnvm::get<ParamType>(attrs.parsed);
+ CHECK_GE(param.n, 0) << "n parameter in binomial distribution has to be non-negative";
+ CHECK_GE(param.p, 0) << "p parameter in binomial distribution has to be non-negative";
+ Tensor<xpu, 1, float> n, p;
+ GetSamplingTempData<xpu, float>(param.n, param.p, ctx, &n, &p);
+ BinomialSampler<xpu> sampler;
+ MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+ RandGenerator<xpu, OType>* pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
+ Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
+ sampler.Sample(n, p, out, pgen, s);
+ });
+}
+
+template <typename xpu, typename ParamType>
static inline void neg_binomial_op(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const OpReqType& req,
@@ -604,6 +657,26 @@ struct SampleMaster<xpu, SamplePoissonLikeParam> {
};
template <typename xpu>
+struct SampleMaster<xpu, SampleBinomialParam> {
+ static inline void op(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const OpReqType& req,
+ TBlob* outputs) {
+ binomial_op<xpu, SampleBinomialParam>(attrs, ctx, req, outputs);
+ }
+};
+
+template <typename xpu>
+struct SampleMaster<xpu, SampleBinomialLikeParam> {
+ static inline void op(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const OpReqType& req,
+ TBlob* outputs) {
+ binomial_op<xpu, SampleBinomialLikeParam>(attrs, ctx, req, outputs);
+ }
+};
+
+template <typename xpu>
struct SampleMaster<xpu, SampleNegBinomialParam> {
static inline void op(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
diff --git a/src/operator/random/sampler.h b/src/operator/random/sampler.h
index 296833c..0d896b1 100644
--- a/src/operator/random/sampler.h
+++ b/src/operator/random/sampler.h
@@ -318,6 +318,201 @@ struct PoissonSampler {
}
};
+MSHADOW_XINLINE double stirling_approximation(double k) {
+ static const double table[] = {0.08106146679532726,
+ 0.04134069595540929,
+ 0.02767792568499834,
+ 0.02079067210376509,
+ 0.01664469118982119,
+ 0.01387612882307075,
+ 0.01189670994589177,
+ 0.01041126526197209,
+ 0.009255462182712733,
+ 0.008330563433362871};
+
+ if (k <= 9)
+ return table[static_cast<int>(k)];
+
+ return (1.0 / 12 - (1.0 / 360 - 1.0 / 1260 / ((k + 1) * (k + 1))) / (((k + 1) * (k + 1)))) /
+ (k + 1);
+}
+
+// The algorithm is explained in https://www.tandfonline.com/doi/abs/10.1080/00949659308811496
+template <typename xpu, typename IType, typename OType>
+MSHADOW_XINLINE OType _sample_binomial_btrd(IType N,
+ IType p,
+ typename RandGenerator<xpu, float>::Impl* gen) {
+ OType m = floor((N + 1) * p);
+ OType r = p / (1 - p);
+ OType nr = (N + 1) * r;
+ OType npq = N * p * (1 - p);
+
+ OType b = 1.15 + 2.53 * sqrt(npq);
+ OType a = -0.0873 + 0.0248 * b + 0.01 * p;
+ OType c = N * p + 0.5;
+ OType alpha = (2.83 + 5.1 / b) * sqrt(npq);
+
+ OType v_r = 0.92 - 4.2 / b;
+ OType u_r__v_r = 0.86 * v_r;
+
+ while (true) {
+ OType v = gen->uniform();
+ if (v <= u_r__v_r) {
+ OType u = v / v_r - 0.43;
+
+ return floor((2 * a / (0.5 - abs(u)) + b) * u + c);
+ }
+
+ OType u;
+ if (v >= v_r) {
+ u = gen->uniform() - 0.5;
+ } else {
+ u = v / v_r - 0.93;
+ OType sgn_u = ((0 < u) - (u < 0));
+ u = sgn_u * 0.5 - u;
+
+ v = gen->uniform() * v_r;
+ }
+
+ OType us = 0.5 - abs(u);
+ OType k = floor((2 * a / us + b) * u + c);
+ if (k < 0 || k > N) {
+ continue;
+ }
+
+ v = v * alpha / (a / (us * us) + b);
+
+ OType km = abs(k - m);
+ if (km <= 15) {
+ OType f = 1;
+ for (double i = m; i < k; ++i)
+ f = f * (nr / i - r);
+ for (double i = k; i < m; ++i)
+ v = v * (nr / i - r);
+
+ if (v <= f) {
+ return k;
+ }
+
+ continue;
+ }
+
+ v = log(v);
+ OType rho = (km / npq) * (((km / 3 + 0.625) * km + 1.0 / 6) / npq + 0.5);
+ OType t = -km * km / (2 * npq);
+ if (v < t - rho) {
+ return k;
+ }
+
+ if (v > t + rho) {
+ continue;
+ }
+
+ OType nm = N - m + 1;
+ OType h = (m + 0.5) * log((m + 1) / (r * nm)) + stirling_approximation(m) +
+ stirling_approximation(N - m);
+
+ OType nk = N - k + 1;
+ OType tmp = h + (N + 1) * log(nm / nk) + (k + 0.5) * log(nk * r / (k + 1)) -
+ stirling_approximation(k) - stirling_approximation(N - k);
+ if (v <= tmp) {
+ return k;
+ }
+ }
+}
+
+template <typename xpu, typename IType, typename OType>
+MSHADOW_XINLINE OType _sample_binomial_inversion(IType n,
+ IType p,
+ typename RandGenerator<xpu, float>::Impl* gen) {
+ OType N = static_cast<OType>(n);
+ OType q = static_cast<OType>(p);
+ if (q > 0.5)
+ q = 1 - q;
+
+ OType s = 1 - q;
+
+ OType A = 1;
+ OType B = q / s;
+ OType C = (N + 1) * B;
+ OType D = A;
+ OType X = 0;
+
+ OType U = gen->uniform();
+ OType V = U / pow(s, N);
+
+ do {
+ if (V <= A)
+ break;
+ X = X + 1;
+ D = D * (C / X - B);
+ A = A + D;
+ } while (X < N);
+
+ if (p > 0.5)
+ return N - X;
+
+ return X;
+}
+
+template <typename xpu, typename IType, typename OType>
+MSHADOW_XINLINE OType SampleBinomial(IType n,
+ IType p,
+ typename RandGenerator<xpu, float>::Impl* gen) {
+ // Generate one sample of the binomial distribution
+ if (p >= 1) {
+ return static_cast<OType>(n);
+ }
+
+ if (p <= 0.5) {
+ if (n * p >= 10.0) {
+ return _sample_binomial_btrd<xpu, IType, OType>(n, p, gen);
+ } else {
+ return _sample_binomial_inversion<xpu, IType, OType>(n, p, gen);
+ }
+ } else {
+ IType q = 1.0 - p;
+ if (n * q >= 10.0) {
+ return n - _sample_binomial_btrd<xpu, IType, OType>(n, q, gen);
+ } else {
+ return n - _sample_binomial_inversion<xpu, IType, OType>(n, q, gen);
+ }
+ }
+}
+
+template <typename xpu>
+struct SampleBinomialKernel {
+ template <typename IType, typename OType>
+ MSHADOW_XINLINE static void Map(index_t id,
+ RandGenerator<xpu, float> gen,
+ const index_t N,
+ const index_t step,
+ index_t nParm,
+ index_t nSample,
+ const IType* n,
+ const IType* p,
+ OType* out) {
+ RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
+ index_t nBatch(1 + (nSample - 1) / nParm);
+ out[i] = SampleBinomial<xpu, IType, OType>(n[i / nBatch], p[i / nBatch], &genImpl);
+ });
+ }
+};
+
+template <typename xpu>
+struct BinomialSampler {
+ template <typename IType, typename OType>
+ MSHADOW_FORCE_INLINE void Sample(const Tensor<xpu, 1, IType>& n,
+ const Tensor<xpu, 1, IType>& p,
+ const Tensor<xpu, 1, OType>& out,
+ RandGenerator<xpu, OType>* pgen,
+ Stream<xpu>* s) {
+ RandGenerator<xpu, float>* gen = reinterpret_cast<RandGenerator<xpu, float>*>(pgen);
+ LaunchRNG<SampleBinomialKernel<xpu>, xpu>(
+ s, gen, out.size(0), n.size(0), out.size(0), n.dptr_, p.dptr_, out.dptr_);
+ }
+};
+
template <typename xpu>
struct SampleNegativeBinomialKernel {
template <typename IType, typename OType>
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index bd3624b..8008c05 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -5503,6 +5503,33 @@ def test_npx_categorical():
@use_np
+def test_npx_multinomial():
+ class TestNumpyMultinomial(HybridBlock):
+ def __init__(self, size=None):
+ super(TestNumpyMultinomial, self).__init__()
+ self.size = size
+
+ def forward(self, n, prob):
+ if self.size is None:
+ return npx.random.multinomial(n, prob)
+ return npx.random.multinomial(n, prob, shape=self.size)
+
+ batch_sizes = [(2,), (2, 3)]
+ event_shapes = [None, (10,), (10, 12)]
+ num_event = [2, 4, 10]
+ for batch_size, num_event, event_shape in itertools.product(batch_sizes, num_event, event_shapes):
+ for hybridize in [True, False]:
+ n = np.ones(batch_size)
+ prob = np.ones(batch_size + (num_event,)) / num_event
+ net = TestNumpyMultinomial(event_shape)
+ if hybridize:
+ net.hybridize()
+ mx_out = net(n, prob)
+ desired_shape = batch_size + event_shape + (num_event,) if event_shape is not None else batch_size + (num_event,)
+ assert mx_out.shape == desired_shape
+
+
+@use_np
def test_random_seed():
for seed in [234, 594, 7240, 20394]:
ret = []
diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py
index 3fc061b..cb3eae3 100644
--- a/tests/python/unittest/test_random.py
+++ b/tests/python/unittest/test_random.py
@@ -425,16 +425,16 @@ def test_random_seed_setting():
num_samples = 100000
for dtype in ['float16', 'float32', 'float64']:
seed = set_seed_variously(1, num_temp_seeds, seed_to_test)
- samples1 = mx.nd.random.multinomial(data=mx.nd.array(probs, ctx=ctx, dtype=dtype),
+ samples1 = mx.nd.random.categorical(data=mx.nd.array(probs, ctx=ctx, dtype=dtype),
shape=num_samples)
seed = set_seed_variously(seed, num_temp_seeds, seed_to_test)
- samples2 = mx.nd.random.multinomial(data=mx.nd.array(probs, ctx=ctx, dtype=dtype),
+ samples2 = mx.nd.random.categorical(data=mx.nd.array(probs, ctx=ctx, dtype=dtype),
shape=num_samples)
samples1np = samples1.asnumpy()
set_seed_variously(seed, num_temp_seeds, seed_to_test+1)
samples2np = samples2.asnumpy()
assert same(samples1np, samples2np), \
- "seed-setting test: `multinomial` should give the same result with the same seed"
+ "seed-setting test: `categorical` should give the same result with the same seed"
# Tests that seed setting of parallel rng is synchronous w.r.t. rng use before and after.
@@ -504,13 +504,13 @@ def test_random_seed_setting_for_context():
ctx = mx.context.current_context()
seed = set_seed_variously_for_context(ctx, 1, num_temp_seeds, seed_to_test)
- # Check imperative. `multinomial` uses non-parallel rng.
- rnds = mx.nd.random.multinomial(data=mx.nd.array(probs, dtype=dtype), shape=num_samples)
+ # Check imperative. `categorical` uses non-parallel rng.
+ rnds = mx.nd.random.categorical(data=mx.nd.array(probs, dtype=dtype), shape=num_samples)
samples_imp.append(rnds.asnumpy())
- # Check symbolic. `multinomial` uses non-parallel rng.
+ # Check symbolic. `categorical` uses non-parallel rng.
P = mx.sym.Variable("P")
- X = mx.sym.random.multinomial(data=P, shape=num_samples, get_prob=False)
+ X = mx.sym.random.categorical(data=P, shape=num_samples, get_prob=False)
exe = X._bind(ctx, {"P": mx.nd.array(probs, dtype=dtype)})
set_seed_variously_for_context(ctx, seed, num_temp_seeds, seed_to_test)
exe.forward()
@@ -566,14 +566,14 @@ def test_parallel_random_seed_setting_for_context():
@pytest.mark.parametrize('dtype', ['uint8', 'int32', 'float16', 'float32', 'float64'])
@pytest.mark.parametrize('x', [[[0,1,2,3,4],[4,3,2,1,0]], [0,1,2,3,4]])
@pytest.mark.serial
-def test_sample_multinomial(dtype, x):
+def test_sample_categorical(dtype, x):
x = mx.nd.array(x) / 10.0
dx = mx.nd.ones_like(x)
mx.autograd.mark_variables([x], [dx])
# Adding rtol and increasing samples needed to pass with seed 2951820647
samples = 10000
with mx.autograd.record():
- y, prob = mx.nd.random.multinomial(x, shape=samples, get_prob=True, dtype=dtype)
+ y, prob = mx.nd.random.categorical(x, shape=samples, get_prob=True, dtype=dtype)
r = prob * 5
r.backward()
@@ -685,6 +685,24 @@ def test_poisson_generator():
verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
@pytest.mark.serial
+def test_binomial_generator():
+ ctx = mx.context.current_context()
+ for dtype in ['float16', 'float32', 'float64']:
+ trials_num = 10000
+ success_prob = 0.25
+
+ buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.binom.ppf(x, trials_num, success_prob), 10)
+ generator_mx = lambda x: mx.nd.random.binomial(trials_num, success_prob,
+ shape=x, ctx=ctx, dtype=dtype).asnumpy()
+ nsamples = 1000
+ verify_generator(generator=generator_mx, buckets=buckets, probs=probs, nsamples=nsamples)
+ generator_mx_same_seed = \
+ lambda x: np.concatenate(
+ [mx.nd.random.binomial(trials_num, success_prob, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
+ for _ in range(10)])
+ verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs, nsamples=nsamples)
+
+@pytest.mark.serial
def test_negative_binomial_generator():
ctx = mx.context.current_context()
for dtype in ['float16', 'float32', 'float64']:
@@ -714,7 +732,7 @@ def test_negative_binomial_generator():
verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
@pytest.mark.serial
-def test_multinomial_generator():
+def test_categorical_generator():
# This test fails with dtype float16 if the probabilities themselves cannot be
# well-represented in float16. When the float16 random picks are assigned to buckets,
# only certain bucket-probabilities are possible. Here we map the desired probabilites
@@ -739,7 +757,7 @@ def test_multinomial_generator():
buckets = list(range(6))
for dtype in ['float16', 'float32', 'float64']:
quantized_probs = quantize_probs(probs, dtype)
- generator_mx = lambda x: mx.nd.random.multinomial(data=mx.nd.array(quantized_probs, ctx=ctx, dtype=dtype),
+ generator_mx = lambda x: mx.nd.random.categorical(data=mx.nd.array(quantized_probs, ctx=ctx, dtype=dtype),
shape=x).asnumpy()
# success_rate was set to 0.15 since PR #13498 and became flaky
# both of previous issues(#14457, #14158) failed with success_rate 0.25
@@ -750,7 +768,7 @@ def test_multinomial_generator():
nsamples=samples, nrepeat=trials, success_rate=0.20)
generator_mx_same_seed = \
lambda x: np.concatenate(
- [mx.nd.random.multinomial(data=mx.nd.array(quantized_probs, ctx=ctx, dtype=dtype),
+ [mx.nd.random.categorical(data=mx.nd.array(quantized_probs, ctx=ctx, dtype=dtype),
shape=x // 10).asnumpy()
for _ in range(10)])
verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=quantized_probs,
@@ -758,6 +776,35 @@ def test_multinomial_generator():
@pytest.mark.serial
+def test_multinomial_generator():
+ def repeat_i(arr):
+ """
+ Return an array containing ordered values from 0 to arr.size()-1,
+ where each value i is repeated arr[i] times.
+
+ Example:
+ >>> repeat_i([3, 1, 2, 1])
+ [0, 0, 0, 1, 2, 2, 3]
+ """
+ ind = mx.nd.expand_dims(mx.nd.cumsum(mx.nd.concat(mx.nd.array([0]), arr[:arr.size-1], dim=0)), axis=0)
+ data = mx.nd.ones((arr.size,))
+ shape = (int(mx.nd.sum(arr).asscalar()),)
+ return mx.nd.cumsum(mx.nd.scatter_nd(data, ind, shape)) - 1
+
+ ctx = mx.context.current_context()
+ probs = np.array([0.1, 0.2, 0.3, 0.05, 0.15, 0.2])
+
+ buckets = list(range(6))
+ for dtype in ['float16', 'float32', 'float64']:
+ generator_mx = lambda x: repeat_i(mx.nd.random.multinomial(n=mx.nd.array([x]), p=mx.nd.array([probs]), ctx=ctx)[0]).asnumpy()
+ verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
+
+ generator_mx_same_seed = \
+ lambda x: np.concatenate([generator_mx(x // 10) for _ in range(10)])
+ verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
+
+
+@pytest.mark.serial
def test_with_random_seed():
ctx = mx.context.current_context()
size = 100
@@ -1022,12 +1069,12 @@ def test_randint_without_dtype():
@pytest.mark.serial
-def test_sample_multinomial_num_outputs():
+def test_sample_categorical_num_outputs():
ctx = mx.context.current_context()
probs = [[0.125, 0.25, 0.25], [0.0625, 0.125, 0.1875]]
- out = mx.nd.random.multinomial(data=mx.nd.array(probs, ctx=ctx), shape=10000, get_prob=False)
+ out = mx.nd.random.categorical(data=mx.nd.array(probs, ctx=ctx), shape=10000, get_prob=False)
assert isinstance(out, mx.nd.NDArray)
- out = mx.nd.random.multinomial(data=mx.nd.array(probs, ctx=ctx), shape=10000, get_prob=True)
+ out = mx.nd.random.categorical(data=mx.nd.array(probs, ctx=ctx), shape=10000, get_prob=True)
assert isinstance(out, list)
assert len(out) == 2
@@ -1079,4 +1126,4 @@ def test_poisson_zero_size_dim():
assertRaises(MXNetError, mx.nd.op.random_pdf_poisson, sample, lam)
test_valid_zero_dim()
- test_invalid_zero_dim()
\ No newline at end of file
+ test_invalid_zero_dim()