You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/06/20 07:14:59 UTC

[incubator-mxnet] branch numpy updated: Numpy compatible multinomial (#15219)

This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/numpy by this push:
     new 1b62426  Numpy compatible multinomial (#15219)
1b62426 is described below

commit 1b62426f7c03863d5c59a3b25606fb5bb996880e
Author: Jake Lee <gs...@gmail.com>
AuthorDate: Thu Jun 20 00:14:36 2019 -0700

    Numpy compatible multinomial (#15219)
    
    * draft of multinomial
    
    * rename to more concise name
    
    * finish shape
    
    * complete the forward function
    
    * complete forward without handle 0 dimension & scalar
    
    * handle 0 dimension
    
    * add new line
    
    * fix lint
    
    * fix the build error
    
    * fix lint
    
    * finish unit test
    
    * change the registration
    
    * make multinomial support pvals as mx.ndarray
    
    * delete newline
    
    * fix lint error
    
    * support input as list, mx.ndarray, np.ndarray & unit test
    
    * fix lint
    
    * fix the include error
    
    * fix lint
    
    * refactor & pass the tensor instead of tuple to kernel
    
    * fix lint
    
    * updata the doc
    
    * address the comment
---
 python/mxnet/_numpy_op_doc.py                  |  30 ++++
 python/mxnet/ndarray/numpy/random.py           |  41 +++++-
 python/mxnet/numpy/random.py                   |  30 ++++
 src/operator/numpy/random/np_multinomial_op.cc |  61 ++++++++
 src/operator/numpy/random/np_multinomial_op.cu |  34 +++++
 src/operator/numpy/random/np_multinomial_op.h  | 193 +++++++++++++++++++++++++
 tests/python/unittest/test_numpy_ndarray.py    |  47 +++++-
 7 files changed, 434 insertions(+), 2 deletions(-)

diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py
index 9265a98..ab81732 100644
--- a/python/mxnet/_numpy_op_doc.py
+++ b/python/mxnet/_numpy_op_doc.py
@@ -109,3 +109,33 @@ def _np_repeat(a, repeats, axis=None):
         the given axis.
     """
     pass
+
+
+def _npi_multinomial(a):
+    """Draw samples from a multinomial distribution.
+
+    The multinomial distribution is a multivariate generalisation of the binomial distribution.
+    Take an experiment with one of ``p`` possible outcomes. An example of such an experiment is throwing a dice,
+    where the outcome can be 1 through 6. Each sample drawn from the distribution represents n such experiments.
+    Its values, ``X_i = [X_0, X_1, ..., X_p]``, represent the number of times the outcome was ``i``.
+
+
+    Parameters
+    ----------
+    n : int
+        Number of experiments.
+    pvals : sequence of floats, length p
+        Probabilities of each of the p different outcomes. These should sum to 1
+        (however, the last element is always assumed to account for the remaining
+        probability, as long as ``sum(pvals[:-1]) <= 1)``.
+    size : int or tuple of ints, optional
+        Output shape. If the given shape is, e.g., ``(m, n, k)``, then ``m * n * k`` sam-
+        ples are drawn. Default is None, in which case a single value is returned.
+
+    Returns
+    -------
+    out : ndarray
+        The drawn samples, of shape size, if that was provided. If not, the shape is ``(N,)``.
+        In other words, each entry ``out[i,j,...,:]`` is an N-dimensional value drawn from the distribution.
+    """
+    pass
diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py
index 3d9fd6a..8607fd5 100644
--- a/python/mxnet/ndarray/numpy/random.py
+++ b/python/mxnet/ndarray/numpy/random.py
@@ -17,11 +17,13 @@
 
 """Namespace for operators used in Gluon dispatched by F=ndarray."""
 from __future__ import absolute_import
+import numpy as np
 from ...base import numeric_types
 from ...context import current_context
+from ..ndarray import NDArray
 from . import _internal as _npi
 
-__all__ = ['uniform', 'normal']
+__all__ = ['uniform', 'normal', 'multinomial']
 
 
 def _random_helper(random, sampler, params, shape, dtype, ctx, out, kwargs):
@@ -135,3 +137,40 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs):
     out = kwargs.pop('out', None)
     return _random_helper(_npi.random_normal, None,
                           [loc, scale], size, dtype, ctx, out, kwargs)
+
+
+def multinomial(n, pvals, size=None):
+    """Draw samples from a multinomial distribution.
+
+    The multinomial distribution is a multivariate generalisation of the binomial distribution.
+    Take an experiment with one of ``p`` possible outcomes. An example of such an experiment is throwing a dice,
+    where the outcome can be 1 through 6. Each sample drawn from the distribution represents n such experiments.
+    Its values, ``X_i = [X_0, X_1, ..., X_p]``, represent the number of times the outcome was ``i``.
+
+
+    Parameters
+    ----------
+    n : int
+        Number of experiments.
+    pvals : sequence of floats, length p
+        Probabilities of each of the p different outcomes. These should sum to 1
+        (however, the last element is always assumed to account for the remaining
+        probability, as long as ``sum(pvals[:-1]) <= 1)``.
+    size : int or tuple of ints, optional
+        Output shape. If the given shape is, e.g., ``(m, n, k)``, then ``m * n * k`` sam-
+        ples are drawn. Default is None, in which case a single value is returned.
+
+    Returns
+    -------
+    out : ndarray
+        The drawn samples, of shape size, if that was provided. If not, the shape is ``(N,)``.
+        In other words, each entry ``out[i,j,...,:]`` is an N-dimensional value drawn from the distribution.
+    """
+    if isinstance(pvals, NDArray):
+        return _npi.multinomial(pvals, pvals=None, n=n, size=size)
+    else:
+        if isinstance(pvals, np.ndarray):
+            pvals = pvals.tolist()
+        if any(isinstance(i, list) for i in pvals):
+            raise ValueError('object too deep for desired array')
+        return _npi.multinomial(n=n, pvals=pvals, size=size)
diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py
index baeab8b..cda1ada 100644
--- a/python/mxnet/numpy/random.py
+++ b/python/mxnet/numpy/random.py
@@ -98,3 +98,33 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs):
     This function currently does not support ``loc`` and ``scale`` as ndarrays.
     """
     return _mx_nd_np.random.normal(loc, scale, size, **kwargs)
+
+
+def multinomial(n, pvals, size=None, **kwargs):
+    """Draw samples from a multinomial distribution.
+
+    The multinomial distribution is a multivariate generalisation of the binomial distribution.
+    Take an experiment with one of ``p`` possible outcomes. An example of such an experiment is throwing a dice,
+    where the outcome can be 1 through 6. Each sample drawn from the distribution represents n such experiments.
+    Its values, ``X_i = [X_0, X_1, ..., X_p]``, represent the number of times the outcome was ``i``.
+
+
+    Parameters
+    ----------
+    n : int
+        Number of experiments.
+    pvals : sequence of floats, length p
+        Probabilities of each of the p different outcomes. These should sum to 1
+        (however, the last element is always assumed to account for the remaining
+        probability, as long as ``sum(pvals[:-1]) <= 1)``.
+    size : int or tuple of ints, optional
+        Output shape. If the given shape is, e.g., ``(m, n, k)``, then ``m * n * k`` sam-
+        ples are drawn. Default is None, in which case a single value is returned.
+
+    Returns
+    -------
+    out : ndarray
+        The drawn samples, of shape size, if that was provided. If not, the shape is ``(N,)``.
+        In other words, each entry ``out[i,j,...,:]`` is an N-dimensional value drawn from the distribution.
+    """
+    return _mx_nd_np.random.multinomial(n, pvals, size, **kwargs)
diff --git a/src/operator/numpy/random/np_multinomial_op.cc b/src/operator/numpy/random/np_multinomial_op.cc
new file mode 100644
index 0000000..bf4f88c
--- /dev/null
+++ b/src/operator/numpy/random/np_multinomial_op.cc
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_multinomial_op.h
+ * \brief Operator for numpy sampling from multinomial distributions
+ */
+#include "./np_multinomial_op.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(NumpyMultinomialParam);
+
+NNVM_REGISTER_OP(_npi_multinomial)
+.describe(R"code(Draw samples from a multinomial distribution. "
+"The multinomial distribution is a multivariate generalisation of the binomial distribution. "
+"Take an experiment with one of p possible outcomes. "
+"An example of such an experiment is throwing a dice, where the outcome can be 1 through 6. "
+"Each sample drawn from the distribution represents n such experiments. "
+"Its values, X_i = [X_0, X_1, ..., X_p], represent the number of times the outcome was i.
+)code")
+.set_num_inputs(
+  [](const nnvm::NodeAttrs& attrs) {
+    const NumpyMultinomialParam& param = nnvm::get<NumpyMultinomialParam>(attrs.parsed);
+    return param.pvals.has_value() ? 0U : 1U;
+  }
+)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<NumpyMultinomialParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", NumpyMultinomialOpShape)
+.set_attr<nnvm::FInferType>("FInferType", NumpyMultinomialOpType)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const nnvm::NodeAttrs& attrs) {
+      return std::vector<ResourceRequest>{
+        ResourceRequest::kRandom, ResourceRequest::kTempSpace};
+  })
+.set_attr<FCompute>("FCompute<cpu>", NumpyMultinomialForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_argument("a", "NDArray-or-Symbol", "Source input")
+.add_arguments(NumpyMultinomialParam::__FIELDS__());
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/numpy/random/np_multinomial_op.cu b/src/operator/numpy/random/np_multinomial_op.cu
new file mode 100644
index 0000000..a809260
--- /dev/null
+++ b/src/operator/numpy/random/np_multinomial_op.cu
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_multinomial_op.cu
+ * \brief Operator for numpy sampling from multinomial distributions
+ */
+#include "./np_multinomial_op.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(_npi_multinomial)
+.set_attr<FCompute>("FCompute<gpu>", NumpyMultinomialForward<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/numpy/random/np_multinomial_op.h b/src/operator/numpy/random/np_multinomial_op.h
new file mode 100644
index 0000000..39515b4
--- /dev/null
+++ b/src/operator/numpy/random/np_multinomial_op.h
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file np_multinomial_op.h
+ * \brief Operator for sampling from multinomial distributions
+ */
+#ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_MULTINOMIAL_OP_H_
+#define MXNET_OPERATOR_NUMPY_RANDOM_NP_MULTINOMIAL_OP_H_
+
+#include <mxnet/operator_util.h>
+#include <vector>
+#include "../../mshadow_op.h"
+#include "../../mxnet_op.h"
+#include "../../operator_common.h"
+#include "../../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+struct NumpyMultinomialParam : public dmlc::Parameter<NumpyMultinomialParam> {
+  int n;
+  dmlc::optional<mxnet::Tuple<double>> pvals;
+  dmlc::optional<mxnet::Tuple<int>> size;
+  DMLC_DECLARE_PARAMETER(NumpyMultinomialParam) {
+    DMLC_DECLARE_FIELD(n)
+      .describe("Number of experiments.");
+    DMLC_DECLARE_FIELD(pvals)
+      .set_default(dmlc::optional<mxnet::Tuple<double>>())
+      .describe("Probabilities of each of the p different outcomes. "
+      "These should sum to 1 (however, the last element is always assumed to "
+      "account for the remaining probability, as long as sum(pvals[:-1]) <= 1)"
+      "Note that this is for internal usage only. "
+      "This operator will only have either input mx.ndarray or this list of pvals");
+    DMLC_DECLARE_FIELD(size)
+      .set_default(dmlc::optional<mxnet::Tuple<int>>())
+      .describe("Output shape. If the given shape is, "
+      "e.g., (m, n, k), then m * n * k samples are drawn. "
+      "Default is None, in which case a single value is returned.");
+  }
+};
+
+inline bool NumpyMultinomialOpShape(const nnvm::NodeAttrs& attrs,
+                                     std::vector<TShape> *in_attrs,
+                                     std::vector<TShape> *out_attrs) {
+  const NumpyMultinomialParam& param = nnvm::get<NumpyMultinomialParam>(attrs.parsed);
+  CHECK_EQ(out_attrs->size(), 1U);
+
+  std::vector<dim_t> oshape_vec;
+  dim_t pvals_length;
+  if (param.pvals.has_value()) {
+    CHECK_EQ(in_attrs->size(), 0U);
+    pvals_length = param.pvals.value().ndim();
+  } else {
+    // pvals is from input ndarray
+    CHECK_EQ(in_attrs->size(), 1U);
+    const TShape& ishape = (*in_attrs)[0];
+    // check the input shape is only one dimension
+    CHECK_EQ(ishape.ndim(), 1U)
+      << "object too deep for desired array";
+    pvals_length = ishape[0];
+  }
+  if (param.size.has_value()) {
+    const mxnet::Tuple<int>& size = param.size.value();
+    for (int i = 0; i < size.ndim(); ++i) {
+      oshape_vec.emplace_back(size[i]);
+    }
+  }
+  oshape_vec.emplace_back(pvals_length);
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(oshape_vec));
+  return out_attrs->at(0).ndim() != 0U;;
+}
+
+inline bool NumpyMultinomialOpType(const nnvm::NodeAttrs& attrs,
+                                    std::vector<int>* in_attrs,
+                                    std::vector<int>* out_attrs) {
+  const NumpyMultinomialParam& param = nnvm::get<NumpyMultinomialParam>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), (param.pvals.has_value()) ? 0U : 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+
+  (*out_attrs)[0] = mshadow::kInt64;
+  return true;
+}
+
+struct multinomial_kernel {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i,
+                                  const int num_exp,
+                                  const int prob_length,
+                                  DType* pvals,
+                                  float* uniform,
+                                  int64_t* out) {
+    for (int j = 0; j < num_exp; ++j) {
+      DType loc = static_cast<DType>(uniform[i * num_exp + j]);
+      DType acc = 0.0;
+      bool found = false;
+      for (int k = 0; k < prob_length; ++k) {
+        acc += pvals[k];
+        if (acc > loc) {
+          found = true;
+          out[i * prob_length + k] += 1;
+          break;
+        }
+      }
+      if (!found) {
+        out[i * prob_length + (prob_length - 1)] += 1;
+      }
+    }
+  }
+};
+
+template<typename xpu>
+void NumpyMultinomialForward(const nnvm::NodeAttrs& attrs,
+                              const OpContext& ctx,
+                              const std::vector<TBlob>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  const NumpyMultinomialParam& param = nnvm::get<NumpyMultinomialParam>(attrs.parsed);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(inputs.size(), (param.pvals.has_value()) ? 0U : 1U);
+
+  int prob_length = (param.pvals.has_value())
+    ? param.pvals.value().ndim() : inputs[0].shape_[0];
+  // if intput is [] or size contains 0 dimension
+  if (prob_length == 0U || outputs[0].shape_.Size() == 0) return;
+  int num_output = outputs[0].Size() / prob_length;
+  int num_exp = param.n;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  Random<xpu, float> *prnd = ctx.requested[0].get_random<xpu, float>(s);
+  Tensor<xpu, 1, float> uniform =
+      ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(num_output * param.n), s);
+  prnd->SampleUniform(&uniform, 0, 1);
+
+  // set zero for the outputs
+  Kernel<set_zero, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<int64_t>());
+
+  if (param.pvals.has_value()) {
+    // create a tensor to copy the param.pvals tuple to avoid
+    // error: calling a __host__ function from a __host__ __device__ function is not allowed
+    Tensor<xpu, 1, double> pvals =
+      ctx.requested[1].get_space_typed<xpu, 1, double>(Shape1(prob_length), s);
+    double* pvals_ = pvals.dptr_;
+    // check if sum of input(pvals) > 1.0
+    double sum = 0.0;
+    for (int i = 0; i < prob_length; ++i) {
+        sum += param.pvals.value()[i];
+        // copy the tuple to data for later kernel usage
+        pvals_[i] = param.pvals.value()[i];
+        CHECK_LE(sum, 1.0)
+          << "sum(pvals[:-1]) > 1.0";
+    }
+    Kernel<multinomial_kernel, xpu>::Launch(
+      s, num_output, num_exp, prob_length, pvals_, uniform.dptr_, outputs[0].dptr<int64_t>());
+  } else {
+    MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      // check if sum of input(pvals) > 1.0
+      DType sum = DType(0);
+      DType* input = inputs[0].dptr<DType>();
+      for (int i = 0; i < prob_length; ++i) {
+        sum += input[i];
+        CHECK_LE(sum, 1.0)
+          << "sum(pvals[:-1]) > 1.0";
+      }
+      Kernel<multinomial_kernel, xpu>::Launch(
+        s, num_output, num_exp, prob_length,
+        inputs[0].dptr<DType>(), uniform.dptr_, outputs[0].dptr<int64_t>());
+    });
+  }
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_NUMPY_RANDOM_NP_MULTINOMIAL_OP_H_
diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py
index 0d8eacf..e6e4911 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -23,7 +23,7 @@ import numpy as _np
 import mxnet as mx
 from mxnet import np, npx, autograd
 from mxnet.gluon import HybridBlock
