You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ap...@apache.org on 2019/05/23 22:04:49 UTC
[incubator-mxnet] branch v1.5.x updated: Fix crash in
random.shuffle operator (#15041)
This is an automated email from the ASF dual-hosted git repository.
apeforest pushed a commit to branch v1.5.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.5.x by this push:
new f9dbd0e Fix crash in random.shuffle operator (#15041)
f9dbd0e is described below
commit f9dbd0e05fb25ff6c773e0b587a92766e60bedf4
Author: Lin Yuan <ap...@gmail.com>
AuthorDate: Thu May 23 13:02:15 2019 -0700
Fix crash in random.shuffle operator (#15041)
* fix crash in random_shuffle caused by int overflow
* add unit test
* add comment
* remove small random test to avoid CI failure
---
src/operator/random/shuffle_op.cc | 9 +++++++--
tests/python/unittest/test_random.py | 2 ++
2 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/src/operator/random/shuffle_op.cc b/src/operator/random/shuffle_op.cc
index 1bd70b1..345a771 100644
--- a/src/operator/random/shuffle_op.cc
+++ b/src/operator/random/shuffle_op.cc
@@ -45,8 +45,13 @@ namespace {
template<typename DType, typename Rand>
void Shuffle1D(DType* const out, const index_t size, Rand* const prnd) {
#ifdef USE_GNU_PARALLEL_SHUFFLE
- auto rand_n = [prnd](index_t n) {
- std::uniform_int_distribution<index_t> dist(0, n - 1);
+ /*
+ * See issue #15029: the data type of n needs to be compatible with
+ * the gcc library: https://github.com/gcc-mirror/gcc/blob/master/libstdc%2B%2B\
+ * -v3/include/parallel/random_shuffle.h#L384
+ */
+ auto rand_n = [prnd](uint32_t n) {
+ std::uniform_int_distribution<uint32_t> dist(0, n - 1);
return dist(*prnd);
};
__gnu_parallel::random_shuffle(out, out + size, rand_n);
diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py
index 5e809d3..4d14719 100644
--- a/tests/python/unittest/test_random.py
+++ b/tests/python/unittest/test_random.py
@@ -867,6 +867,8 @@ def test_shuffle():
# Test larger arrays
testLarge(mx.nd.arange(0, 100000).reshape((10, 10000)), 10)
testLarge(mx.nd.arange(0, 100000).reshape((10000, 10)), 10)
+ testLarge(mx.nd.arange(0, 100000), 10)
+
@with_seed()
def test_randint():