You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by wk...@apache.org on 2019/03/05 03:57:13 UTC
[incubator-mxnet] branch master updated: support long for
mx.random.seed (#14314)
This is an automated email from the ASF dual-hosted git repository.
wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 99bb06c support long for mx.random.seed (#14314)
99bb06c is described below
commit 99bb06cd0bba3a4aeed8d39cca4bdfee285e1cd6
Author: JackieWu <wk...@live.cn>
AuthorDate: Tue Mar 5 11:56:11 2019 +0800
support long for mx.random.seed (#14314)
* support long for mx.random.seed
* update test_random
* reorder
* use mx.random.uniform
* trigger CI
* retrigger CI
---
python/mxnet/random.py | 6 +++---
tests/python/unittest/test_random.py | 29 ++++++++++++++++++++++++++++-
2 files changed, 31 insertions(+), 4 deletions(-)
diff --git a/python/mxnet/random.py b/python/mxnet/random.py
index 6394727..1d6fe27 100644
--- a/python/mxnet/random.py
+++ b/python/mxnet/random.py
@@ -22,7 +22,7 @@
from __future__ import absolute_import
import ctypes
-from .base import _LIB, check_call
+from .base import _LIB, check_call, integer_types
from .ndarray.random import *
from .context import Context
@@ -90,9 +90,9 @@ def seed(seed_state, ctx="all"):
[[ 2.5020072 -1.6884501]
[-0.7931333 -1.4218881]]
"""
- if not isinstance(seed_state, int):
+ if not isinstance(seed_state, integer_types):
raise ValueError('seed_state must be int')
- seed_state = ctypes.c_int(seed_state)
+ seed_state = ctypes.c_int(int(seed_state))
if ctx == "all":
check_call(_LIB.MXRandomSeed(seed_state))
else:
diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py
index b786c5b..d1340c4 100644
--- a/tests/python/unittest/test_random.py
+++ b/tests/python/unittest/test_random.py
@@ -336,6 +336,7 @@ def test_parallel_random_seed_setting():
# Avoid excessive test cpu runtimes
num_temp_seeds = 25 if ctx.device_type == 'gpu' else 1
# To flush out a possible race condition, run multiple times
+
for _ in range(20):
# Create enough samples such that we get a meaningful distribution.
shape = (200, 200)
@@ -670,7 +671,7 @@ def test_with_random_seed():
with random_seed(seed):
python_data = [rnd.random() for _ in range(size)]
np_data = np.random.rand(size)
- mx_data = mx.nd.random_uniform(shape=shape, ctx=ctx).asnumpy()
+ mx_data = mx.random.uniform(shape=shape, ctx=ctx).asnumpy()
return (seed, python_data, np_data, mx_data)
# check data, expecting them to be the same or different based on the seeds
@@ -713,6 +714,32 @@ def test_with_random_seed():
check_data(data[i],data[j])
@with_seed()
+def test_random_seed():
+ shape = (5, 5)
+ seed = rnd.randint(-(1 << 31), (1 << 31))
+
+ def _assert_same_mx_arrays(a, b):
+ assert len(a) == len(b)
+ for a_i, b_i in zip(a, b):
+ assert (a_i.asnumpy() == b_i.asnumpy()).all()
+
+ N = 100
+ mx.random.seed(seed)
+ v1 = [mx.random.uniform(shape=shape) for _ in range(N)]
+
+ mx.random.seed(seed)
+ v2 = [mx.random.uniform(shape=shape) for _ in range(N)]
+ _assert_same_mx_arrays(v1, v2)
+
+ try:
+ long
+ mx.random.seed(long(seed))
+ v3 = [mx.random.uniform(shape=shape) for _ in range(N)]
+ _assert_same_mx_arrays(v1, v3)
+ except NameError:
+ pass
+
+@with_seed()
def test_unique_zipfian_generator():
ctx = mx.context.current_context()
if ctx.device_type == 'cpu':