You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2019/10/20 05:55:40 UTC

[incubator-mxnet] branch numpy_1_6_prs updated: add interface for rand

This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch numpy_1_6_prs
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/numpy_1_6_prs by this push:
     new c2cf2a3  add interface for rand
c2cf2a3 is described below

commit c2cf2a3010b36a00defb1ba5c0c6953fb8a3522e
Author: kshitij12345 <ks...@gmail.com>
AuthorDate: Thu Aug 15 17:27:08 2019 +0530

    add interface for rand
    
    add relevant tests
    
    address comments.
    
    * fix document string -> Returns description.
    
    Fix
---
 python/mxnet/ndarray/numpy/random.py   | 29 ++++++++++++++++-
 python/mxnet/numpy/random.py           | 30 ++++++++++++++++--
 python/mxnet/symbol/numpy/random.py    | 30 ++++++++++++++++--
 tests/python/unittest/test_numpy_op.py | 58 +++++++++++++++++++++++++++++++---
 4 files changed, 137 insertions(+), 10 deletions(-)

diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py
index 9e40169..583f56e 100644
--- a/python/mxnet/ndarray/numpy/random.py
+++ b/python/mxnet/ndarray/numpy/random.py
@@ -23,7 +23,7 @@ from . import _internal as _npi
 from ..ndarray import NDArray
 
 
-__all__ = ['randint', 'uniform', 'normal', "choice"]
+__all__ = ['randint', 'uniform', 'normal', "choice", "rand"]
 
 
 def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
@@ -317,3 +317,30 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None):
             return _npi.choice(a=a, size=size, replace=replace, ctx=ctx, weighted=False, out=out)
         else:
             return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True, out=out)
+
+
+def rand(*size, **kwargs):
+    r"""Random values in a given shape.
+
+    Create an array of the given shape and populate it with random
+    samples from a uniform distribution over [0, 1).
+    Parameters
+    ----------
+    d0, d1, ..., dn : int, optional
+        The dimensions of the returned array, should be all positive.
+        If no argument is given a single Python float is returned.
+    Returns
+    -------
+    out : ndarray
+       Random values.
+    Examples
+    --------
+    >>> np.random.rand(3,2)
+    array([[ 0.14022471,  0.96360618],  #random
+           [ 0.37601032,  0.25528411],  #random
+           [ 0.49313049,  0.94909878]]) #random
+    """
+    output_shape = ()
+    for s in size:
+        output_shape += (s,)
+    return uniform(0, 1, size=output_shape, **kwargs)
diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py
index 746ce99..d0ae237 100644
--- a/python/mxnet/numpy/random.py
+++ b/python/mxnet/numpy/random.py
@@ -20,8 +20,7 @@
 from __future__ import absolute_import
 from ..ndarray import numpy as _mx_nd_np
 
-
-__all__ = ["randint", "uniform", "normal", "choice"]
+__all__ = ["randint", "uniform", "normal", "choice", "rand"]
 
 
 def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
