You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ap...@apache.org on 2019/05/23 22:04:49 UTC

[incubator-mxnet] branch v1.5.x updated: Fix crash in random.shuffle operator (#15041)

This is an automated email from the ASF dual-hosted git repository.

apeforest pushed a commit to branch v1.5.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.5.x by this push:
     new f9dbd0e  Fix crash in random.shuffle operator (#15041)
f9dbd0e is described below

commit f9dbd0e05fb25ff6c773e0b587a92766e60bedf4
Author: Lin Yuan <ap...@gmail.com>
AuthorDate: Thu May 23 13:02:15 2019 -0700

    Fix crash in random.shuffle operator (#15041)
    
    * fix crash in random_shuffle caused by int overflow
    
    * add unit test
    
    * add comment
    
    * remove small random test to avoid CI failure
---
 src/operator/random/shuffle_op.cc    | 9 +++++++--
 tests/python/unittest/test_random.py | 2 ++
 2 files changed, 9 insertions(+), 2 deletions(-)

diff --git a/src/operator/random/shuffle_op.cc b/src/operator/random/shuffle_op.cc
index 1bd70b1..345a771 100644
--- a/src/operator/random/shuffle_op.cc
+++ b/src/operator/random/shuffle_op.cc
@@ -45,8 +45,13 @@ namespace {
 template<typename DType, typename Rand>
 void Shuffle1D(DType* const out, const index_t size, Rand* const prnd) {
   #ifdef USE_GNU_PARALLEL_SHUFFLE
-    auto rand_n = [prnd](index_t n) {
-      std::uniform_int_distribution<index_t> dist(0, n - 1);
+     /*
+      * See issue #15029: the data type of n needs to be compatible with
+      * the gcc library: https://github.com/gcc-mirror/gcc/blob/master/libstdc%2B%2B\
+      * -v3/include/parallel/random_shuffle.h#L384
+      */
+    auto rand_n = [prnd](uint32_t n) {
+      std::uniform_int_distribution<uint32_t> dist(0, n - 1);
       return dist(*prnd);
     };
     __gnu_parallel::random_shuffle(out, out + size, rand_n);
diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py
index 5e809d3..4d14719 100644
--- a/tests/python/unittest/test_random.py
+++ b/tests/python/unittest/test_random.py
@@ -867,6 +867,8 @@ def test_shuffle():
     # Test larger arrays
     testLarge(mx.nd.arange(0, 100000).reshape((10, 10000)), 10)
     testLarge(mx.nd.arange(0, 100000).reshape((10000, 10)), 10)
+    testLarge(mx.nd.arange(0, 100000), 10)
+
 
 @with_seed()
 def test_randint():