You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/15 19:57:50 UTC
[incubator-mxnet] branch master updated: Random seed setting (#9409)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new a4a09e5 Random seed setting (#9409)
a4a09e5 is described below
commit a4a09e5622c295bd6d7910e6c83f4eb2c5e0404e
Author: Dick Carter <di...@comcast.net>
AuthorDate: Mon Jan 15 11:57:47 2018 -0800
Random seed setting (#9409)
* Added tests of random generator seed setting to expose potential races.
* Syncing parallel-rng seeding to fix failure of test_operator_gpu.py:test_parallel_random_seed_setting
* Fix test_random.py:test_uniform_generator float16 failures.
---
src/common/random_generator.cu | 1 +
tests/python/unittest/test_random.py | 82 ++++++++++++++++++++++++++++++++++--
2 files changed, 79 insertions(+), 4 deletions(-)
diff --git a/src/common/random_generator.cu b/src/common/random_generator.cu
index 66969fe..f6f31cf 100644
--- a/src/common/random_generator.cu
+++ b/src/common/random_generator.cu
@@ -55,6 +55,7 @@ void RandGenerator<gpu, float>::Seed(mshadow::Stream<gpu> *s, uint32_t seed) {
states_,
RandGenerator<gpu, float>::kNumRandomStates,
seed);
+ s->Wait();
}
} // namespace random
diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py
index 0efe8e6..d0c1798 100644
--- a/tests/python/unittest/test_random.py
+++ b/tests/python/unittest/test_random.py
@@ -122,7 +122,7 @@ def check_with_device(device, dtype):
ret1 = ndop(**params).asnumpy()
mx.random.seed(128)
ret2 = ndop(**params).asnumpy()
- assert device.device_type == 'gpu' or same(ret1, ret2), \
+ assert same(ret1, ret2), \
"ndarray test: `%s` should give the same result with the same seed" % name
for check_name, check_func, tol in symbdic['checks']:
@@ -135,7 +135,7 @@ def check_with_device(device, dtype):
ret1 = ndop(**params).asnumpy()
mx.random.seed(128)
ret2 = ndop(**params).asnumpy()
- assert device.device_type == 'gpu' or same(ret1, ret2), \
+ assert same(ret1, ret2), \
"ndarray test: `%s` should give the same result with the same seed" % name
for i in range(2):
for j in range(2):
@@ -161,7 +161,7 @@ def check_with_device(device, dtype):
mx.random.seed(128)
yexec.forward()
un2 = (yexec.outputs[0] - x).copyto(device)
- assert device.device_type == 'gpu' or same(un1.asnumpy(), un2.asnumpy()), \
+ assert same(un1.asnumpy(), un2.asnumpy()), \
"symbolic test: `%s` should give the same result with the same seed" % name
ret1 = un1.asnumpy()
@@ -197,6 +197,76 @@ def test_random():
check_with_device(mx.context.current_context(), 'float64')
+# Set seed variously based on `start_seed` and `num_init_seeds`, then set seed finally to `final_seed`
+def set_seed_variously(init_seed, num_init_seeds, final_seed):
+ end_seed = init_seed + num_init_seeds
+ for seed in range(init_seed, end_seed):
+ mx.random.seed(seed)
+ mx.random.seed(final_seed)
+ return end_seed
+
+# Tests that seed setting of std (non-parallel) rng is synchronous w.r.t. rng use before and after.
+def test_random_seed_setting():
+ ctx = mx.context.current_context()
+ seed_to_test = 1234
+ num_temp_seeds = 25
+ probs = [0.125, 0.25, 0.25, 0.0625, 0.125, 0.1875]
+ num_samples = 100000
+ for dtype in ['float16', 'float32', 'float64']:
+ seed = set_seed_variously(1, num_temp_seeds, seed_to_test)
+ samples1 = mx.nd.random.multinomial(data=mx.nd.array(probs, ctx=ctx, dtype=dtype),
+ shape=num_samples)
+ seed = set_seed_variously(seed, num_temp_seeds, seed_to_test)
+ samples2 = mx.nd.random.multinomial(data=mx.nd.array(probs, ctx=ctx, dtype=dtype),
+ shape=num_samples)
+ samples1np = samples1.asnumpy()
+ set_seed_variously(seed, num_temp_seeds, seed_to_test+1)
+ samples2np = samples2.asnumpy()
+ assert same(samples1np, samples2np), \
+ "seed-setting test: `multinomial` should give the same result with the same seed"
+
+
+# Tests that seed setting of parallel rng is synchronous w.r.t. rng use before and after.
+def test_parallel_random_seed_setting():
+ ctx = mx.context.current_context()
+ seed_to_test = 1234
+ for dtype in ['float16', 'float32', 'float64']:
+ # Avoid excessive test cpu runtimes
+ num_temp_seeds = 25 if ctx.device_type == 'gpu' else 1
+ # To flush out a possible race condition, run multiple times
+ for _ in range(20):
+ # Create enough samples such that we get a meaningful distribution.
+ shape = (200, 200)
+ params = { 'low': -1.5, 'high': 3.0 }
+ params.update(shape=shape, dtype=dtype, ctx=ctx)
+
+ # check directly
+ seed = set_seed_variously(1, num_temp_seeds, seed_to_test)
+ ret1 = mx.nd.random.uniform(**params)
+ seed = set_seed_variously(seed, num_temp_seeds, seed_to_test)
+ ret2 = mx.nd.random.uniform(**params)
+ seed = set_seed_variously(seed, num_temp_seeds, seed_to_test)
+ assert same(ret1.asnumpy(), ret2.asnumpy()), \
+ "ndarray seed-setting test: `uniform` should give the same result with the same seed"
+
+ # check symbolic
+ X = mx.sym.Variable("X")
+ Y = mx.sym.random.uniform(**params) + X
+ x = mx.nd.zeros(shape, dtype=dtype, ctx=ctx)
+ xgrad = mx.nd.zeros(shape, dtype=dtype, ctx=ctx)
+ yexec = Y.bind(ctx, {'X' : x}, {'X': xgrad})
+ seed = set_seed_variously(seed, num_temp_seeds, seed_to_test)
+ yexec.forward(is_train=True)
+ yexec.backward(yexec.outputs[0])
+ un1 = (yexec.outputs[0] - x).copyto(ctx)
+ seed = set_seed_variously(seed, num_temp_seeds, seed_to_test)
+ yexec.forward()
+ set_seed_variously(seed, num_temp_seeds, seed_to_test)
+ un2 = (yexec.outputs[0] - x).copyto(ctx)
+ assert same(un1.asnumpy(), un2.asnumpy()), \
+ "symbolic seed-setting test: `uniform` should give the same result with the same seed"
+
+
def test_sample_multinomial():
x = mx.nd.array([[0,1,2,3,4],[4,3,2,1,0]])/10.0
dx = mx.nd.ones_like(x)
@@ -240,7 +310,11 @@ def test_uniform_generator():
for dtype in ['float16', 'float32', 'float64']:
for low, high in [(-1.0, 1.0), (1.0, 3.0)]:
print("ctx=%s, dtype=%s, Low=%g, High=%g:" % (ctx, dtype, low, high))
- buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.uniform.ppf(x, loc=low, scale=high - low), 5)
+ scale = high - low
+ buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.uniform.ppf(x, loc=low, scale=scale), 5)
+ # Quantize bucket boundaries to reflect the actual dtype and adjust probs accordingly
+ buckets = np.array(buckets, dtype=dtype).tolist()
+ probs = [(buckets[i][1] - buckets[i][0])/scale for i in range(5)]
generator_mx = lambda x: mx.nd.random.uniform(low, high, shape=x, ctx=ctx, dtype=dtype).asnumpy()
verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
generator_mx_same_seed = \
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].