@@ -231,3 +230,30 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None):
     array([2, 3, 0])
     """
     return _mx_nd_np.random.choice(a, size, replace, p, ctx, out)
+
+
+def rand(*size, **kwargs):
+    r"""Random values in a given shape.
+
+    Create an array of the given shape and populate it with random
+    samples from a uniform distribution over [0, 1).
+    Parameters
+    ----------
+    d0, d1, ..., dn : int, optional
+        The dimensions of the returned array, should be all positive.
+        If no argument is given a single Python float is returned.
+    Returns
+    -------
+    out : ndarray
+       Random values.
+    Examples
+    --------
+    >>> np.random.rand(3,2)
+    array([[ 0.14022471,  0.96360618],  #random
+           [ 0.37601032,  0.25528411],  #random
+           [ 0.49313049,  0.94909878]]) #random
+    """
+    output_shape = ()
+    for s in size:
+        output_shape += (s,)
+    return _mx_nd_np.random.uniform(0, 1, size=output_shape, **kwargs)
diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py
index 84cc570..d891ea0 100644
--- a/python/mxnet/symbol/numpy/random.py
+++ b/python/mxnet/symbol/numpy/random.py
@@ -21,8 +21,7 @@ from __future__ import absolute_import
 from ...context import current_context
 from . import _internal as _npi
 
-
-__all__ = ['randint', 'uniform', 'normal']
+__all__ = ['randint', 'uniform', 'normal', 'rand']
 
 
 def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
@@ -86,6 +85,33 @@ def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
     return _npi.random_randint(low, high, shape=size, dtype=dtype, ctx=ctx, out=out)
 
 
+def rand(*size, **kwargs):
+    r"""Random values in a given shape.
+
+    Create an array of the given shape and populate it with random
+    samples from a uniform distribution over [0, 1).
+    Parameters
+    ----------
+    d0, d1, ..., dn : int, optional
+        The dimensions of the returned array, should be all positive.
+        If no argument is given a single Python float is returned.
+    Returns
+    -------
+    out : ndarray
+       Random values.
+    Examples
+    --------
+    >>> np.random.rand(3,2)
+    array([[ 0.14022471,  0.96360618],  #random
+           [ 0.37601032,  0.25528411],  #random
+           [ 0.49313049,  0.94909878]]) #random
+    """
+    output_shape = ()
+    for s in size:
+        output_shape += (s,)
+    return uniform(0, 1, size=output_shape, **kwargs)
+
+
 def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None):
     """Draw samples from a uniform distribution.
 
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 5676ff8..99833d1 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -20,7 +20,9 @@ from __future__ import absolute_import
 import sys
 import unittest
 import numpy as _np
+import platform
 import mxnet as mx
+import scipy.stats as ss
 from mxnet import np, npx
 from mxnet.gluon import HybridBlock
 from mxnet.base import MXNetError
@@ -28,13 +30,9 @@ from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndar
 from mxnet.test_utils import check_numeric_gradient, use_np, collapse_sum_like
 from common import assertRaises, with_seed
 import random
-import scipy.stats as ss
-from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, retry
-from mxnet.runtime import Features
+from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf
 from mxnet.numpy_op_signature import _get_builtin_op
-from mxnet.test_utils import current_context, verify_generator, gen_buckets_probs_with_ppf
 from mxnet.test_utils import is_op_runnable, has_tvm_ops
-import platform
 
 
 @with_seed()
@@ -3439,6 +3437,56 @@ def test_np_einsum():
                     assert_almost_equal(grad[0][iop], grad[1][iop], rtol=rtol, atol=atol)
 
 
+@with_seed()
+@use_np
+def test_np_rand():
+    # Test shapes.
+    shapes = [
+        (3, 3),
+        (3, 4),
+        (0, 0),
+        (3, 3, 3),
+        (0, 0, 0),
+        (2, 2, 4, 3),
+        (2, 2, 4, 3),
+        (2, 0, 3, 0),
+        (2, 0, 2, 3)
+    ]
+    dtypes = ['float16', 'float32', 'float64']
+    for dtype in dtypes:
+        for shape in shapes:
+            data_mx = np.random.rand(*shape, dtype=dtype)
+            assert data_mx.shape == shape
+
+    # Test random generator.
+    ctx = mx.context.current_context()
+    samples = 1000000
+    trials = 8
+    num_buckets = 10
+    lower = 0.0
+    upper = 1.0
+    for dtype in ['float16', 'float32', 'float64']:
+        buckets, probs = gen_buckets_probs_with_ppf(
+            lambda x: ss.uniform.ppf(x, lower, upper), num_buckets)
+        # Quantize bucket boundaries to reflect the actual dtype
+        # and adjust probs accordingly
+        buckets = np.array(buckets, dtype=dtype).tolist()
+        probs = [(ss.uniform.cdf(buckets[i][1], lower, upper) -
+                  ss.uniform.cdf(buckets[i][0], lower, upper))
+                 for i in range(num_buckets)]
+
+        def generator_mx(x): return np.random.rand(
+            samples, ctx=ctx, dtype=dtype).asnumpy()
+        verify_generator(generator=generator_mx, buckets=buckets,
+                         probs=probs, nsamples=samples, nrepeat=trials)
+        generator_mx_same_seed =\
+            lambda x: _np.concatenate(
+                [np.random.rand(x // 10, ctx=ctx, dtype=dtype).asnumpy()
+                    for _ in range(10)])
+        verify_generator(generator=generator_mx_same_seed, buckets=buckets,
+                         probs=probs, nsamples=samples, nrepeat=trials)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()