You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by wk...@apache.org on 2019/03/05 03:57:13 UTC

[incubator-mxnet] branch master updated: support long for mx.random.seed (#14314)

This is an automated email from the ASF dual-hosted git repository.

wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 99bb06c  support long for mx.random.seed (#14314)
99bb06c is described below

commit 99bb06cd0bba3a4aeed8d39cca4bdfee285e1cd6
Author: JackieWu <wk...@live.cn>
AuthorDate: Tue Mar 5 11:56:11 2019 +0800

    support long for mx.random.seed (#14314)
    
    * support long for mx.random.seed
    
    * update test_random
    
    * reorder
    
    * use mx.random.uniform
    
    * trigger CI
    
    * retrigger CI
---
 python/mxnet/random.py               |  6 +++---
 tests/python/unittest/test_random.py | 29 ++++++++++++++++++++++++++++-
 2 files changed, 31 insertions(+), 4 deletions(-)

diff --git a/python/mxnet/random.py b/python/mxnet/random.py
index 6394727..1d6fe27 100644
--- a/python/mxnet/random.py
+++ b/python/mxnet/random.py
@@ -22,7 +22,7 @@
 from __future__ import absolute_import
 
 import ctypes
-from .base import _LIB, check_call
+from .base import _LIB, check_call, integer_types
 from .ndarray.random import *
 from .context import Context
 
@@ -90,9 +90,9 @@ def seed(seed_state, ctx="all"):
     [[ 2.5020072 -1.6884501]
      [-0.7931333 -1.4218881]]
     """
-    if not isinstance(seed_state, int):
+    if not isinstance(seed_state, integer_types):
         raise ValueError('seed_state must be int')
-    seed_state = ctypes.c_int(seed_state)
+    seed_state = ctypes.c_int(int(seed_state))
     if ctx == "all":
         check_call(_LIB.MXRandomSeed(seed_state))
     else:
diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py
index b786c5b..d1340c4 100644
--- a/tests/python/unittest/test_random.py
+++ b/tests/python/unittest/test_random.py
@@ -336,6 +336,7 @@ def test_parallel_random_seed_setting():
         # Avoid excessive test cpu runtimes
         num_temp_seeds = 25 if ctx.device_type == 'gpu' else 1
         # To flush out a possible race condition, run multiple times
+
         for _ in range(20):
             # Create enough samples such that we get a meaningful distribution.
             shape = (200, 200)
@@ -670,7 +671,7 @@ def test_with_random_seed():
         with random_seed(seed):
             python_data = [rnd.random() for _ in range(size)]
             np_data = np.random.rand(size)
-            mx_data = mx.nd.random_uniform(shape=shape, ctx=ctx).asnumpy()
+            mx_data = mx.random.uniform(shape=shape, ctx=ctx).asnumpy()
         return (seed, python_data, np_data, mx_data)
 
     # check data, expecting them to be the same or different based on the seeds
@@ -713,6 +714,32 @@ def test_with_random_seed():
             check_data(data[i],data[j])
 
 @with_seed()
+def test_random_seed():
+    shape = (5, 5)
+    seed = rnd.randint(-(1 << 31), (1 << 31))
+
+    def _assert_same_mx_arrays(a, b):
+        assert len(a) == len(b)
+        for a_i, b_i in zip(a, b):
+            assert (a_i.asnumpy() == b_i.asnumpy()).all()
+
+    N = 100
+    mx.random.seed(seed)
+    v1 = [mx.random.uniform(shape=shape) for _ in range(N)]
+
+    mx.random.seed(seed)
+    v2 = [mx.random.uniform(shape=shape) for _ in range(N)]
+    _assert_same_mx_arrays(v1, v2)
+
+    try:
+        long
+        mx.random.seed(long(seed))
+        v3 = [mx.random.uniform(shape=shape) for _ in range(N)]
+        _assert_same_mx_arrays(v1, v3)
+    except NameError:
+        pass
+
+@with_seed()
 def test_unique_zipfian_generator():
     ctx = mx.context.current_context()
     if ctx.device_type == 'cpu':