You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2022/02/05 16:55:02 UTC

[incubator-mxnet] branch master updated: [FEATURE] Add binomial sampling and fix multinomial sampling (#20734)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new e9becb9  [FEATURE] Add binomial sampling and fix multinomial sampling (#20734)
e9becb9 is described below

commit e9becb91c82486723d288b572c6721eaa6f2fd21
Author: Vincenzo Territo <47...@users.noreply.github.com>
AuthorDate: Sat Feb 5 17:53:43 2022 +0100

    [FEATURE] Add binomial sampling and fix multinomial sampling (#20734)
    
    * implement binomial sampling
    
    * add correct multinomial implementation
    
    * small fix in binomial symbol api doc
    
    * change npx_categorical to npx_multinomial
    
    * small sanity fix
    
    * fix python unit tests
    
    * rename previous multinomial implementation to categorical
    
    * small fix
---
 python/mxnet/amp/lists/symbol_fp16.py        |   4 +
 python/mxnet/ndarray/random.py               | 134 ++++++++++++++++--
 python/mxnet/symbol/random.py                |  82 +++++++++--
 src/operator/random/multisample_op.cc        |  38 ++++++
 src/operator/random/multisample_op.cu        |   3 +
 src/operator/random/sample_multinomial_op.cc |  80 ++++++++---
 src/operator/random/sample_multinomial_op.cu |   9 +-
 src/operator/random/sample_multinomial_op.h  | 194 ++++++++++++++++++++++++--
 src/operator/random/sample_op.cc             |  16 +++
 src/operator/random/sample_op.cu             |   2 +
 src/operator/random/sample_op.h              |  73 ++++++++++
 src/operator/random/sampler.h                | 195 +++++++++++++++++++++++++++
 tests/python/unittest/test_numpy_op.py       |  27 ++++
 tests/python/unittest/test_random.py         |  79 ++++++++---
 14 files changed, 868 insertions(+), 68 deletions(-)

diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py
index a6d8a53..3f3dd48 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -336,6 +336,8 @@ FP16_FP32_FUNCS = [
     '_random_exponential_like',
     '_random_gamma',
     '_random_gamma_like',
+    '_random_binomial',
+    '_random_binomial_like',
     '_random_generalized_negative_binomial',
     '_random_generalized_negative_binomial_like',
     '_random_negative_binomial',
@@ -353,7 +355,9 @@ FP16_FP32_FUNCS = [
     '_rnn_param_concat',
     '_sample_exponential',
     '_sample_gamma',
+    '_sample_binomial',
     '_sample_generalized_negative_binomial',
+    '_sample_categorical',
     '_sample_multinomial',
     '_sample_negative_binomial',
     '_sample_normal',
diff --git a/python/mxnet/ndarray/random.py b/python/mxnet/ndarray/random.py
index b0683b4..1e7955f 100644
--- a/python/mxnet/ndarray/random.py
+++ b/python/mxnet/ndarray/random.py
@@ -23,8 +23,8 @@ from . import _internal
 from .ndarray import NDArray
 
 
-__all__ = ['uniform', 'normal', 'randn', 'poisson', 'exponential', 'gamma',
-           'multinomial', 'negative_binomial', 'generalized_negative_binomial',
+__all__ = ['uniform', 'normal', 'randn', 'poisson', 'exponential', 'gamma', 'binomial',
+           'categorical', 'multinomial', 'negative_binomial', 'generalized_negative_binomial',
            'shuffle', 'randint']
 
 
@@ -383,6 +383,59 @@ def gamma(alpha=1, beta=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwarg
                           [alpha, beta], shape, dtype, ctx, out, kwargs)
 
 
+def binomial(n=1, p=0.5, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs):
+    """Draw random samples from a binomial distribution.
+
+    Samples are distributed according to a binomial distribution parametrized
+    by *n* (number of trials) and *p* (success probability).
+
+    Parameters
+    ----------
+    n : float or NDArray, optional
+        Number of experiments, > 0.
+    p : float or NDArray, optional
+        Success probability in each experiment, >= 0 and <= 1.
+    shape : int or tuple of ints, optional
+        The number of samples to draw. If shape is, e.g., `(m, n)` and `n` and
+        `p` are scalars, output shape will be `(m, n)`. If `n` and `p`
+        are NDArrays with shape, e.g., `(x, y)`, then output will have shape
+        `(x, y, m, n)`, where `m*n` samples are drawn for each `[n, p)` pair.
+    dtype : {'float16', 'float32', 'float64'}, optional
+        Data type of output samples. Default is 'float32'
+    ctx : Context, optional
+        Device context of output. Default is current context. Overridden by
+        `n.context` when `n` is an NDArray.
+    out : NDArray, optional
+        Store output to an existing NDArray.
+
+    Returns
+    -------
+    NDArray
+        If input `shape` has shape, e.g., `(m, n)` and `n` and `p` are scalars, output
+        shape will be `(m, n)`. If `n` and `p` are NDArrays with shape, e.g.,
+        `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are
+        drawn for each `[n, p)` pair.
+
+    Examples
+    --------
+    >>> mx.nd.random.binomial(10, 0.1)
+    [ 1.]
+    <NDArray 1 @cpu(0)>
+    >>> mx.nd.random.binomial(10, 0.6, shape=(2,))
+    [ 4. 6.]
+    <NDArray 2 @cpu(0)>
+    >>> n = mx.nd.array([10,2,3])
+    >>> p = mx.nd.array([0.2,0.3,0.4])
+    >>> mx.nd.random.binomial(n, p, shape=2)
+    [[  1. 4.]
+     [  0. 2.]
+     [  1. 1.]]
+    <NDArray 3x2 @cpu(0)>
+    """
+    return _random_helper(_internal._random_binomial, _internal._sample_binomial,
+                          [n, p], shape, dtype, ctx, out, kwargs)
+
+
 def negative_binomial(k=1, p=1, shape=_Null, dtype=_Null, ctx=None,
                       out=None, **kwargs):
     """Draw random samples from a negative binomial distribution.
@@ -496,9 +549,8 @@ def generalized_negative_binomial(mu=1, alpha=1, shape=_Null, dtype=_Null, ctx=N
                           _internal._sample_generalized_negative_binomial,
                           [mu, alpha], shape, dtype, ctx, out, kwargs)
 
-
-def multinomial(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kwargs):
-    """Concurrent sampling from multiple multinomial distributions.
+def categorical(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kwargs):
+    """Concurrent sampling from multiple categorical distributions.
 
     .. note:: The input distribution must be normalized, i.e. `data` must sum to
               1 along its last dimension.
@@ -507,8 +559,8 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kw
     ----------
     data : NDArray
         An *n* dimensional array whose last dimension has length `k`, where
-        `k` is the number of possible outcomes of each multinomial distribution.
-        For example, data with shape `(m, n, k)` specifies `m*n` multinomial
+        `k` is the number of possible outcomes of each categorical distribution.
+        For example, data with shape `(m, n, k)` specifies `m*n` categorical
         distributions each with `k` possible outcomes.
     shape : int or tuple of ints, optional
         The number of samples to draw from each distribution. If shape is empty
@@ -530,7 +582,7 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kw
         For input `data` with `n` dimensions and shape `(d1, d2, ..., dn-1, k)`, and input
         `shape` with shape `(s1, s2, ..., sx)`, returns an NDArray with shape
         `(d1, d2, ... dn-1, s1, s2, ..., sx)`. The `s1, s2, ... sx` dimensions of the
-        returned NDArray consist of 0-indexed values sampled from each respective multinomial
+        returned NDArray consist of 0-indexed values sampled from each respective categorical
         distribution provided in the `k` dimension of `data`.
 
         For the case `n`=1, and `x`=1 (one shape dimension), returned NDArray has shape `(s1,)`.
@@ -542,24 +594,80 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kw
     Examples
     --------
     >>> probs = mx.nd.array([0, 0.1, 0.2, 0.3, 0.4])
-    >>> mx.nd.random.multinomial(probs)
+    >>> mx.nd.random.categorical(probs)
     [3]
     <NDArray 1 @cpu(0)>
     >>> probs = mx.nd.array([[0, 0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1, 0]])
-    >>> mx.nd.random.multinomial(probs)
+    >>> mx.nd.random.categorical(probs)
     [3 1]
     <NDArray 2 @cpu(0)>
-    >>> mx.nd.random.multinomial(probs, shape=2)
+    >>> mx.nd.random.categorical(probs, shape=2)
     [[4 4]
      [1 2]]
     <NDArray 2x2 @cpu(0)>
-    >>> mx.nd.random.multinomial(probs, get_prob=True)
+    >>> mx.nd.random.categorical(probs, get_prob=True)
     [3 2]
     <NDArray 2 @cpu(0)>
     [-1.20397282 -1.60943794]
     <NDArray 2 @cpu(0)>
     """
-    return _internal._sample_multinomial(data, shape, get_prob, out=out, dtype=dtype, **kwargs)
+    return _internal._sample_categorical(data, shape, get_prob, out=out, dtype=dtype, **kwargs)
+
+
+def multinomial(n=[1], p=[[1.0]], shape=_Null, dtype='float32', ctx=None, out=None, **kwargs):
+    """Concurrent sampling from multiple multinomial distributions.
+
+    .. note:: The input distribution must be normalized, i.e. `p` must sum to
+              1 along its last dimension.
+
+    Parameters
+    ----------
+    n : NDArray
+        An *n* dimensional array containing the number of trials of each
+        multinomial distribution.
+    p : NDArray
+        An *n+1* dimensional array containing the probabilities of each multinomial
+        distribution. Its last dimension has length `k`, where `k` is the number
+        of possible outcomes of each multinomial distribution.
+        For example, p with shape `(m, n, k)` specifies `m*n` multinomial
+        distributions each with `k` possible outcomes.
+    shape : int or tuple of ints, optional
+        The number of samples to draw from each distribution. If shape is empty
+        one sample will be drawn from each distribution.
+    out : NDArray, optional
+        Store output to an existing NDArray.
+    ctx : Context, optional
+        Device context of output. Default is current context. Overridden by
+        `n.context` when `n` is an NDArray.
+    dtype : {'float16', 'float32', 'float64'}, optional
+        Data type of output samples. Default is 'float32'
+
+    Returns
+    -------
+    NDArray
+        If input `shape` has shape, e.g., `(m, n)` and `n` and `p` are a scalar and an array of length k
+        respectively, output shape will be `(m, n, k)`. If `n` and `p` are NDArrays with shape, e.g.,
+        `(x, y)` and `(x, y, k)`, then output will have shape `(x, y, m, n, k)`, where `m*n`
+        samples are drawn for each `[n, p)` pair.
+
+    Examples
+    --------
+    >>> mx.nd.random.multinomial(mx.nd.array([10]), mx.nd.array([[0.1, 0.9]]))
+    [[ 1. 9.]]
+    <NDArray 1x2 @cpu(0)>
+    >>> mx.nd.random.multinomial(mx.nd.array([10]), mx.nd.array([[0.6, 0.4]]), shape=(2,))
+    [[[ 5. 5.]
+      [ 6. 4.]]]
+    <NDArray 1x2x2 @cpu(0)>
+    >>> n = mx.nd.array([10, 2, 3])
+    >>> p = mx.nd.array([[0.2, 0.8], [0.3, 0.7], [0.4, 0.6]])
+    >>> mx.nd.random.binomial(n, p)
+    [[  2. 8.]
+     [  1. 1.]
+     [  1. 2.]]
+    <NDArray 3x2 @cpu(0)>
+    """
+    return _internal._sample_multinomial(n, p, shape=shape, out=out, ctx=ctx, dtype=dtype, **kwargs)
 
 
 def shuffle(data, **kwargs):
diff --git a/python/mxnet/symbol/random.py b/python/mxnet/symbol/random.py
index b2ff104..827ec40 100644
--- a/python/mxnet/symbol/random.py
+++ b/python/mxnet/symbol/random.py
@@ -22,8 +22,8 @@ from . import _internal
 from .symbol import Symbol
 
 
-__all__ = ['uniform', 'normal', 'randn', 'poisson', 'exponential', 'gamma', 'multinomial',
-           'negative_binomial', 'generalized_negative_binomial', 'shuffle', 'randint']
+__all__ = ['uniform', 'normal', 'randn', 'poisson', 'exponential', 'gamma', 'categorical', 'multinomial',
+           'binomial', 'negative_binomial', 'generalized_negative_binomial', 'shuffle', 'randint']
 
 
 def _random_helper(random, sampler, params, shape, dtype, kwargs):
@@ -240,6 +240,38 @@ def gamma(alpha=1, beta=1, shape=_Null, dtype=_Null, **kwargs):
                           [alpha, beta], shape, dtype, kwargs)
 
 
+def binomial(n=1, p=0.5, shape=_Null, dtype=_Null, **kwargs):
+    """Draw random samples from a binomial distribution.
+
+    Samples are distributed according to a binomial distribution parametrized
+    by *n* (number of trials) and *p* (success probability).
+
+    Parameters
+    ----------
+    n : float or Symbol, optional
+        Number of experiments, > 0.
+    p : float or Symbol, optional
+        Success probability in each experiment, >= 0 and <= 1.
+    shape : int or tuple of ints, optional
+        The number of samples to draw. If shape is, e.g., `(m, n)` and `n` and
+        `p` are scalars, output shape will be `(m, n)`. If `n` and `p`
+        are NDArrays with shape, e.g., `(x, y)`, then output will have shape
+        `(x, y, m, n)`, where `m*n` samples are drawn for each `[n, p)` pair.
+    dtype : {'float16', 'float32', 'float64'}, optional
+        Data type of output samples. Default is 'float32'
+
+    Returns
+    -------
+    Symbol
+        If input `shape` has shape, e.g., `(m, n)` and `n` and `p` are scalars, output
+        shape will be `(m, n)`. If `n` and `p` are NDArrays with shape, e.g.,
+        `(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are
+        drawn for each `[n, p)` pair.
+    """
+    return _random_helper(_internal._random_binomial, _internal._sample_binomial,
+                          [n, p], shape, dtype, kwargs)
+
+
 def negative_binomial(k=1, p=1, shape=_Null, dtype=_Null, **kwargs):
     """Draw random samples from a negative binomial distribution.
 
@@ -311,8 +343,8 @@ def generalized_negative_binomial(mu=1, alpha=1, shape=_Null, dtype=_Null, **kwa
                           [mu, alpha], shape, dtype, kwargs)
 
 
-def multinomial(data, shape=_Null, get_prob=True, dtype='int32', **kwargs):
-    """Concurrent sampling from multiple multinomial distributions.
+def categorical(data, shape=_Null, get_prob=True, dtype='int32', **kwargs):
+    """Concurrent sampling from multiple categorical distributions.
 
     .. note:: The input distribution must be normalized, i.e. `data` must sum to
               1 along its last dimension.
@@ -321,8 +353,8 @@ def multinomial(data, shape=_Null, get_prob=True, dtype='int32', **kwargs):
     ----------
     data : Symbol
         An *n* dimensional array whose last dimension has length `k`, where
-        `k` is the number of possible outcomes of each multinomial distribution.
-        For example, data with shape `(m, n, k)` specifies `m*n` multinomial
+        `k` is the number of possible outcomes of each categorical distribution.
+        For example, data with shape `(m, n, k)` specifies `m*n` categorical
         distributions each with `k` possible outcomes.
     shape : int or tuple of ints, optional
         The number of samples to draw from each distribution. If shape is empty
@@ -343,7 +375,7 @@ def multinomial(data, shape=_Null, get_prob=True, dtype='int32', **kwargs):
         `shape` with shape `(s1, s2, ..., sx)`, returns a Symbol that resovles to shape
         `(d1, d2, ... dn-1, s1, s2, ..., sx)`. The `s1, s2, ... sx` dimensions of the
         returned Symbol's resolved value will consist of 0-indexed values sampled from each
-        respective multinomial distribution provided in the `k` dimension of `data`.
+        respective categorical distribution provided in the `k` dimension of `data`.
 
         For the case `n`=1, and `x`=1 (one shape dimension), returned Symbol will resolve to
         shape `(s1,)`.
@@ -352,7 +384,41 @@ def multinomial(data, shape=_Null, get_prob=True, dtype='int32', **kwargs):
         outputs: `[ndarray_output, log_likelihood_output]`, where `log_likelihood_output` will resolve
         to the same shape as the sampled outputs in ndarray_output.
     """
-    return _internal._sample_multinomial(data, shape, get_prob, dtype=dtype, **kwargs)
+    return _internal._sample_categorical(data, shape, get_prob, dtype=dtype, **kwargs)
+
+
+def multinomial(n=[1], p=[[1.0]], shape=_Null, dtype='float32', **kwargs):
+    """Concurrent sampling from multiple multinomial distributions.
+
+    .. note:: The input distribution must be normalized, i.e. `p` must sum to
+              1 along its last dimension.
+
+    Parameters
+    ----------
+    n : Symbol
+        An *n* dimensional array containing the number of trials of each
+        multinomial distribution.
+    p : Symbol
+        An *n+1* dimensional array containing the probabilities of each multinomial
+        distribution. Its last dimension has length `k`, where `k` is the number
+        of possible outcomes of each multinomial distribution.
+        For example, p with shape `(m, n, k)` specifies `m*n` multinomial
+        distributions each with `k` possible outcomes.
+    shape : int or tuple of ints, optional
+        The number of samples to draw from each distribution. If shape is empty
+        one sample will be drawn from each distribution.
+    dtype : {'float16', 'float32', 'float64'}, optional
+        Data type of output samples. Default is 'float32'
+
+    Returns
+    -------
+    Symbol
+        If input `shape` has shape, e.g., `(m, n)` and `n` and `p` are a scalar and an array of length k
+        respectively, output shape will be `(m, n, k)`. If `n` and `p` are NDArrays with shape, e.g.,
+        `(x, y)` and `(x, y, k)`, then output will have shape `(x, y, m, n, k)`, where `m*n`
+        samples are drawn for each `[n, p)` pair.
+    """
+    return _internal._sample_multinomial(n, p, shape, dtype=dtype, **kwargs)
 
 
 def shuffle(data, **kwargs):
diff --git a/src/operator/random/multisample_op.cc b/src/operator/random/multisample_op.cc
index b7f2214..7b0a02d 100644
--- a/src/operator/random/multisample_op.cc
+++ b/src/operator/random/multisample_op.cc
@@ -218,6 +218,37 @@ Examples::
 )code");
 }
 
+inline std::string binomial_desc() {
+  return std::string(R"code(Concurrent sampling from multiple
+binomial distributions with parameters *n* (number of trials) and *p* (success probability).
+
+The parameters of the distributions are provided as input arrays.
+Let *[s]* be the shape of the input arrays, *n* be the dimension of *[s]*, *[t]*
+be the shape specified as the parameter of the operator, and *m* be the dimension
+of *[t]*. Then the output will be a *(n+m)*-dimensional array with shape *[s]x[t]*.
+
+For any valid *n*-dimensional index *i* with respect to the input arrays, *output[i]*
+will be an *m*-dimensional array that holds randomly drawn samples from the distribution
+which is parameterized by the input values at index *i*. If the shape parameter of the
+operator is not set, then one sample will be drawn per distribution and the output array
+has the same shape as the input arrays.
+
+Samples will always be returned as a floating point data type.
+
+Examples::
+
+   n = [ 20, 49 ]
+   p = [ 0.4 , 0.77 ]
+
+   // Draw a single sample for each distribution
+   sample_binomial(n, p) = [ 5.,  36.]
+
+   // Draw a vector containing two samples for each distribution
+   sample_binomial(n, p, shape=(2)) = [[ 5.,  40.],
+                                       [ 11.,  35.]]
+)code");
+}
+
 inline std::string negative_binomial_desc() {
   return std::string(R"code(Concurrent sampling from multiple
 negative binomial distributions with parameters *k* (failure limit) and *p* (failure probability).
@@ -312,6 +343,13 @@ MXNET_OPERATOR_REGISTER_SAMPLING1(poisson,
                                   "Lambda (rate) parameters of the distributions.",
                                   poisson_desc)
     .add_alias("_npx_tensor_poisson");
+MXNET_OPERATOR_REGISTER_SAMPLING2(binomial,
+                                  BinomialSampler<cpu>,
+                                  "n",
+                                  "p",
+                                  "Number of experiments.",
+                                  "Success probabilities in each experiment.",
+                                  binomial_desc);
 MXNET_OPERATOR_REGISTER_SAMPLING2(negative_binomial,
                                   NegativeBinomialSampler<cpu>,
                                   "k",
diff --git a/src/operator/random/multisample_op.cu b/src/operator/random/multisample_op.cu
index 4c571e7..e7756c6 100644
--- a/src/operator/random/multisample_op.cu
+++ b/src/operator/random/multisample_op.cu
@@ -42,6 +42,9 @@ NNVM_REGISTER_OP(_sample_exponential)
 NNVM_REGISTER_OP(_sample_poisson)
     .set_attr<FCompute>("FCompute<gpu>", MultiSampleOpForward<gpu, PoissonSampler<gpu>, 1>);
 
+NNVM_REGISTER_OP(_sample_binomial)
+    .set_attr<FCompute>("FCompute<gpu>", MultiSampleOpForward<gpu, BinomialSampler<gpu>, 2>);
+
 NNVM_REGISTER_OP(_sample_negative_binomial)
     .set_attr<FCompute>("FCompute<gpu>",
                         MultiSampleOpForward<gpu, NegativeBinomialSampler<gpu>, 2>);
diff --git a/src/operator/random/sample_multinomial_op.cc b/src/operator/random/sample_multinomial_op.cc
index 01d66e2..5b3edaa 100644
--- a/src/operator/random/sample_multinomial_op.cc
+++ b/src/operator/random/sample_multinomial_op.cc
@@ -26,15 +26,16 @@
 namespace mxnet {
 namespace op {
 
+DMLC_REGISTER_PARAMETER(SampleCategoricalParam);
 DMLC_REGISTER_PARAMETER(SampleMultinomialParam);
 
-NNVM_REGISTER_OP(_sample_multinomial)
-    .add_alias("sample_multinomial")
+NNVM_REGISTER_OP(_sample_categorical)
+    .add_alias("sample_categorical")
     .add_alias("_npx__random_categorical")
-    .describe(R"code(Concurrent sampling from multiple multinomial distributions.
+    .describe(R"code(Concurrent sampling from multiple categorical distributions.
 
 *data* is an *n* dimensional array whose last dimension has length *k*, where
-*k* is the number of possible outcomes of each multinomial distribution. This
+*k* is the number of possible outcomes of each categorical distribution. This
 operator will draw *shape* samples from each distribution. If shape is empty
 one sample will be drawn from each distribution.
 
@@ -51,23 +52,23 @@ Examples::
    probs = [[0, 0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1, 0]]
 
    // Draw a single sample for each distribution
-   sample_multinomial(probs) = [3, 0]
+   sample_categorical(probs) = [3, 0]
 
    // Draw a vector containing two samples for each distribution
-   sample_multinomial(probs, shape=(2)) = [[4, 2],
+   sample_categorical(probs, shape=(2)) = [[4, 2],
                                            [0, 0]]
 
    // requests log likelihood
-   sample_multinomial(probs, get_prob=True) = [2, 1], [0.2, 0.3]
+   sample_categorical(probs, get_prob=True) = [2, 1], [0.2, 0.3]
 )code")
     .set_num_inputs(1)
     .set_num_outputs([](const nnvm::NodeAttrs& attrs) {
-      const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(attrs.parsed);
+      const SampleCategoricalParam& param = nnvm::get<SampleCategoricalParam>(attrs.parsed);
       return param.get_prob ? 2U : 1U;
     })
-    .set_attr_parser(ParamParser<SampleMultinomialParam>)
-    .set_attr<mxnet::FInferShape>("FInferShape", SampleMultinomialOpShape)
-    .set_attr<nnvm::FInferType>("FInferType", SampleMultinomialOpType)
+    .set_attr_parser(ParamParser<SampleCategoricalParam>)
+    .set_attr<mxnet::FInferShape>("FInferShape", SampleCategoricalOpShape)
+    .set_attr<nnvm::FInferType>("FInferType", SampleCategoricalOpType)
     .set_attr<FResourceRequest>("FResourceRequest",
                                 [](const nnvm::NodeAttrs& attrs) {
                                   return std::vector<ResourceRequest>{ResourceRequest::kRandom,
@@ -76,9 +77,9 @@ Examples::
     .set_attr<nnvm::FGradient>(
         "FGradient",
         [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
-          const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(n->attrs.parsed);
+          const SampleCategoricalParam& param = nnvm::get<SampleCategoricalParam>(n->attrs.parsed);
           if (param.get_prob) {
-            return MakeGradNode("_backward_sample_multinomial",
+            return MakeGradNode("_backward_sample_categorical",
                                 n,
                                 {ograds[1], n->inputs[0], nnvm::NodeEntry{n, 0, 0}},
                                 std::unordered_map<std::string, std::string>());
@@ -86,13 +87,58 @@ Examples::
             return MakeZeroGradNodes(n, ograds);
           }
         })
-    .set_attr<FCompute>("FCompute<cpu>", SampleMultinomialForward<cpu>)
+    .set_attr<FCompute>("FCompute<cpu>", SampleCategoricalForward<cpu>)
     .add_argument("data",
                   "NDArray-or-Symbol",
                   "Distribution probabilities. Must sum to one on the last axis.")
+    .add_arguments(SampleCategoricalParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_sample_multinomial)
+    .add_alias("sample_multinomial")
+    .add_alias("_npx__random_multinomial")
+    .describe(R"code(Concurrent sampling from multiple multinomial distributions.
+
+Samples are distributed according to a multinomial distribution parametrized by
+*n* (number of experiments) and *p* (success probabilities of the k possible outcomes
+in each experiment). Samples will always be returned as a floating point data type.
+
+Note that the input distribution must be normalized, i.e. *p* must sum to
+1 along its last axis.
+
+Examples::
+
+   n = [5., 6.]
+   probs = [[0., 0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1, 0.]]
+
+   multinomial(n, probs) = [[0., 0., 0., 3., 2.],
+                            [0., 3., 1., 2., 0.]]
+)code")
+    .set_num_inputs(2)
+    .set_num_outputs(1)
+    .set_attr_parser(ParamParser<SampleMultinomialParam>)
+    .set_attr<mxnet::FInferShape>("FInferShape", SampleMultinomialOpShape)
+    .set_attr<nnvm::FInferType>("FInferType", SampleMultinomialOpType)
+    .set_attr<FResourceRequest>("FResourceRequest",
+                                [](const nnvm::NodeAttrs& attrs) {
+                                  return std::vector<ResourceRequest>{
+                                      ResourceRequest::kParallelRandom,
+                                      ResourceRequest::kTempSpace};
+                                })
+    .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+    .set_attr<nnvm::FListInputNames>("FListInputNames",
+                                     [](const NodeAttrs& attrs) {
+                                       std::vector<std::string> v = {"n", "p"};
+                                       v.resize(2);
+                                       return v;
+                                     })
+    .set_attr<FCompute>("FCompute<cpu>", SampleMultinomialForward<cpu>)
+    .add_argument("n", "NDArray-or-Symbol", "Number of experiments")
+    .add_argument("p",
+                  "NDArray-or-Symbol",
+                  "Probability of every outcome in each experiment. Must sum to 1 on the last axis")
     .add_arguments(SampleMultinomialParam::__FIELDS__());
 
-struct SampleMultinomialBackwardCPUKernel {
+struct SampleCategoricalBackwardCPUKernel {
   template <typename DType, typename IType>
   MSHADOW_XINLINE static void
   Map(int i, index_t K, index_t M, DType* ograd, DType* dist, IType* out, DType* igrad) {
@@ -103,12 +149,12 @@ struct SampleMultinomialBackwardCPUKernel {
   }
 };
 
-NNVM_REGISTER_OP(_backward_sample_multinomial)
+NNVM_REGISTER_OP(_backward_sample_categorical)
     .set_num_inputs(3)
     .set_num_outputs(1)
     .set_attr<nnvm::TIsBackward>("TIsBackward", true)
     .set_attr<FCompute>("FCompute<cpu>",
-                        SampleMultinomialBackward<SampleMultinomialBackwardCPUKernel, cpu>);
+                        SampleCategoricalBackward<SampleCategoricalBackwardCPUKernel, cpu>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/random/sample_multinomial_op.cu b/src/operator/random/sample_multinomial_op.cu
index 79b45ff..3123d7b 100644
--- a/src/operator/random/sample_multinomial_op.cu
+++ b/src/operator/random/sample_multinomial_op.cu
@@ -26,10 +26,13 @@
 namespace mxnet {
 namespace op {
 
+NNVM_REGISTER_OP(_sample_categorical)
+    .set_attr<FCompute>("FCompute<gpu>", SampleCategoricalForward<gpu>);
+
 NNVM_REGISTER_OP(_sample_multinomial)
     .set_attr<FCompute>("FCompute<gpu>", SampleMultinomialForward<gpu>);
 
-struct SampleMultinomialBackwardGPUKernel {
+struct SampleCategoricalBackwardGPUKernel {
   template <typename DType, typename IType>
   MSHADOW_XINLINE static void
   Map(int i, index_t K, index_t M, DType* ograd, DType* dist, IType* out, DType* igrad) {
@@ -40,9 +43,9 @@ struct SampleMultinomialBackwardGPUKernel {
   }
 };
 
-NNVM_REGISTER_OP(_backward_sample_multinomial)
+NNVM_REGISTER_OP(_backward_sample_categorical)
     .set_attr<FCompute>("FCompute<gpu>",
-                        SampleMultinomialBackward<SampleMultinomialBackwardGPUKernel, gpu>);
+                        SampleCategoricalBackward<SampleCategoricalBackwardGPUKernel, gpu>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/random/sample_multinomial_op.h b/src/operator/random/sample_multinomial_op.h
index 9cc1de6..3346ae4 100644
--- a/src/operator/random/sample_multinomial_op.h
+++ b/src/operator/random/sample_multinomial_op.h
@@ -26,19 +26,21 @@
 
 #include <mxnet/operator_util.h>
 #include <vector>
+#include <string>
 #include "../mshadow_op.h"
 #include "../mxnet_op.h"
 #include "../operator_common.h"
 #include "../elemwise_op_common.h"
+#include "./sampler.h"
 
 namespace mxnet {
 namespace op {
 
-struct SampleMultinomialParam : public dmlc::Parameter<SampleMultinomialParam> {
+struct SampleCategoricalParam : public dmlc::Parameter<SampleCategoricalParam> {
   mxnet::TShape shape;
   bool get_prob;
   int dtype;
-  DMLC_DECLARE_PARAMETER(SampleMultinomialParam) {
+  DMLC_DECLARE_PARAMETER(SampleCategoricalParam) {
     DMLC_DECLARE_FIELD(shape)
         .set_default(mxnet::TShape(0, 1))
         .describe("Shape to be sampled from each random distribution.");
@@ -57,10 +59,32 @@ struct SampleMultinomialParam : public dmlc::Parameter<SampleMultinomialParam> {
   }
 };
 
-inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs,
+struct SampleMultinomialParam : public dmlc::Parameter<SampleMultinomialParam> {
+  mxnet::TShape shape;
+  std::string ctx;
+  int dtype;
+  DMLC_DECLARE_PARAMETER(SampleMultinomialParam) {
+    DMLC_DECLARE_FIELD(shape)
+        .set_default(mxnet::TShape(0, 1))
+        .describe("Shape to be sampled from each random distribution.");
+    DMLC_DECLARE_FIELD(ctx).set_default("").describe(
+        "Context of output, in format [cpu|gpu|cpu_pinned](n)."
+        " Only used for imperative calls.");
+    DMLC_DECLARE_FIELD(dtype)
+        .add_enum("uint8", mshadow::kUint8)
+        .add_enum("int32", mshadow::kInt32)
+        .add_enum("float16", mshadow::kFloat16)
+        .add_enum("float32", mshadow::kFloat32)
+        .add_enum("float64", mshadow::kFloat64)
+        .set_default(mshadow::kInt32)
+        .describe("DType of the output in case this can't be inferred.");
+  }
+};
+
+inline bool SampleCategoricalOpShape(const nnvm::NodeAttrs& attrs,
                                      mxnet::ShapeVector* in_attrs,
                                      mxnet::ShapeVector* out_attrs) {
-  const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(attrs.parsed);
+  const SampleCategoricalParam& param = nnvm::get<SampleCategoricalParam>(attrs.parsed);
 
   CHECK_EQ(in_attrs->size(), 1U);
   CHECK_EQ(out_attrs->size(), param.get_prob ? 2U : 1U);
@@ -98,10 +122,10 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
-inline bool SampleMultinomialOpType(const nnvm::NodeAttrs& attrs,
+inline bool SampleCategoricalOpType(const nnvm::NodeAttrs& attrs,
                                     std::vector<int>* in_attrs,
                                     std::vector<int>* out_attrs) {
-  const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(attrs.parsed);
+  const SampleCategoricalParam& param = nnvm::get<SampleCategoricalParam>(attrs.parsed);
 
   CHECK_EQ(in_attrs->size(), 1U);
   CHECK_EQ(out_attrs->size(), param.get_prob ? 2U : 1U);
@@ -116,7 +140,65 @@ inline bool SampleMultinomialOpType(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
-struct SampleMultinomialKernel {
+inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs,
+                                     mxnet::ShapeVector* in_attrs,
+                                     mxnet::ShapeVector* out_attrs) {
+  const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(attrs.parsed);
+
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+
+  const mxnet::TShape& n_shape = (*in_attrs)[0];
+  const mxnet::TShape& p_shape = (*in_attrs)[1];
+  if (!ndim_is_known(n_shape) || !ndim_is_known(p_shape) || n_shape.ndim() + 1 != p_shape.ndim())
+    return false;
+
+  mxnet::TShape oshape(p_shape.ndim() + param.shape.ndim(), -1);
+  for (int i = 0; i < p_shape.ndim() - 1; ++i) {
+    if (n_shape[i] != p_shape[i])
+      return false;
+    oshape[i] = p_shape[i];
+  }
+  for (int i = 0; i < param.shape.ndim(); ++i) {
+    oshape[i + p_shape.ndim() - 1] = param.shape[i];
+  }
+  oshape[p_shape.ndim() + param.shape.ndim() - 1] = p_shape[p_shape.ndim() - 1];
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
+
+  return true;
+}
+
+inline bool SampleMultinomialOpType(const nnvm::NodeAttrs& attrs,
+                                    std::vector<int>* in_attrs,
+                                    std::vector<int>* out_attrs) {
+  const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(attrs.parsed);
+
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  int dtype     = -1;
+  int dtype_n   = (*in_attrs)[0];
+  int dtype_out = (*out_attrs)[0];
+
+  if (dtype_out != -1) {
+    dtype = dtype_out;
+    if (param.dtype != -1) {
+      CHECK_EQ(dtype_out, param.dtype)
+          << "Output type does not match requested type: " << dtype_out << " vs " << param.dtype;
+    }
+  } else {
+    if (dtype_n != -1) {
+      dtype = dtype_n;
+    } else {
+      dtype = mxnet::common::GetDefaultDtype();
+    }
+  }
+
+  TYPE_ASSIGN_CHECK(*out_attrs, 0, dtype);
+
+  return true;
+}
+
+struct SampleCategoricalKernel {
   template <typename DType, typename IType>
   MSHADOW_XINLINE static void Map(index_t i,
                                   index_t K,
@@ -152,15 +234,60 @@ struct SampleMultinomialKernel {
   }
 };
 
+template <typename xpu, typename NType, typename PType, typename OType>
+MSHADOW_XINLINE void SampleMultinomial(NType N,
+                                       const PType* p,
+                                       OType* out,
+                                       index_t K,
+                                       typename RandGenerator<xpu, float>::Impl* gen) {
+  PType remaining_p = 1.0;
+  NType dN          = N;
+
+  int j;
+  for (j = 0; j < K - 1; ++j) {
+    out[j] = SampleBinomial<xpu, PType, OType>(static_cast<PType>(dN), p[j] / remaining_p, gen);
+    dN     = dN - out[j];
+
+    if (dN <= 0)
+      break;
+    remaining_p -= p[j];
+  }
+  for (j = j + 1; j < K; ++j)
+    out[j] = 0;
+  if (dN > 0)
+    out[K - 1] = dN;
+}
+
 template <typename xpu>
-void SampleMultinomialForward(const nnvm::NodeAttrs& attrs,
+struct SampleMultinomialKernel {
+  template <typename NType, typename PType, typename OType>
+  MSHADOW_XINLINE static void Map(index_t id,
+                                  RandGenerator<xpu, float> gen,
+                                  const index_t N,
+                                  const index_t step,
+                                  index_t nParm,
+                                  index_t nSample,
+                                  index_t K,
+                                  const NType* n,
+                                  const PType* p,
+                                  OType* out) {
+    RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
+      index_t nBatch(1 + (nSample - 1) / nParm);
+      SampleMultinomial<xpu, NType, PType, OType>(
+          n[i / nBatch], &p[(i / nBatch) * K], &out[i * K], K, &genImpl);
+    })
+  }
+};
+
+template <typename xpu>
+void SampleCategoricalForward(const nnvm::NodeAttrs& attrs,
                               const OpContext& ctx,
                               const std::vector<TBlob>& inputs,
                               const std::vector<OpReqType>& req,
                               const std::vector<TBlob>& outputs) {
   using namespace mshadow;
   using namespace mxnet_op;
-  const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(attrs.parsed);
+  const SampleCategoricalParam& param = nnvm::get<SampleCategoricalParam>(attrs.parsed);
 
   index_t K = inputs[0].shape_[inputs[0].ndim() - 1];
   index_t N = inputs[0].Size() / K;
@@ -174,7 +301,7 @@ void SampleMultinomialForward(const nnvm::NodeAttrs& attrs,
     Tensor<xpu, 1, float> uniform(workspace.dptr_, Shape1(N * M));
     prnd->SampleUniform(&uniform, 0, 1);
     MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, {
-      Kernel<SampleMultinomialKernel, xpu>::Launch(
+      Kernel<SampleCategoricalKernel, xpu>::Launch(
           s,
           N,
           K,
@@ -188,8 +315,53 @@ void SampleMultinomialForward(const nnvm::NodeAttrs& attrs,
   });
 }
 
+template <typename xpu>
+static inline void multinomial_op(const nnvm::NodeAttrs& attrs,
+                                  const OpContext& ctx,
+                                  const OpReqType& req,
+                                  TBlob* num,
+                                  TBlob* prob,
+                                  TBlob* outputs) {
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  MSHADOW_REAL_TYPE_SWITCH(
+      num[0].type_flag_,
+      NType,
+      {MSHADOW_REAL_TYPE_SWITCH(
+          prob[0].type_flag_, PType, {MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+            RandGenerator<xpu, OType>* pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
+            RandGenerator<xpu, float>* gen  = reinterpret_cast<RandGenerator<xpu, float>*>(pgen);
+
+            Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
+            Tensor<xpu, 1, NType> n   = num->FlatTo1D<xpu, NType>(s);
+            Tensor<xpu, 1, PType> p   = prob->FlatTo1D<xpu, PType>(s);
+            index_t K                 = prob->shape_[prob->ndim() - 1];
+
+            LaunchRNG<SampleMultinomialKernel<xpu>, xpu>(s,
+                                                         gen,
+                                                         out.size(0) / K,
+                                                         n.size(0),
+                                                         out.size(0) / K,
+                                                         K,
+                                                         n.dptr_,
+                                                         p.dptr_,
+                                                         out.dptr_);
+          })})});
+}
+
+template <typename xpu>
+void SampleMultinomialForward(const nnvm::NodeAttrs& attrs,
+                              const OpContext& ctx,
+                              const std::vector<TBlob>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<TBlob>& outputs) {
+  TBlob num  = inputs[0];
+  TBlob prob = inputs[1];
+  TBlob out  = outputs[0];
+  multinomial_op<xpu>(attrs, ctx, req[0], &num, &prob, &out);
+}
+
 template <typename kernel, typename xpu>
-void SampleMultinomialBackward(const nnvm::NodeAttrs& attrs,
+void SampleCategoricalBackward(const nnvm::NodeAttrs& attrs,
                                const OpContext& ctx,
                                const std::vector<TBlob>& inputs,
                                const std::vector<OpReqType>& req,
diff --git a/src/operator/random/sample_op.cc b/src/operator/random/sample_op.cc
index 02d91f2..8d9f5d8 100644
--- a/src/operator/random/sample_op.cc
+++ b/src/operator/random/sample_op.cc
@@ -33,6 +33,7 @@ DMLC_REGISTER_PARAMETER(SampleNormalParam);
 DMLC_REGISTER_PARAMETER(SampleGammaParam);
 DMLC_REGISTER_PARAMETER(SampleExponentialParam);
 DMLC_REGISTER_PARAMETER(SamplePoissonParam);
+DMLC_REGISTER_PARAMETER(SampleBinomialParam);
 DMLC_REGISTER_PARAMETER(SampleNegBinomialParam);
 DMLC_REGISTER_PARAMETER(SampleGenNegBinomialParam);
 DMLC_REGISTER_PARAMETER(SampleRandIntParam);
@@ -42,6 +43,7 @@ DMLC_REGISTER_PARAMETER(SampleNormalLikeParam);
 DMLC_REGISTER_PARAMETER(SampleGammaLikeParam);
 DMLC_REGISTER_PARAMETER(SampleExponentialLikeParam);
 DMLC_REGISTER_PARAMETER(SamplePoissonLikeParam);
+DMLC_REGISTER_PARAMETER(SampleBinomialLikeParam);
 DMLC_REGISTER_PARAMETER(SampleNegBinomialLikeParam);
 DMLC_REGISTER_PARAMETER(SampleGenNegBinomialLikeParam);
 
@@ -149,6 +151,20 @@ Example::
                                   [ 4.,  6.]]
 )code" ADD_FILELINE);
 
+MXNET_OPERATOR_REGISTER_SAMPLE(_random_binomial, SampleBinomialParam)
+    .add_alias("random_binomial")
+    .describe(R"code(Draw random samples from a binomial distribution.
+
+Samples are distributed according to a binomial distribution parametrized by
+*n* (number of experiments) and *p* (success probability in each experiment).
+Samples will always be returned as a floating point data type.
+
+Example::
+
+   binomial(n=3, p=0.4, shape=(2,2)) = [[ 1.,  0.],
+                                        [ 1.,  2.]]
+)code" ADD_FILELINE);
+
 MXNET_OPERATOR_REGISTER_SAMPLE(_random_negative_binomial, SampleNegBinomialParam)
     .add_alias("random_negative_binomial")
     .describe(R"code(Draw random samples from a negative binomial distribution.
diff --git a/src/operator/random/sample_op.cu b/src/operator/random/sample_op.cu
index 8ffc633..1fc6239 100644
--- a/src/operator/random/sample_op.cu
+++ b/src/operator/random/sample_op.cu
@@ -36,6 +36,7 @@ MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_normal, SampleNormalParam)
 MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_gamma, SampleGammaParam)
 MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_exponential, SampleExponentialParam)
 MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_poisson, SamplePoissonParam)
+MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_binomial, SampleBinomialParam)
 MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_negative_binomial, SampleNegBinomialParam)
 MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_generalized_negative_binomial, SampleGenNegBinomialParam)
 MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_randint, SampleRandIntParam)
@@ -44,6 +45,7 @@ MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_normal_like, SampleNormalLikeParam)
 MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_gamma_like, SampleGammaLikeParam)
 MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_exponential_like, SampleExponentialLikeParam)
 MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_poisson_like, SamplePoissonLikeParam)
+MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_binomial_like, SampleBinomialLikeParam)
 MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_negative_binomial_like, SampleNegBinomialLikeParam)
 MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_generalized_negative_binomial_like,
                                    SampleGenNegBinomialLikeParam)
diff --git a/src/operator/random/sample_op.h b/src/operator/random/sample_op.h
index 05d4363..cfff87c 100644
--- a/src/operator/random/sample_op.h
+++ b/src/operator/random/sample_op.h
@@ -67,6 +67,11 @@ struct PoissonParam {
   float lam;
 };
 
+struct BinomialParam {
+  int n;
+  float p;
+};
+
 struct NegBinomialParam {
   int k;
   float p;
@@ -190,6 +195,28 @@ struct SamplePoissonParam : public dmlc::Parameter<SamplePoissonParam>,
   }
 };
 
+struct SampleBinomialParam : public dmlc::Parameter<SampleBinomialParam>,
+                             BinomialParam,
+                             SampleOpParam {
+  DMLC_DECLARE_PARAMETER(SampleBinomialParam) {
+    DMLC_DECLARE_FIELD(n).set_default(1).describe("number of experiments.");
+    DMLC_DECLARE_FIELD(p).set_default(1.0f).describe("success probability in each experiment.");
+    DMLC_DECLARE_FIELD(shape).set_default(mxnet::TShape()).describe("Shape of the output.");
+    DMLC_DECLARE_FIELD(ctx).set_default("").describe(
+        "Context of output, in format [cpu|gpu|cpu_pinned](n)."
+        " Only used for imperative calls.");
+    DMLC_DECLARE_FIELD(dtype)
+        .add_enum("None", -1)
+        .add_enum("float32", mshadow::kFloat32)
+        .add_enum("float64", mshadow::kFloat64)
+        .add_enum("float16", mshadow::kFloat16)
+        .set_default(-1)
+        .describe(
+            "DType of the output in case this can't be inferred. "
+            "Defaults to float32 if not defined (dtype=None).");
+  }
+};
+
 struct SampleNegBinomialParam : public dmlc::Parameter<SampleNegBinomialParam>,
                                 NegBinomialParam,
                                 SampleOpParam {
@@ -306,6 +333,13 @@ struct SamplePoissonLikeParam : public dmlc::Parameter<SamplePoissonLikeParam>,
   }
 };
 
+struct SampleBinomialLikeParam : public dmlc::Parameter<SampleBinomialLikeParam>, BinomialParam {
+  DMLC_DECLARE_PARAMETER(SampleBinomialLikeParam) {
+    DMLC_DECLARE_FIELD(n).set_default(1).describe("Number of experiments.");
+    DMLC_DECLARE_FIELD(p).set_default(1.0f).describe("success probability in each experiment.");
+  }
+};
+
 struct SampleNegBinomialLikeParam : public dmlc::Parameter<SampleNegBinomialLikeParam>,
                                     NegBinomialParam {
   DMLC_DECLARE_PARAMETER(SampleNegBinomialLikeParam) {
@@ -443,6 +477,25 @@ static inline void poisson_op(const nnvm::NodeAttrs& attrs,
 }
 
 template <typename xpu, typename ParamType>
+static inline void binomial_op(const nnvm::NodeAttrs& attrs,
+                               const OpContext& ctx,
+                               const OpReqType& req,
+                               TBlob* outputs) {
+  Stream<xpu>* s             = ctx.get_stream<xpu>();
+  const BinomialParam& param = nnvm::get<ParamType>(attrs.parsed);
+  CHECK_GE(param.n, 0) << "n parameter in binomial distribution has to be non-negative";
+  CHECK_GE(param.p, 0) << "p parameter in binomial distribution has to be non-negative";
+  Tensor<xpu, 1, float> n, p;
+  GetSamplingTempData<xpu, float>(param.n, param.p, ctx, &n, &p);
+  BinomialSampler<xpu> sampler;
+  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+    RandGenerator<xpu, OType>* pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
+    Tensor<xpu, 1, OType> out       = outputs->FlatTo1D<xpu, OType>(s);
+    sampler.Sample(n, p, out, pgen, s);
+  });
+}
+
+template <typename xpu, typename ParamType>
 static inline void neg_binomial_op(const nnvm::NodeAttrs& attrs,
                                    const OpContext& ctx,
                                    const OpReqType& req,
@@ -604,6 +657,26 @@ struct SampleMaster<xpu, SamplePoissonLikeParam> {
 };
 
 template <typename xpu>
+struct SampleMaster<xpu, SampleBinomialParam> {
+  static inline void op(const nnvm::NodeAttrs& attrs,
+                        const OpContext& ctx,
+                        const OpReqType& req,
+                        TBlob* outputs) {
+    binomial_op<xpu, SampleBinomialParam>(attrs, ctx, req, outputs);
+  }
+};
+
+template <typename xpu>
+struct SampleMaster<xpu, SampleBinomialLikeParam> {
+  static inline void op(const nnvm::NodeAttrs& attrs,
+                        const OpContext& ctx,
+                        const OpReqType& req,
+                        TBlob* outputs) {
+    binomial_op<xpu, SampleBinomialLikeParam>(attrs, ctx, req, outputs);
+  }
+};
+
+template <typename xpu>
 struct SampleMaster<xpu, SampleNegBinomialParam> {
   static inline void op(const nnvm::NodeAttrs& attrs,
                         const OpContext& ctx,
diff --git a/src/operator/random/sampler.h b/src/operator/random/sampler.h
index 296833c..0d896b1 100644
--- a/src/operator/random/sampler.h
+++ b/src/operator/random/sampler.h
@@ -318,6 +318,201 @@ struct PoissonSampler {
   }
 };
 
+MSHADOW_XINLINE double stirling_approximation(double k) {
+  static const double table[] = {0.08106146679532726,
+                                 0.04134069595540929,
+                                 0.02767792568499834,
+                                 0.02079067210376509,
+                                 0.01664469118982119,
+                                 0.01387612882307075,
+                                 0.01189670994589177,
+                                 0.01041126526197209,
+                                 0.009255462182712733,
+                                 0.008330563433362871};
+
+  if (k <= 9)
+    return table[static_cast<int>(k)];
+
+  return (1.0 / 12 - (1.0 / 360 - 1.0 / 1260 / ((k + 1) * (k + 1))) / (((k + 1) * (k + 1)))) /
+         (k + 1);
+}
+
+// The algorithm is explained in https://www.tandfonline.com/doi/abs/10.1080/00949659308811496
+template <typename xpu, typename IType, typename OType>
+MSHADOW_XINLINE OType _sample_binomial_btrd(IType N,
+                                            IType p,
+                                            typename RandGenerator<xpu, float>::Impl* gen) {
+  OType m   = floor((N + 1) * p);
+  OType r   = p / (1 - p);
+  OType nr  = (N + 1) * r;
+  OType npq = N * p * (1 - p);
+
+  OType b     = 1.15 + 2.53 * sqrt(npq);
+  OType a     = -0.0873 + 0.0248 * b + 0.01 * p;
+  OType c     = N * p + 0.5;
+  OType alpha = (2.83 + 5.1 / b) * sqrt(npq);
+
+  OType v_r      = 0.92 - 4.2 / b;
+  OType u_r__v_r = 0.86 * v_r;
+
+  while (true) {
+    OType v = gen->uniform();
+    if (v <= u_r__v_r) {
+      OType u = v / v_r - 0.43;
+
+      return floor((2 * a / (0.5 - abs(u)) + b) * u + c);
+    }
+
+    OType u;
+    if (v >= v_r) {
+      u = gen->uniform() - 0.5;
+    } else {
+      u           = v / v_r - 0.93;
+      OType sgn_u = ((0 < u) - (u < 0));
+      u           = sgn_u * 0.5 - u;
+
+      v = gen->uniform() * v_r;
+    }
+
+    OType us = 0.5 - abs(u);
+    OType k  = floor((2 * a / us + b) * u + c);
+    if (k < 0 || k > N) {
+      continue;
+    }
+
+    v = v * alpha / (a / (us * us) + b);
+
+    OType km = abs(k - m);
+    if (km <= 15) {
+      OType f = 1;
+      for (double i = m; i < k; ++i)
+        f = f * (nr / i - r);
+      for (double i = k; i < m; ++i)
+        v = v * (nr / i - r);
+
+      if (v <= f) {
+        return k;
+      }
+
+      continue;
+    }
+
+    v         = log(v);
+    OType rho = (km / npq) * (((km / 3 + 0.625) * km + 1.0 / 6) / npq + 0.5);
+    OType t   = -km * km / (2 * npq);
+    if (v < t - rho) {
+      return k;
+    }
+
+    if (v > t + rho) {
+      continue;
+    }
+
+    OType nm = N - m + 1;
+    OType h  = (m + 0.5) * log((m + 1) / (r * nm)) + stirling_approximation(m) +
+              stirling_approximation(N - m);
+
+    OType nk  = N - k + 1;
+    OType tmp = h + (N + 1) * log(nm / nk) + (k + 0.5) * log(nk * r / (k + 1)) -
+                stirling_approximation(k) - stirling_approximation(N - k);
+    if (v <= tmp) {
+      return k;
+    }
+  }
+}
+
+template <typename xpu, typename IType, typename OType>
+MSHADOW_XINLINE OType _sample_binomial_inversion(IType n,
+                                                 IType p,
+                                                 typename RandGenerator<xpu, float>::Impl* gen) {
+  OType N = static_cast<OType>(n);
+  OType q = static_cast<OType>(p);
+  if (q > 0.5)
+    q = 1 - q;
+
+  OType s = 1 - q;
+
+  OType A = 1;
+  OType B = q / s;
+  OType C = (N + 1) * B;
+  OType D = A;
+  OType X = 0;
+
+  OType U = gen->uniform();
+  OType V = U / pow(s, N);
+
+  do {
+    if (V <= A)
+      break;
+    X = X + 1;
+    D = D * (C / X - B);
+    A = A + D;
+  } while (X < N);
+
+  if (p > 0.5)
+    return N - X;
+
+  return X;
+}
+
+template <typename xpu, typename IType, typename OType>
+MSHADOW_XINLINE OType SampleBinomial(IType n,
+                                     IType p,
+                                     typename RandGenerator<xpu, float>::Impl* gen) {
+  // Generate one sample of the binomial distribution
+  if (p >= 1) {
+    return static_cast<OType>(n);
+  }
+
+  if (p <= 0.5) {
+    if (n * p >= 10.0) {
+      return _sample_binomial_btrd<xpu, IType, OType>(n, p, gen);
+    } else {
+      return _sample_binomial_inversion<xpu, IType, OType>(n, p, gen);
+    }
+  } else {
+    IType q = 1.0 - p;
+    if (n * q >= 10.0) {
+      return n - _sample_binomial_btrd<xpu, IType, OType>(n, q, gen);
+    } else {
+      return n - _sample_binomial_inversion<xpu, IType, OType>(n, q, gen);
+    }
+  }
+}
+
+template <typename xpu>
+struct SampleBinomialKernel {
+  template <typename IType, typename OType>
+  MSHADOW_XINLINE static void Map(index_t id,
+                                  RandGenerator<xpu, float> gen,
+                                  const index_t N,
+                                  const index_t step,
+                                  index_t nParm,
+                                  index_t nSample,
+                                  const IType* n,
+                                  const IType* p,
+                                  OType* out) {
+    RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
+      index_t nBatch(1 + (nSample - 1) / nParm);
+      out[i] = SampleBinomial<xpu, IType, OType>(n[i / nBatch], p[i / nBatch], &genImpl);
+    });
+  }
+};
+
+template <typename xpu>
+struct BinomialSampler {
+  template <typename IType, typename OType>
+  MSHADOW_FORCE_INLINE void Sample(const Tensor<xpu, 1, IType>& n,
+                                   const Tensor<xpu, 1, IType>& p,
+                                   const Tensor<xpu, 1, OType>& out,
+                                   RandGenerator<xpu, OType>* pgen,
+                                   Stream<xpu>* s) {
+    RandGenerator<xpu, float>* gen = reinterpret_cast<RandGenerator<xpu, float>*>(pgen);
+    LaunchRNG<SampleBinomialKernel<xpu>, xpu>(
+        s, gen, out.size(0), n.size(0), out.size(0), n.dptr_, p.dptr_, out.dptr_);
+  }
+};
+
 template <typename xpu>
 struct SampleNegativeBinomialKernel {
   template <typename IType, typename OType>
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index bd3624b..8008c05 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -5503,6 +5503,33 @@ def test_npx_categorical():
 
 
 @use_np
+def test_npx_multinomial():
+    class TestNumpyMultinomial(HybridBlock):
+        def __init__(self, size=None):
+            super(TestNumpyMultinomial, self).__init__()
+            self.size = size
+
+        def forward(self, n, prob):
+            if self.size is None:
+                return npx.random.multinomial(n, prob)
+            return npx.random.multinomial(n, prob, shape=self.size)
+
+    batch_sizes = [(2,), (2, 3)]
+    event_shapes = [None, (10,), (10, 12)]
+    num_event = [2, 4, 10]
+    for batch_size, num_event, event_shape in itertools.product(batch_sizes, num_event, event_shapes):
+        for hybridize in [True, False]:
+            n = np.ones(batch_size)
+            prob = np.ones(batch_size + (num_event,)) / num_event
+            net = TestNumpyMultinomial(event_shape)
+            if hybridize:
+                net.hybridize()
+            mx_out = net(n, prob)
+            desired_shape = batch_size + event_shape + (num_event,) if event_shape is not None else batch_size + (num_event,)
+            assert mx_out.shape == desired_shape
+
+
+@use_np
 def test_random_seed():
     for seed in [234, 594, 7240, 20394]:
         ret = []
diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py
index 3fc061b..cb3eae3 100644
--- a/tests/python/unittest/test_random.py
+++ b/tests/python/unittest/test_random.py
@@ -425,16 +425,16 @@ def test_random_seed_setting():
     num_samples = 100000
     for dtype in ['float16', 'float32', 'float64']:
         seed = set_seed_variously(1, num_temp_seeds, seed_to_test)
-        samples1 = mx.nd.random.multinomial(data=mx.nd.array(probs, ctx=ctx, dtype=dtype),
+        samples1 = mx.nd.random.categorical(data=mx.nd.array(probs, ctx=ctx, dtype=dtype),
                                             shape=num_samples)
         seed = set_seed_variously(seed, num_temp_seeds, seed_to_test)
-        samples2 = mx.nd.random.multinomial(data=mx.nd.array(probs, ctx=ctx, dtype=dtype),
+        samples2 = mx.nd.random.categorical(data=mx.nd.array(probs, ctx=ctx, dtype=dtype),
                                             shape=num_samples)
         samples1np = samples1.asnumpy()
         set_seed_variously(seed, num_temp_seeds, seed_to_test+1)
         samples2np = samples2.asnumpy()
         assert same(samples1np, samples2np), \
-            "seed-setting test: `multinomial` should give the same result with the same seed"
+            "seed-setting test: `categorical` should give the same result with the same seed"
 
 
 # Tests that seed setting of parallel rng is synchronous w.r.t. rng use before and after.
@@ -504,13 +504,13 @@ def test_random_seed_setting_for_context():
                 ctx = mx.context.current_context()
                 seed = set_seed_variously_for_context(ctx, 1, num_temp_seeds, seed_to_test)
 
-                # Check imperative. `multinomial` uses non-parallel rng.
-                rnds = mx.nd.random.multinomial(data=mx.nd.array(probs, dtype=dtype), shape=num_samples)
+                # Check imperative. `categorical` uses non-parallel rng.
+                rnds = mx.nd.random.categorical(data=mx.nd.array(probs, dtype=dtype), shape=num_samples)
                 samples_imp.append(rnds.asnumpy())
 
-                # Check symbolic. `multinomial` uses non-parallel rng.
+                # Check symbolic. `categorical` uses non-parallel rng.
                 P = mx.sym.Variable("P")
-                X = mx.sym.random.multinomial(data=P, shape=num_samples, get_prob=False)
+                X = mx.sym.random.categorical(data=P, shape=num_samples, get_prob=False)
                 exe = X._bind(ctx, {"P": mx.nd.array(probs, dtype=dtype)})
                 set_seed_variously_for_context(ctx, seed, num_temp_seeds, seed_to_test)
                 exe.forward()
@@ -566,14 +566,14 @@ def test_parallel_random_seed_setting_for_context():
 @pytest.mark.parametrize('dtype', ['uint8', 'int32', 'float16', 'float32', 'float64'])
 @pytest.mark.parametrize('x', [[[0,1,2,3,4],[4,3,2,1,0]], [0,1,2,3,4]])
 @pytest.mark.serial
-def test_sample_multinomial(dtype, x):
+def test_sample_categorical(dtype, x):
     x = mx.nd.array(x) / 10.0
     dx = mx.nd.ones_like(x)
     mx.autograd.mark_variables([x], [dx])
     # Adding rtol and increasing samples needed to pass with seed 2951820647
     samples = 10000
     with mx.autograd.record():
-        y, prob = mx.nd.random.multinomial(x, shape=samples, get_prob=True, dtype=dtype)
+        y, prob = mx.nd.random.categorical(x, shape=samples, get_prob=True, dtype=dtype)
         r = prob * 5
         r.backward()
 
@@ -685,6 +685,24 @@ def test_poisson_generator():
             verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
 
 @pytest.mark.serial
+def test_binomial_generator():
+    ctx = mx.context.current_context()
+    for dtype in ['float16', 'float32', 'float64']:
+        trials_num = 10000
+        success_prob = 0.25
+
+        buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.binom.ppf(x, trials_num, success_prob), 10)
+        generator_mx = lambda x: mx.nd.random.binomial(trials_num, success_prob,
+                                                                shape=x, ctx=ctx, dtype=dtype).asnumpy()
+        nsamples = 1000
+        verify_generator(generator=generator_mx, buckets=buckets, probs=probs, nsamples=nsamples)
+        generator_mx_same_seed = \
+            lambda x: np.concatenate(
+                [mx.nd.random.binomial(trials_num, success_prob, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
+                 for _ in range(10)])
+        verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs, nsamples=nsamples)
+
+@pytest.mark.serial
 def test_negative_binomial_generator():
     ctx = mx.context.current_context()
     for dtype in ['float16', 'float32', 'float64']:
@@ -714,7 +732,7 @@ def test_negative_binomial_generator():
         verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
 
 @pytest.mark.serial
-def test_multinomial_generator():
+def test_categorical_generator():
     # This test fails with dtype float16 if the probabilities themselves cannot be
     # well-represented in float16.  When the float16 random picks are assigned to buckets,
     # only certain bucket-probabilities are possible.  Here we map the desired probabilites
@@ -739,7 +757,7 @@ def test_multinomial_generator():
     buckets = list(range(6))
     for dtype in ['float16', 'float32', 'float64']:
         quantized_probs = quantize_probs(probs, dtype)
-        generator_mx = lambda x: mx.nd.random.multinomial(data=mx.nd.array(quantized_probs, ctx=ctx, dtype=dtype),
+        generator_mx = lambda x: mx.nd.random.categorical(data=mx.nd.array(quantized_probs, ctx=ctx, dtype=dtype),
                                                           shape=x).asnumpy()
         # success_rate was set to 0.15 since PR #13498 and became flaky
         # both of previous issues(#14457, #14158) failed with success_rate 0.25
@@ -750,7 +768,7 @@ def test_multinomial_generator():
                          nsamples=samples, nrepeat=trials, success_rate=0.20)
         generator_mx_same_seed = \
             lambda x: np.concatenate(
-                [mx.nd.random.multinomial(data=mx.nd.array(quantized_probs, ctx=ctx, dtype=dtype),
+                [mx.nd.random.categorical(data=mx.nd.array(quantized_probs, ctx=ctx, dtype=dtype),
                                                           shape=x // 10).asnumpy()
                  for _ in range(10)])
         verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=quantized_probs,
@@ -758,6 +776,35 @@ def test_multinomial_generator():
 
 
 @pytest.mark.serial
+def test_multinomial_generator():
+    def repeat_i(arr):
+        """
+        Return an array containing ordered values from 0 to arr.size()-1,
+        where each value i is repeated arr[i] times.
+
+        Example:
+        >>> repeat_i([3, 1, 2, 1])
+        [0, 0, 0, 1, 2, 2, 3]
+        """
+        ind = mx.nd.expand_dims(mx.nd.cumsum(mx.nd.concat(mx.nd.array([0]), arr[:arr.size-1], dim=0)), axis=0)
+        data = mx.nd.ones((arr.size,))
+        shape = (int(mx.nd.sum(arr).asscalar()),)
+        return mx.nd.cumsum(mx.nd.scatter_nd(data, ind, shape)) - 1
+
+    ctx = mx.context.current_context()
+    probs = np.array([0.1, 0.2, 0.3, 0.05, 0.15, 0.2])
+
+    buckets = list(range(6))
+    for dtype in ['float16', 'float32', 'float64']:
+        generator_mx = lambda x: repeat_i(mx.nd.random.multinomial(n=mx.nd.array([x]), p=mx.nd.array([probs]), ctx=ctx)[0]).asnumpy()
+        verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
+
+        generator_mx_same_seed = \
+            lambda x: np.concatenate([generator_mx(x // 10) for _ in range(10)])
+        verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
+
+
+@pytest.mark.serial
 def test_with_random_seed():
     ctx = mx.context.current_context()
     size = 100
@@ -1022,12 +1069,12 @@ def test_randint_without_dtype():
 
 
 @pytest.mark.serial
-def test_sample_multinomial_num_outputs():
+def test_sample_categorical_num_outputs():
     ctx = mx.context.current_context()
     probs = [[0.125, 0.25, 0.25], [0.0625, 0.125, 0.1875]]
-    out = mx.nd.random.multinomial(data=mx.nd.array(probs, ctx=ctx), shape=10000, get_prob=False)
+    out = mx.nd.random.categorical(data=mx.nd.array(probs, ctx=ctx), shape=10000, get_prob=False)
     assert isinstance(out, mx.nd.NDArray)
-    out = mx.nd.random.multinomial(data=mx.nd.array(probs, ctx=ctx), shape=10000, get_prob=True)
+    out = mx.nd.random.categorical(data=mx.nd.array(probs, ctx=ctx), shape=10000, get_prob=True)
     assert isinstance(out, list)
     assert len(out) == 2
 
@@ -1079,4 +1126,4 @@ def test_poisson_zero_size_dim():
         assertRaises(MXNetError, mx.nd.op.random_pdf_poisson, sample, lam)
 
     test_valid_zero_dim()
-    test_invalid_zero_dim()
\ No newline at end of file
+    test_invalid_zero_dim()