-from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, assert_exception
+from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, retry, assert_exception
 from common import with_seed, TemporaryDirectory
 
 
@@ -669,6 +669,51 @@ def test_np_save_load_ndarrays():
             assert _np.array_equal(v.asnumpy(), arr_dict[k].asnumpy())
 
 
+@retry(5)
+@with_seed()
+@npx.use_np_shape
+def test_np_multinomial():
+    pvals_list = [[0.0, 0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1, 0.0]]
+    sizes = [None, (), (3,), (2, 5, 7), (4, 9)]
+    experiements = 10000
+    for pvals_type in [list, _np.ndarray]:
+        for have_size in [False, True]:
+            for pvals in pvals_list:
+                if have_size:
+                    for size in sizes:
+                        if pvals_type == mx.nd.NDArray:
+                            pvals = mx.nd.array(pvals).as_np_ndarray()
+                        elif pvals_type == _np.ndarray:
+                            pvals = _np.array(pvals)
+                        freq = mx.np.random.multinomial(experiements, pvals, size=size).asnumpy() / _np.float32(experiements)
+                        # for those cases that didn't need reshape
+                        if size in [None, ()]:
+                            mx.test_utils.assert_almost_equal(freq, pvals, rtol=0.20, atol=1e-1)
+                        else:
+                            # check the shape
+                            assert freq.shape == size + (len(pvals),), 'freq.shape={}, size + (len(pvals))={}'.format(freq.shape, size + (len(pvals)))
+                            freq = freq.reshape((-1, len(pvals)))
+                            # check the value for each row
+                            for i in range(freq.shape[0]):
+                                mx.test_utils.assert_almost_equal(freq[i, :], pvals, rtol=0.20, atol=1e-1)
+                else:
+                    freq = mx.np.random.multinomial(experiements, pvals).asnumpy() / _np.float32(experiements)
+                    mx.test_utils.assert_almost_equal(freq, pvals, rtol=0.20, atol=1e-1)
+    # check the zero dimension
+    sizes = [(0), (0, 2), (4, 0, 2), (3, 0, 1, 2, 0)]
+    for pvals in pvals_list:
+        for size in sizes:
+            freq = mx.np.random.multinomial(experiements, pvals, size=size).asnumpy()
+            assert freq.size == 0
+    # check [] as pvals
+    for pvals in [[], ()]:
+        freq = mx.np.random.multinomial(experiements, pvals).asnumpy()
+        assert freq.size == 0
+        for size in sizes:
+            freq = mx.np.random.multinomial(experiements, pvals, size=size).asnumpy()
+            assert freq.size == 0
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()