You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/12/28 20:27:30 UTC

[incubator-mxnet] branch master updated: fix random generator: do not gen seed each time (#9119)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 34a5195  fix random generator: do not gen seed each time (#9119)
34a5195 is described below

commit 34a51959bd2bc21c6cfa93f5fe0e079ef5268261
Author: Yizhi Liu <ja...@gmail.com>
AuthorDate: Thu Dec 28 12:27:25 2017 -0800

    fix random generator: do not gen seed each time (#9119)
    
    * add tests for distribution generators
    
    fix lint
    
    fix lint
    
    fix typo
    
    fix docstring
    
    fix docstring
    
    * [Bugfix] fix random generator: do not gen seed each time
    
    * gen samplers on gpu for test_softmax
    
    * fix test cases
    
    * remove unnecessary prints
    
    * refactor RandGenerator
    
    * get_native_random -> get_parallel_random
    
    * revise test cases + remove dependency of scipy
    
    * raise warning
---
 amalgamation/amalgamation.py            |   2 +-
 include/mxnet/resource.h                |  18 +-
 include/mxnet/storage.h                 |   2 +-
 perl-package/AI-MXNet/t/test_random.t   |   2 +-
 python/mxnet/optimizer.py               |   3 +-
 python/mxnet/test_utils.py              | 226 ++++++++++++++++++++++++
 src/common/random_generator.cu          |  56 ++++++
 src/common/random_generator.h           | 219 ++++++++++++++++++++++++
 src/common/utils.h                      |   1 -
 src/executor/attach_op_resource_pass.cc |   2 +
 src/imperative/imperative_utils.h       |   4 +
 src/operator/random/multisample_op.cc   |   3 +-
 src/operator/random/multisample_op.h    |  23 ++-
 src/operator/random/sample_op.h         |  63 +++----
 src/operator/random/sampler.h           | 293 ++++++++++++++------------------
 src/resource.cc                         |  96 ++++++++++-
 tests/python/unittest/test_module.py    |   2 +-
 tests/python/unittest/test_operator.py  |   4 +-
 tests/python/unittest/test_optimizer.py |   2 +-
 tests/python/unittest/test_random.py    | 122 +++++++++++++
 20 files changed, 919 insertions(+), 224 deletions(-)

diff --git a/amalgamation/amalgamation.py b/amalgamation/amalgamation.py
index 9419898..f1e1e02 100644
--- a/amalgamation/amalgamation.py
+++ b/amalgamation/amalgamation.py
@@ -21,7 +21,7 @@ import platform
 
 blacklist = [
     'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh',
-    'cuda_runtime.h', 'cudnn.h', 'cudnn_lrn-inl.h', 'curand.h',
+    'cuda_runtime.h', 'cudnn.h', 'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h',
     'glog/logging.h', 'io/azure_filesys.h', 'io/hdfs_filesys.h', 'io/s3_filesys.h',
     'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h',
     'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h',
diff --git a/include/mxnet/resource.h b/include/mxnet/resource.h
index 7d2e6ca..773baf0 100644
--- a/include/mxnet/resource.h
+++ b/include/mxnet/resource.h
@@ -28,6 +28,7 @@
 #include <dmlc/logging.h>
 #include "./base.h"
 #include "./engine.h"
+#include "../../src/common/random_generator.h"
 
 namespace mxnet {
 
@@ -40,7 +41,9 @@ struct ResourceRequest {
     /*! \brief mshadow::Random<xpu> object */
     kRandom,
     /*! \brief A dynamic temp space that can be arbitrary size */
-    kTempSpace
+    kTempSpace,
+    /*! \brief common::RandGenerator<xpu> object, which can be used in GPU kernel functions */
+    kParallelRandom
   };
   /*! \brief type of resources */
   Type type;
@@ -89,6 +92,19 @@ struct Resource {
     ret->set_stream(stream);
     return ret;
   }
+
+  /*!
+   * \brief Get parallel random number generator.
+   * \tparam xpu the device type of random number generator.
+   * \tparam DType the return type.
+   * \return the native random number generator. for gpu, it is allocated on global memory.
+   */
+  template<typename xpu, typename DType>
+  inline common::random::RandGenerator<xpu, DType>* get_parallel_random() const {
+    CHECK_EQ(req.type, ResourceRequest::kParallelRandom);
+    return static_cast<common::random::RandGenerator<xpu, DType>*>(ptr_);
+  }
+
   /*!
    * \brief Get space requested as mshadow Tensor.
    *  The caller can request arbitrary size.
diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h
index d19f98b..a8481c1 100644
--- a/include/mxnet/storage.h
+++ b/include/mxnet/storage.h
@@ -82,7 +82,7 @@ class Storage {
   virtual void SharedIncrementRefCount(Handle handle) = 0;
   /*!
    * \brief Free storage.
-   * \param handle Handle struect.
+   * \param handle Handle struct.
    */
   virtual void Free(Handle handle) = 0;
   /*!
diff --git a/perl-package/AI-MXNet/t/test_random.t b/perl-package/AI-MXNet/t/test_random.t
index c95a199..60cebcf 100644
--- a/perl-package/AI-MXNet/t/test_random.t
+++ b/perl-package/AI-MXNet/t/test_random.t
@@ -87,7 +87,7 @@ sub check_with_device
             ]
         },
     );
-    my $shape = [100, 100];
+    my $shape = [1000, 1000];
     for my $symbdic (@symbols)
     {
         my $name = $symbdic->{name};
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 7e8e7c2..aebb52e 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -648,7 +648,8 @@ class SGLD(Optimizer):
         if self.clip_gradient is not None:
             grad = clip(grad, -self.clip_gradient, self.clip_gradient)
         weight[:] += - lr/2 * (grad + wd * weight) + normal(0, math.sqrt(lr),
-                                                            weight.shape, weight.context)
+                                                            shape=weight.shape,
+                                                            ctx=weight.context)
 
 
 @register  # pylint: disable=invalid-name
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 53814b7..58bc8d3 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -35,6 +35,10 @@ import numpy as np
 import numpy.testing as npt
 import numpy.random as rnd
 try:
+    import scipy.stats as ss
+except ImportError:
+    ss = None
+try:
     import requests
 except ImportError:
     # in rare cases requests may be not installed
@@ -1593,3 +1597,225 @@ class DummyIter(mx.io.DataIter):
             The data of next batch.
         """
         return self.the_batch
+
+def gen_buckets_probs_with_ppf(ppf, nbuckets):
+    """Generate the buckets and probabilities for chi_square test when the ppf (Quantile function)
+     is specified.
+
+    Parameters
+    ----------
+    ppf : function
+        The Quantile function that takes a probability and maps it back to a value.
+        It's the inverse of the cdf function
+    nbuckets : int
+        size of the buckets
+
+    Returns
+    -------
+    buckets : list of tuple
+        The generated buckets
+    probs : list
+        The generate probabilities
+    """
+    assert nbuckets > 0
+    probs = [1.0 / nbuckets for _ in range(nbuckets)]
+    buckets = [(ppf(i / float(nbuckets)), ppf((i + 1) / float(nbuckets))) for i in range(nbuckets)]
+    return buckets, probs
+
+def mean_check(generator, mu, sigma, nsamples=1000000):
+    """Test the generator by matching the mean.
+
+    We test the sample mean by checking if it falls inside the range
+        (mu - 3 * sigma / sqrt(n), mu + 3 * sigma / sqrt(n))
+
+    References::
+
+        @incollection{goucher2009beautiful,
+              title={Beautiful Testing: Leading Professionals Reveal How They Improve Software},
+              author={Goucher, Adam and Riley, Tim},
+              year={2009},
+              chapter=10
+        }
+
+    Examples::
+
+        generator = lambda x: np.random.normal(0, 1.0, size=x)
+        mean_check_ret = mean_check(generator, 0, 1.0)
+
+    Parameters
+    ----------
+    generator : function
+        The generator function. It's expected to generate N i.i.d samples by calling generator(N).
+    mu : float
+    sigma : float
+    nsamples : int
+
+    Returns
+    -------
+    ret : bool
+        Whether the mean test succeeds
+    """
+    samples = np.array(generator(nsamples))
+    sample_mean = samples.mean()
+    ret = (sample_mean > mu - 3 * sigma / np.sqrt(nsamples)) and\
+          (sample_mean < mu + 3 * sigma / np.sqrt(nsamples))
+    return ret
+
+def var_check(generator, sigma, nsamples=1000000):
+    """Test the generator by matching the variance.
+    It will need a large number of samples and is not recommended to use
+
+    We test the sample variance by checking if it falls inside the range
+        (sigma^2 - 3 * sqrt(2 * sigma^4 / (n-1)), sigma^2 + 3 * sqrt(2 * sigma^4 / (n-1)))
+
+    References::
+
+        @incollection{goucher2009beautiful,
+              title={Beautiful Testing: Leading Professionals Reveal How They Improve Software},
+              author={Goucher, Adam and Riley, Tim},
+              year={2009},
+              chapter=10
+        }
+
+    Examples::
+
+        generator = lambda x: np.random.normal(0, 1.0, size=x)
+        var_check_ret = var_check(generator, 0, 1.0)
+
+    Parameters
+    ----------
+    generator : function
+        The generator function. It's expected to generate N i.i.d samples by calling generator(N).
+    sigma : float
+    nsamples : int
+
+    Returns
+    -------
+    ret : bool
+        Whether the variance test succeeds
+    """
+    samples = np.array(generator(nsamples))
+    sample_var = samples.var(ddof=1)
+    ret = (sample_var > sigma ** 2 - 3 * np.sqrt(2 * sigma ** 4 / (nsamples - 1))) and\
+          (sample_var < sigma ** 2 + 3 * np.sqrt(2 * sigma ** 4 / (nsamples - 1)))
+    return ret
+
+def chi_square_check(generator, buckets, probs, nsamples=1000000):
+    """Run the chi-square test for the generator. The generator can be both continuous and discrete.
+    If the generator is continuous, the buckets should contain tuples of (range_min, range_max) and
+     the probs should be the corresponding ideal probability within the specific ranges.
+    Otherwise, the buckets should be the possible output of the discrete distribution and the probs
+     should be groud-truth probability.
+
+    Usually the user is required to specify the probs parameter.
+
+    After obtatining the p value, we could further use the standard p > 0.05 threshold to get
+     the final result.
+
+    Examples::
+        buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.norm.ppf(x, 0, 1), 5)
+        generator = lambda x: np.random.normal(0, 1.0, size=x)
+        p = chi_square_check(generator=generator, buckets=buckets, probs=probs)
+        assert(p > 0.05)
+
+    Parameters
+    ----------
+    generator: function
+        A function that is assumed to generate i.i.d samples from a specific distribution.
+        generator(N) should generate N random samples.
+    buckets: list of tuple or list of number
+        The buckets to run the chi-square the test. Make sure that the buckets cover
+         the whole range of the distribution. Also, the buckets must be in ascending order and have
+         no intersection
+    probs: list or tuple
+        The ground-truth probability of the random value fall in a specific bucket.
+    nsamples:int
+        The number of samples to generate for the testing
+
+    Returns
+    -------
+    p : float
+        p value that the generator has the expected distribution.
+        A higher value indicates a larger confidence
+    obs_freq : list
+        Observed frequency of buckets
+    expected_freq : list
+        The expected (ground-truth) frequency of the buckets
+    """
+    if not ss:
+        raise ImportError("scipy is not available."
+                          " Please check if the scipy python bindings are installed.")
+    assert isinstance(buckets, list)
+    samples = generator(nsamples)
+    assert len(probs) == len(buckets)
+    if isinstance(buckets[0], (list, tuple)):
+        # Check whether the buckets are valid and fill them into a npy array
+        continuous_dist = True
+        buckets_npy = np.zeros((len(buckets) * 2, ), dtype=np.float32)
+        for i, _ in enumerate(buckets):
+            assert(buckets[i][0] <= buckets[i][1])
+            if i < len(buckets) - 1:
+                assert(buckets[i][1] <= buckets[i + 1][0])
+            buckets_npy[i * 2] = buckets[i][0]
+            buckets_npy[i * 2 + 1] = buckets[i][1]
+    else:
+        continuous_dist = False
+        buckets_npy = np.array(buckets)
+    expected_freq = (nsamples * np.array(probs, dtype=np.float32)).astype(np.int32)
+    if continuous_dist:
+        sample_bucket_ids = np.searchsorted(buckets_npy, samples, side='right')
+    else:
+        sample_bucket_ids = samples
+    if continuous_dist:
+        sample_bucket_ids = sample_bucket_ids // 2
+    obs_freq = np.zeros(shape=len(buckets), dtype=np.int)
+    for i in range(len(buckets)):
+        obs_freq[i] = (sample_bucket_ids == i).sum()
+    _, p = ss.chisquare(f_obs=obs_freq, f_exp=expected_freq)
+    return p, obs_freq, expected_freq
+
+def verify_generator(generator, buckets, probs, nsamples=1000000, nrepeat=5, success_rate=0.25):
+    """Verify whether the generator is correct using chi-square testing.
+
+    The test is repeated for "nrepeat" times and we check if the success rate is
+     above the threshold (25% by default).
+
+    Parameters
+    ----------
+    generator: function
+        A function that is assumed to generate i.i.d samples from a specific distribution.
+            generator(N) should generate N random samples.
+    buckets: list of tuple or list of number
+        The buckets to run the chi-square the test. Make sure that the buckets cover
+         the whole range of the distribution. Also, the buckets must be in ascending order and
+         have no intersection
+    probs: list or tuple
+        The ground-truth probability of the random value fall in a specific bucket.
+    nsamples: int
+        The number of samples to generate for the testing
+    nrepeat: int
+        The times to repeat the test
+    success_rate: float
+        The desired success rate
+
+    Returns
+    -------
+    cs_ret_l: list
+        The p values of the chi-square test.
+    """
+    cs_ret_l = []
+    obs_freq_l = []
+    expected_freq_l = []
+    for _ in range(nrepeat):
+        cs_ret, obs_freq, expected_freq = chi_square_check(generator=generator, buckets=buckets,
+                                                           probs=probs, nsamples=nsamples)
+        cs_ret_l.append(cs_ret)
+        obs_freq_l.append(obs_freq)
+        expected_freq_l.append(expected_freq)
+    success_num = (np.array(cs_ret_l) > 0.05).sum()
+    if success_num < nrepeat * success_rate:
+        raise AssertionError("Generator test fails, Chi-square p=%s, obs_freq=%s, expected_freq=%s."
+                             "\nbuckets=%s, probs=%s"
+                             % (str(cs_ret_l), str(obs_freq_l), str(expected_freq_l),
+                                str(buckets), str(probs)))
+    return cs_ret_l
diff --git a/src/common/random_generator.cu b/src/common/random_generator.cu
new file mode 100644
index 0000000..5f6ac44
--- /dev/null
+++ b/src/common/random_generator.cu
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2017 by Contributors
+ * \file random_generator.cu
+ * \brief gpu implements for parallel random number generator.
+ */
+
+#include <algorithm>
+#include "./random_generator.h"
+#include "../operator/mxnet_op.h"
+
+namespace mxnet {
+namespace common {
+namespace random {
+
+__global__ void rand_generator_seed_kernel(curandStatePhilox4_32_10_t *states_,
+                                           const int size,
+                                           uint32_t seed) {
+  int id = blockIdx.x * blockDim.x + threadIdx.x;
+  if (id < size) curand_init(seed, id, 0, states_ + id);
+}
+
+template<>
+void RandGenerator<gpu, float>::Seed(Stream<gpu> *s, uint32_t seed) {
+  using namespace mshadow::cuda;
+  int ngrid = std::min(kMaxGridNum,
+                       (RandGenerator<gpu, float>::kNumRandomStates + kBaseThreadNum - 1) /
+                         kBaseThreadNum);
+  rand_generator_seed_kernel
+      <<<ngrid, kBaseThreadNum, 0, Stream<gpu>::GetStream(s)>>>(
+          states_,
+          RandGenerator<gpu, float>::kNumRandomStates,
+          seed);
+}
+
+}  // namespace random
+}  // namespace common
+}  // namespace mxnet
diff --git a/src/common/random_generator.h b/src/common/random_generator.h
new file mode 100644
index 0000000..21db9d7
--- /dev/null
+++ b/src/common/random_generator.h
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2017 by Contributors
+ * \file random_generator.h
+ * \brief Parallel random number generator.
+ */
+#ifndef MXNET_COMMON_RANDOM_GENERATOR_H_
+#define MXNET_COMMON_RANDOM_GENERATOR_H_
+
+#include <mxnet/base.h>
+#include <random>
+#include <new>
+
+#if MXNET_USE_CUDA
+#include <curand_kernel.h>
+#include "../common/cuda_utils.h"
+#endif  // MXNET_USE_CUDA
+
+using namespace mshadow;
+
+namespace mxnet {
+namespace common {
+namespace random {
+
+template<typename Device, typename DType MSHADOW_DEFAULT_DTYPE>
+class RandGenerator;
+
+template<typename DType>
+class RandGenerator<cpu, DType> {
+ public:
+  // at least how many random numbers should be generated by one CPU thread.
+  static const int kMinNumRandomPerThread = 64;
+  // store how many global random states for CPU.
+  static const int kNumRandomStates = 1024;
+
+  // implementation class for random number generator
+  class Impl {
+   public:
+    typedef typename std::conditional<std::is_floating_point<DType>::value,
+                                      DType, double>::type FType;
+
+    explicit Impl(RandGenerator<cpu, DType> *gen, int state_idx)
+        : engine_(gen->states_ + state_idx) {}
+
+    Impl(const Impl &) = delete;
+    Impl &operator=(const Impl &) = delete;
+
+    MSHADOW_XINLINE int rand() { return engine_->operator()(); }
+
+    MSHADOW_XINLINE FType uniform() {
+      typedef typename std::conditional<std::is_integral<DType>::value,
+      std::uniform_int_distribution<DType>,
+      std::uniform_real_distribution<FType>>::type GType;
+      GType dist_uniform;
+      return dist_uniform(*engine_);
+    }
+
+    MSHADOW_XINLINE FType normal() {
+      std::normal_distribution<FType> dist_normal;
+      return dist_normal(*engine_);
+    }
+
+   private:
+    std::mt19937 *engine_;
+  };
+
+  static void AllocState(RandGenerator<cpu, DType> *inst) {
+    inst->states_ = new std::mt19937[kNumRandomStates];
+  }
+
+  static void FreeState(RandGenerator<cpu, DType> *inst) {
+    delete[] inst->states_;
+  }
+
+  MSHADOW_XINLINE void Seed(Stream<cpu> *, uint32_t seed) {
+    for (int i = 0; i < kNumRandomStates; ++i) (states_ + i)->seed(seed + i);
+  }
+
+ private:
+  std::mt19937 *states_;
+};
+
+#if MXNET_USE_CUDA
+
+template<typename DType>
+class RandGenerator<gpu, DType> {
+ public:
+  // at least how many random numbers should be generated by one GPU thread.
+  static const int kMinNumRandomPerThread = 64;
+  // store how many global random states for GPU.
+  static const int kNumRandomStates = 32768;
+
+  // uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
+  // by using 1.0-curand_uniform().
+  // Needed as some samplers in sampler.h won't be able to deal with
+  // one of the boundary cases.
+  class Impl {
+   public:
+    Impl &operator=(const Impl &) = delete;
+    Impl(const Impl &) = delete;
+
+    // Copy state to local memory for efficiency.
+    __device__ explicit Impl(RandGenerator<gpu, DType> *gen, int state_idx)
+        : global_gen_(gen),
+          global_state_idx_(state_idx),
+          state_(*(gen->states_ + state_idx)) {}
+
+    __device__ ~Impl() {
+      // store the curand state back into global memory
+      global_gen_->states_[global_state_idx_] = state_;
+    }
+
+    MSHADOW_FORCE_INLINE __device__ int rand() {
+      return curand(&state_);
+    }
+
+    MSHADOW_FORCE_INLINE __device__ float uniform() {
+      return static_cast<float>(1.0) - curand_uniform(&state_);
+    }
+
+    MSHADOW_FORCE_INLINE __device__ float normal() {
+      return curand_normal(&state_);
+    }
+
+   private:
+    RandGenerator<gpu, DType> *global_gen_;
+    int global_state_idx_;
+    curandStatePhilox4_32_10_t state_;
+  };  // class RandGenerator<gpu, DType>::Impl
+
+  static void AllocState(RandGenerator<gpu, DType> *inst) {
+    CUDA_CALL(cudaMalloc(&inst->states_,
+                         kNumRandomStates * sizeof(curandStatePhilox4_32_10_t)));
+  }
+
+  static void FreeState(RandGenerator<gpu, DType> *inst) {
+    CUDA_CALL(cudaFree(inst->states_));
+  }
+
+  void Seed(Stream<gpu> *s, uint32_t seed);
+
+ private:
+  curandStatePhilox4_32_10_t *states_;
+};
+
+template<>
+class RandGenerator<gpu, double> {
+ public:
+  // at least how many random numbers should be generated by one GPU thread.
+  static const int kMinNumRandomPerThread = 64;
+  // store how many global random states for GPU.
+  static const int kNumRandomStates = 32768;
+
+  // uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
+  // by using 1.0-curand_uniform().
+  // Needed as some samplers in sampler.h won't be able to deal with
+  // one of the boundary cases.
+  class Impl {
+   public:
+    Impl &operator=(const Impl &) = delete;
+    Impl(const Impl &) = delete;
+
+    // Copy state to local memory for efficiency.
+    __device__ explicit Impl(RandGenerator<gpu, double> *gen, int state_idx)
+        : global_gen_(gen),
+          global_state_idx_(state_idx),
+          state_(*(gen->states_ + state_idx)) {}
+
+    __device__ ~Impl() {
+      // store the curand state back into global memory
+      global_gen_->states_[global_state_idx_] = state_;
+    }
+
+    MSHADOW_FORCE_INLINE __device__ int rand() {
+      return curand(&state_);
+    }
+
+    MSHADOW_FORCE_INLINE __device__ double uniform() {
+      return static_cast<float>(1.0) - curand_uniform_double(&state_);
+    }
+
+    MSHADOW_FORCE_INLINE __device__ double normal() {
+      return curand_normal_double(&state_);
+    }
+
+   private:
+    RandGenerator<gpu, double> *global_gen_;
+    int global_state_idx_;
+    curandStatePhilox4_32_10_t state_;
+  };  // class RandGenerator<gpu, double>::Impl
+
+ private:
+  curandStatePhilox4_32_10_t *states_;
+};
+
+#endif  // MXNET_USE_CUDA
+
+}  // namespace random
+}  // namespace common
+}  // namespace mxnet
+#endif  // MXNET_COMMON_RANDOM_GENERATOR_H_
diff --git a/src/common/utils.h b/src/common/utils.h
index 038ab2a..ede218b 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -49,7 +49,6 @@
 namespace mxnet {
 namespace common {
 
-
 /*!
  * \brief IndPtr should be non-negative, in non-decreasing order, start with 0
  *           and end with value equal with size of indices.
diff --git a/src/executor/attach_op_resource_pass.cc b/src/executor/attach_op_resource_pass.cc
index 18feec7..9a7ed09 100644
--- a/src/executor/attach_op_resource_pass.cc
+++ b/src/executor/attach_op_resource_pass.cc
@@ -61,6 +61,8 @@ Graph AttachOpResources(Graph g) {
           }
         } else if (req.type == ResourceRequest::kRandom) {
           requested.push_back(ResourceManager::Get()->Request(ctx, req));
+        } else if (req.type == ResourceRequest::kParallelRandom) {
+          requested.push_back(ResourceManager::Get()->Request(ctx, req));
         } else {
           LOG(FATAL) << "resource type not yet supported";
         }
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index e265cce..8be1eb4 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -203,6 +203,10 @@ inline void SetDependency(const nnvm::NodeAttrs& attrs,
         requested.push_back(ResourceManager::Get()->Request(ctx, req));
         write_vars.push_back(requested.back().var);
         break;
+       case ResourceRequest::kParallelRandom:
+        requested.push_back(ResourceManager::Get()->Request(ctx, req));
+        write_vars.push_back(requested.back().var);
+        break;
        default:
         LOG(FATAL) << "resource type not yet supported";
       }
diff --git a/src/operator/random/multisample_op.cc b/src/operator/random/multisample_op.cc
index 5f2af61..a88db09 100644
--- a/src/operator/random/multisample_op.cc
+++ b/src/operator/random/multisample_op.cc
@@ -47,7 +47,8 @@ DMLC_REGISTER_PARAMETER(MultiSampleParam);
   .set_attr<nnvm::FInferShape>("FInferShape", MultiSampleOpShape) \
   .set_attr<nnvm::FInferType>("FInferType", MultiSampleOpType) \
   .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) { \
-      return std::vector<ResourceRequest>{ResourceRequest::kRandom, ResourceRequest::kTempSpace}; \
+      return std::vector<ResourceRequest>{ResourceRequest::kParallelRandom, \
+                                          ResourceRequest::kTempSpace}; \
     }) \
   .set_attr<FCompute>("FCompute<cpu>", MultiSampleOpForward<cpu, sampler, num_inputs>) \
   .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes) \
diff --git a/src/operator/random/multisample_op.h b/src/operator/random/multisample_op.h
index 38ccbb6..e93e453 100644
--- a/src/operator/random/multisample_op.h
+++ b/src/operator/random/multisample_op.h
@@ -135,6 +135,8 @@ inline bool MultiSampleOpType(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+using namespace mxnet::common::random;
+
 template<typename xpu, typename IType, typename OType, typename Sampler, int inum>
 struct SamplerCaller;
 
@@ -142,12 +144,12 @@ template<typename xpu, typename IType, typename OType, typename Sampler>
 struct SamplerCaller<xpu, IType, OType, Sampler, 1> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
-                 const Tensor<xpu, 1, unsigned int>& seeds,
-                       mshadow::Stream<xpu> *s) {
+                 RandGenerator<xpu, OType> *pgen,
+                 mshadow::Stream<xpu> *s) {
     Sampler sampler;
     sampler.Sample(inputs[0].FlatTo1D<xpu, IType>(s),
                    outputs[0].FlatTo1D<xpu, OType>(s),
-                   seeds, s);
+                   pgen, s);
   }
 };
 
@@ -155,13 +157,13 @@ template<typename xpu, typename IType, typename OType, typename Sampler>
 struct SamplerCaller<xpu, IType, OType, Sampler, 2> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
-                 const Tensor<xpu, 1, unsigned int>& seeds,
-                       mshadow::Stream<xpu> *s) {
+                 RandGenerator<xpu, OType> *pgen,
+                 mshadow::Stream<xpu> *s) {
     Sampler sampler;
     sampler.Sample(inputs[0].FlatTo1D<xpu, IType>(s),
                    inputs[1].FlatTo1D<xpu, IType>(s),
                    outputs[0].FlatTo1D<xpu, OType>(s),
-                   seeds, s);
+                   pgen, s);
   }
 };
 
@@ -177,15 +179,10 @@ void MultiSampleOpForward(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(outputs.size(), 1);
   CHECK_GT(inputs[0].Size(), 0);
   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-  // Generate multiple seeds for the different threads.
-  const int nSeeds(OptSampleSeedNum<xpu>(outputs[0].Size()));
-  Tensor<xpu, 1, unsigned> seeds
-    = ctx.requested[1].get_space_typed<xpu, 1, unsigned> (Shape1(nSeeds), ctx.get_stream<xpu>());
-  ctx.requested[0].get_random<xpu, float>(s)->GetRandInt(seeds);
   MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, {
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
-        SamplerCaller<xpu, IType, OType, Sampler, inum>
-            ::op(inputs, outputs, seeds, s);
+      RandGenerator<xpu, OType> *pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
+      SamplerCaller<xpu, IType, OType, Sampler, inum>::op(inputs, outputs, pgen, s);
     });
   });
 }
diff --git a/src/operator/random/sample_op.h b/src/operator/random/sample_op.h
index 9fdff03..a81b41a 100644
--- a/src/operator/random/sample_op.h
+++ b/src/operator/random/sample_op.h
@@ -241,31 +241,27 @@ using FSampleCompute = std::function<void (const nnvm::NodeAttrs& attrs,
                                            TBlob* outputs)>;
 
 using mxnet::TBlob;
+using namespace mxnet::common::random;
 
 // Allocates a single chunk of workspace memory and partitions it into three
 // workspace tensors that hold the seeds as well as the distribution parameters.
 template<typename xpu, typename DType>
-MSHADOW_FORCE_INLINE void GetSamplingTempData(index_t N, DType p1, DType p2, const OpContext& ctx,
-                                              Tensor<xpu, 1, unsigned int>* seeds,
+MSHADOW_FORCE_INLINE void GetSamplingTempData(DType p1, DType p2, const OpContext& ctx,
                                               Tensor<xpu, 1, DType>* parm1,
                                               Tensor<xpu, 1, DType>* parm2) {
   Stream<xpu> *s = ctx.get_stream<xpu>();
-  const index_t nSeeds(OptSampleSeedNum<xpu>(N));
   // Combined memory requirement for the workspace data.
-  const index_t nInt(nSeeds + (2 * sizeof(DType) + sizeof(unsigned) - 1) / sizeof(unsigned));
+  const index_t nInt((2 * sizeof(DType) + sizeof(unsigned) - 1) / sizeof(unsigned));
   Tensor<xpu, 1, unsigned> wspace
     = ctx.requested[1].get_space_typed<xpu, 1, unsigned>(Shape1(nInt), s);
-  // Partition workspace into three chunks and initialize them.
-  *seeds = Tensor<xpu, 1, unsigned>(wspace.dptr_, Shape1(nSeeds), s);
-  ctx.requested[0].get_random<xpu, float>(s)->GetRandInt(*seeds);
-  DType *pspace = static_cast<DType*>(static_cast<void*>(wspace.dptr_+nSeeds));
+  // Partition workspace into two chunks and initialize them.
+  DType *pspace = static_cast<DType*>(static_cast<void*>(wspace.dptr_));
   *parm1 = Tensor<xpu, 1, DType>(pspace, Shape1(1), s);
   Copy(*parm1, Tensor<cpu, 1, DType>(&p1, Shape1(1)), s);
   *parm2 = Tensor<xpu, 1, DType>(pspace+1, Shape1(1), s);
   Copy(*parm2, Tensor<cpu, 1, DType>(&p2, Shape1(1)), s);
 }
 
-
 template<typename xpu, typename Sampler>
 struct SampleMaster;
 
@@ -278,14 +274,14 @@ struct SampleMaster<xpu, UniformSampler<xpu>> {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     const SampleUniformParam& param = nnvm::get<SampleUniformParam>(attrs.parsed);
     CHECK_GE(param.high, param.low) << "low must be less or equal to high in uniform distribution";
-    Tensor<xpu, 1, unsigned int> seeds;
     Tensor<xpu, 1, float> low, high;
-    GetSamplingTempData<xpu, float>(outputs->Size(), param.low, param.high, ctx,
-                                    &seeds, &low, &high);
+    GetSamplingTempData<xpu, float>(param.low, param.high, ctx,
+                                    &low, &high);
     UniformSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+      RandGenerator<xpu, OType> *pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(low, high, out, seeds, s);
+      sampler.Sample(low, high, out, pgen, s);
     });
   }
 };
@@ -299,14 +295,13 @@ struct SampleMaster<xpu, NormalSampler<xpu>> {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     const SampleNormalParam& param = nnvm::get<SampleNormalParam>(attrs.parsed);
     CHECK_GT(param.scale, 0) << "scale parameter in gaussian has to be positive";
-    Tensor<xpu, 1, unsigned int> seeds;
     Tensor<xpu, 1, float> loc, scale;
-    GetSamplingTempData<xpu, float>(outputs->Size(), param.loc, param.scale, ctx,
-                                    &seeds, &loc, &scale);
+    GetSamplingTempData<xpu, float>(param.loc, param.scale, ctx, &loc, &scale);
     NormalSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+      RandGenerator<xpu, OType> *pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(loc, scale, out, seeds, s);
+      sampler.Sample(loc, scale, out, pgen, s);
     });
   }
 };
@@ -321,14 +316,13 @@ struct SampleMaster<xpu, GammaSampler<xpu>> {
     const SampleGammaParam& param = nnvm::get<SampleGammaParam>(attrs.parsed);
     CHECK_GT(param.alpha, 0) << "alpha parameter in gamma distribution has to be positive";
     CHECK_GT(param.beta, 0) << "beta parameter in gamma distribution has to be positive";
-    Tensor<xpu, 1, unsigned int> seeds;
     Tensor<xpu, 1, float> alpha, beta;
-    GetSamplingTempData<xpu, float>(outputs->Size(), param.alpha, param.beta, ctx,
-                                    &seeds, &alpha, &beta);
+    GetSamplingTempData<xpu, float>(param.alpha, param.beta, ctx, &alpha, &beta);
     GammaSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+      RandGenerator<xpu, OType> *pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(alpha, beta, out, seeds, s);
+      sampler.Sample(alpha, beta, out, pgen, s);
     });
   }
 };
@@ -342,13 +336,13 @@ struct SampleMaster<xpu, ExponentialSampler<xpu>> {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     const SampleExponentialParam& param = nnvm::get<SampleExponentialParam>(attrs.parsed);
     CHECK_GT(param.lam, 0) << "lambda parameter in exponential distribution has to be positive";
-    Tensor<xpu, 1, unsigned int> seeds;
     Tensor<xpu, 1, float> lam, dummy;
-    GetSamplingTempData<xpu, float>(outputs->Size(), param.lam, 0, ctx, &seeds, &lam, &dummy);
+    GetSamplingTempData<xpu, float>(param.lam, 0, ctx, &lam, &dummy);
     ExponentialSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+      RandGenerator<xpu, OType> *pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(lam, out, seeds, s);
+      sampler.Sample(lam, out, pgen, s);
     });
   }
 };
@@ -362,13 +356,13 @@ struct SampleMaster<xpu, PoissonSampler<xpu>> {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     const SamplePoissonParam& param = nnvm::get<SamplePoissonParam>(attrs.parsed);
     CHECK_GE(param.lam, 0) << "lambda parameter in poisson distribution has to be non-negative";
-    Tensor<xpu, 1, unsigned int> seeds;
     Tensor<xpu, 1, float> lam, dummy;
-    GetSamplingTempData<xpu, float>(outputs->Size(), param.lam, 0, ctx, &seeds, &lam, &dummy);
+    GetSamplingTempData<xpu, float>(param.lam, 0, ctx, &lam, &dummy);
     PoissonSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+      RandGenerator<xpu, OType> *pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(lam, out, seeds, s);
+      sampler.Sample(lam, out, pgen, s);
     });
   }
 };
@@ -383,13 +377,13 @@ struct SampleMaster<xpu, NegativeBinomialSampler<xpu>> {
     const SampleNegBinomialParam& param = nnvm::get<SampleNegBinomialParam>(attrs.parsed);
     CHECK_GE(param.k, 0) << "k parameter in negative binomial distribution has to be non-negative";
     CHECK_GE(param.p, 0) << "p parameter in negative binomial distribution has to be non-negative";
-    Tensor<xpu, 1, unsigned int> seeds;
     Tensor<xpu, 1, float> k, p;
-    GetSamplingTempData<xpu, float>(outputs->Size(), param.k, param.p, ctx, &seeds, &k, &p);
+    GetSamplingTempData<xpu, float>(param.k, param.p, ctx, &k, &p);
     NegativeBinomialSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+      RandGenerator<xpu, OType> *pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(k, p, out, seeds, s);
+      sampler.Sample(k, p, out, pgen, s);
     });
   }
 };
@@ -406,14 +400,13 @@ struct SampleMaster<xpu, GeneralizedNegativeBinomialSampler<xpu>> {
       << "mu parameter in generalized negative binomial distribution has to be non-negative";
     CHECK_GE(param.alpha, 0)
       << "alpha parameter in generalized negative binomial distribution has to be non-negative";
-    Tensor<xpu, 1, unsigned int> seeds;
     Tensor<xpu, 1, float> mu, alpha;
-    GetSamplingTempData<xpu, float>(outputs->Size(), param.mu, param.alpha, ctx,
-                                    &seeds, &mu, &alpha);
+    GetSamplingTempData<xpu, float>(param.mu, param.alpha, ctx, &mu, &alpha);
     GeneralizedNegativeBinomialSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+      RandGenerator<xpu, OType> *pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(mu, alpha, out, seeds, s);
+      sampler.Sample(mu, alpha, out, pgen, s);
     });
   }
 };
@@ -502,7 +495,7 @@ inline bool SampleOpType(const nnvm::NodeAttrs& attrs,
 }
 
 inline std::vector<ResourceRequest> SampleResource(const NodeAttrs& attrs) {
-  return { ResourceRequest::kRandom, ResourceRequest::kTempSpace };
+  return { ResourceRequest::kParallelRandom, ResourceRequest::kTempSpace };
 }
 
 }  // namespace op
diff --git a/src/operator/random/sampler.h b/src/operator/random/sampler.h
index d544aec..8eace1e 100644
--- a/src/operator/random/sampler.h
+++ b/src/operator/random/sampler.h
@@ -25,89 +25,52 @@
 #ifndef MXNET_OPERATOR_RANDOM_SAMPLER_H_
 #define MXNET_OPERATOR_RANDOM_SAMPLER_H_
 
-#ifdef __CUDACC__
-#include <curand.h>
-#include <curand_kernel.h>
-#endif  // __CUDACC__
+#include <algorithm>
 
 using namespace mshadow;
 using namespace mxnet::op::mxnet_op;
+using namespace mxnet::common::random;
 
 namespace mxnet {
 namespace op {
 
-// Elementary random number generation for int/uniform/gaussian in CPU and GPU.
-// Will use float data type whenever instantiated for half_t or any other non
-// standard real type.
-template<typename xpu, typename DType>
-class RandGenerator;
-
-template<typename DType>
-class RandGenerator<cpu, DType> {
- public:
-  typedef typename std::conditional<std::is_floating_point<DType>::value,
-                                    DType, float>::type FType;
-  std::mt19937 engine;
-  std::uniform_real_distribution<FType> uniformNum;
-  std::normal_distribution<FType> normalNum;
-  explicit RandGenerator(unsigned int seed): engine(seed) {}
-  MSHADOW_XINLINE int rand() { return engine(); }
-  MSHADOW_XINLINE FType uniform() { return uniformNum(engine); }
-  MSHADOW_XINLINE FType normal() { return normalNum(engine); }
-};
-
-#ifdef __CUDACC__
-
-// uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
-// by using 1.0-curand_uniform(). Needed as some samplers below won't be able to deal with
-// one of the boundary cases.
-template<typename DType>
-class RandGenerator<gpu, DType> {
- public:
-  curandState_t state;
-  __device__ RandGenerator(unsigned int seed) { curand_init(seed, 0, 0, &state); }
-  MSHADOW_FORCE_INLINE __device__ int rand() { return curand(&state); }
-  MSHADOW_FORCE_INLINE __device__ float uniform()
-                              { return static_cast<float>(1.0) - curand_uniform(&state); }
-  MSHADOW_FORCE_INLINE __device__ float normal() { return curand_normal(&state); }
-};
-
-template<>
-class RandGenerator<gpu, double> {
- public:
-  curandState_t state;
-  __device__ RandGenerator(unsigned int seed) { curand_init(seed, 0, 0, &state); }
-  MSHADOW_FORCE_INLINE __device__ int rand() { return curand(&state); }
-  MSHADOW_FORCE_INLINE __device__ double uniform()
-                            { return static_cast<double>(1.0) - curand_uniform_double(&state); }
-  MSHADOW_FORCE_INLINE __device__ double normal() { return curand_normal_double(&state); }
-};
-
-#endif  // __CUDACC__
-
-// Number of seeds/threads when sampling on cpu/gpu.
-template<typename xpu>
-MSHADOW_XINLINE index_t OptSampleSeedNum(index_t N);
-template<>
-MSHADOW_XINLINE index_t OptSampleSeedNum<cpu>(index_t N) {
-  return omp_get_num_threads();
-}
-template<>
-MSHADOW_XINLINE index_t OptSampleSeedNum<gpu>(index_t N) {
-  return N;
+/*!
+ * \brief Launch a generic kernel with parallel random generator.
+ * \tparam gen random generator
+ * \tparam N Number of iterations
+ * \tparam Args Varargs type to eventually pass to the OP::Map() functoion
+ */
+template<typename OP, typename xpu, typename GType, typename ...Args>
+inline static void LaunchRNG(mshadow::Stream<xpu> *s,
+                             common::random::RandGenerator<xpu, GType> *gen,
+                             const int N, Args... args) {
+  const int nloop = (N + RandGenerator<xpu, GType>::kMinNumRandomPerThread - 1) /
+                    RandGenerator<xpu, GType>::kMinNumRandomPerThread;
+  const int nthread = std::min(nloop, RandGenerator<xpu, GType>::kNumRandomStates);
+  const int step = (N + nthread - 1) / nthread;
+  Kernel<OP, xpu>::Launch(s, nthread, *gen, N, step, args...);
 }
 
+#define RNG_KERNEL_LOOP(xpu, GType, thread_id, gen, N, step, ...)        \
+  const int start = thread_id * step;                                    \
+  const int end = start + step;                                          \
+  typename RandGenerator<xpu, GType>::Impl genImpl(&gen, thread_id);     \
+  for (int i = start; i < end && i < N; ++i) {                           \
+    {__VA_ARGS__}                                                        \
+  }
+
 template<typename xpu>
 struct SampleUniformKernel {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int i, index_t nParm, index_t nSample, index_t nSeed,
-                     const IType *lower, const IType *upper, OType *out, const unsigned *seed) {
-    index_t nBatch(nSample/nParm), nChunk((nSample+nSeed-1)/nSeed),
-            start(i*nChunk), end((i+1)*nChunk < nSample ? (i+1)*nChunk : nSample);
-    RandGenerator<xpu, OType> gen(seed[i]);
-    for ( index_t j = start; j < end; ++j ) {
-      out[j] = OType(lower[j/nBatch] + (upper[j/nBatch] - lower[j/nBatch]) * gen.uniform());
-    }
+  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, OType> gen,
+                                  const int N, const int step,
+                                  index_t nParm, index_t nSample,
+                                  const IType *lower, const IType *upper, OType *out) {
+    RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, {
+      index_t nBatch(1 + (nSample - 1) / nParm);
+      out[i] = OType(lower[i / nBatch] +
+                     (upper[i / nBatch] - lower[i / nBatch]) * genImpl.uniform());
+    });
   }
 };
 
@@ -117,25 +80,24 @@ struct UniformSampler {
   MSHADOW_FORCE_INLINE void Sample(const Tensor<xpu, 1, IType>& lower,
                                    const Tensor<xpu, 1, IType>& upper,
                                    const Tensor<xpu, 1, OType>& out,
-                                   const Tensor<xpu, 1, unsigned>& seed,
-                                         Stream<xpu> *s) {
-    Kernel<SampleUniformKernel<xpu>, xpu>
-      ::Launch(s, seed.size(0), lower.size(0), out.size(0), seed.size(0),
-               lower.dptr_, upper.dptr_, out.dptr_, seed.dptr_);
+                                   RandGenerator<xpu, OType> *pgen,
+                                   Stream<xpu> *s) {
+    LaunchRNG<SampleUniformKernel<xpu>, xpu>(s, pgen, out.size(0), lower.size(0), out.size(0),
+                                             lower.dptr_, upper.dptr_, out.dptr_);
   }
 };
 
 template<typename xpu>
 struct SampleNormalKernel {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int i, index_t nParm, index_t nSample, index_t nSeed,
-                            const IType *mean, const IType *std, OType *out, const unsigned *seed) {
-    index_t nBatch(nSample/nParm), nChunk((nSample+nSeed-1)/nSeed),
-            start(i*nChunk), end((i+1)*nChunk < nSample ? (i+1)*nChunk : nSample);
-    RandGenerator<xpu, OType> gen(seed[i]);
-    for ( index_t j = start; j < end; ++j ) {
-      out[j] = OType(gen.normal() * std[j/nBatch] + mean[j/nBatch]);
-    }
+  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, OType> gen,
+                                  const int N, const int step,
+                                  index_t nParm, index_t nSample,
+                                  const IType *mean, const IType *std, OType *out) {
+    RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, {
+      index_t nBatch(1 + (nSample - 1) / nParm);
+      out[i] = OType(genImpl.normal() * std[i / nBatch] + mean[i / nBatch]);
+    });
   }
 };
 
@@ -145,25 +107,24 @@ struct NormalSampler {
   MSHADOW_FORCE_INLINE void Sample(const Tensor<xpu, 1, IType>& mean,
                                    const Tensor<xpu, 1, IType>& std,
                                    const Tensor<xpu, 1, OType>& out,
-                                   const Tensor<xpu, 1, unsigned>& seed,
-                                         Stream<xpu> *s) {
-    Kernel<SampleNormalKernel<xpu>, xpu>
-      ::Launch(s, seed.size(0), mean.size(0), out.size(0), seed.size(0),
-               mean.dptr_, std.dptr_, out.dptr_, seed.dptr_);
+                                   RandGenerator<xpu, OType> *pgen,
+                                   Stream<xpu> *s) {
+    LaunchRNG<SampleNormalKernel<xpu>, xpu>(s, pgen, out.size(0), mean.size(0), out.size(0),
+                                            mean.dptr_, std.dptr_, out.dptr_);
   }
 };
 
 template<typename xpu>
 struct SampleExponentialKernel {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int i, index_t nParm, index_t nSample, index_t nSeed,
-                                  const IType *lambda, OType *out, const unsigned *seed) {
-    index_t nBatch(nSample/nParm), nChunk((nSample+nSeed-1)/nSeed),
-            start(i*nChunk), end((i+1)*nChunk < nSample ? (i+1)*nChunk : nSample);
-    RandGenerator<xpu, OType> gen(seed[i]);
-    for ( index_t j = start; j < end; ++j ) {
-      out[j] = OType(-log(1.0-gen.uniform()) / lambda[j/nBatch]);
-    }
+  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, OType> gen,
+                                  const int N, const int step,
+                                  index_t nParm, index_t nSample,
+                                  const IType *lambda, OType *out) {
+    RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, {
+      index_t nBatch(1 + (nSample - 1) / nParm);
+      out[i] = OType(-log(1.0 - genImpl.uniform()) / lambda[i / nBatch]);
+    });
   }
 };
 
@@ -172,16 +133,16 @@ struct ExponentialSampler {
   template<typename IType, typename OType>
   MSHADOW_FORCE_INLINE void Sample(const Tensor<xpu, 1, IType>& lambda,
                                    const Tensor<xpu, 1, OType>& out,
-                                   const Tensor<xpu, 1, unsigned>& seed,
-                                         Stream<xpu> *s) {
-    Kernel<SampleExponentialKernel<xpu>, xpu>
-      ::Launch(s, seed.size(0), lambda.size(0), out.size(0), seed.size(0),
-               lambda.dptr_, out.dptr_, seed.dptr_);
+                                   RandGenerator<xpu, OType> *pgen,
+                                   Stream<xpu> *s) {
+    LaunchRNG<SampleExponentialKernel<xpu>, xpu>(s, pgen, out.size(0),
+                                                 lambda.size(0), out.size(0),
+                                                 lambda.dptr_, out.dptr_);
   }
 };
 
 template<typename xpu, typename IType, typename OType>
-MSHADOW_XINLINE OType SampleGamma(IType a, IType b, RandGenerator<xpu, OType> *gen) {
+MSHADOW_XINLINE OType SampleGamma(IType a, IType b, typename RandGenerator<xpu, OType>::Impl *gen) {
   // Generate one sample of the gamma distribution
   OType sample;
   OType d = a < 1 ? a + 2.0 / 3.0 : a - 1.0 / 3.0;
@@ -203,17 +164,16 @@ MSHADOW_XINLINE OType SampleGamma(IType a, IType b, RandGenerator<xpu, OType> *g
 
 template<typename xpu>
 struct SampleGammaKernel {
-  template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int i, index_t nParm, index_t nSample, index_t nSeed,
-                      const IType *alpha, const IType *beta, OType *out, const unsigned *seed) {
-    index_t nBatch(nSample/nParm), nChunk((nSample+nSeed-1)/nSeed),
-            start(i*nChunk), end((i+1)*nChunk < nSample ? (i+1)*nChunk : nSample);
-    typedef typename std::conditional<std::is_floating_point<OType>::value,
-                                     OType, float>::type FType;
-    RandGenerator<xpu, FType> gen(seed[i]);
-    for ( index_t j = start; j < end; ++j ) {
-      out[j] = OType(SampleGamma(alpha[j/nBatch], beta[j/nBatch], &gen));
-    }
+  template<typename IType, typename OType, typename FType>
+  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, FType> gen,
+                                  const int N, const int step,
+                                  index_t nParm, index_t nSample,
+                                  const IType *alpha, const IType *beta, OType *out) {
+    RNG_KERNEL_LOOP(xpu, FType, id, gen, N, step, {
+      index_t nBatch(1 + (nSample - 1) / nParm);
+      out[i] = OType(SampleGamma<xpu, IType, FType>(alpha[i / nBatch],
+                                                    beta[i / nBatch], &genImpl));
+    });
   }
 };
 
@@ -223,16 +183,18 @@ struct GammaSampler {
   MSHADOW_FORCE_INLINE void Sample(const Tensor<xpu, 1, IType>& alpha,
                                    const Tensor<xpu, 1, IType>& beta,
                                    const Tensor<xpu, 1, OType>& out,
-                                   const Tensor<xpu, 1, unsigned>& seed,
-                                         Stream<xpu> *s) {
-    Kernel<SampleGammaKernel<xpu>, xpu>
-      ::Launch(s, seed.size(0), alpha.size(0), out.size(0), seed.size(0),
-               alpha.dptr_, beta.dptr_, out.dptr_, seed.dptr_);
+                                   RandGenerator<xpu, OType> *pgen,
+                                   Stream<xpu> *s) {
+    typedef typename std::conditional<std::is_floating_point<OType>::value,
+                                      OType, float>::type FType;
+    RandGenerator<xpu, FType> *gen = reinterpret_cast<RandGenerator<xpu, FType> *>(pgen);
+    LaunchRNG<SampleGammaKernel<xpu>, xpu>(s, gen, out.size(0), alpha.size(0), out.size(0),
+                                           alpha.dptr_, beta.dptr_, out.dptr_);
   }
 };
 
 template<typename xpu>
-MSHADOW_XINLINE int SamplePoisson(float lambda, RandGenerator<xpu, float> *gen) {
+MSHADOW_XINLINE int SamplePoisson(float lambda, typename RandGenerator<xpu, float>::Impl *gen) {
   // Generate one sample of the poisson distribution. Intentionally written
   // towards a specific type (float) for internal computation which is sufficient
   // for accurate enough computation.
@@ -265,14 +227,14 @@ MSHADOW_XINLINE int SamplePoisson(float lambda, RandGenerator<xpu, float> *gen)
 template<typename xpu>
 struct SamplePoissonKernel {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int i, index_t nParm, index_t nSample, index_t nSeed,
-                                  const IType *lambda, OType *out, const unsigned *seed) {
-    index_t nBatch(nSample/nParm), nChunk((nSample+nSeed-1)/nSeed),
-            start(i*nChunk), end((i+1)*nChunk < nSample ? (i+1)*nChunk : nSample);
-    RandGenerator<xpu, float> gen(seed[i]);
-    for ( index_t j = start; j < end; ++j ) {
-      out[j] = OType(SamplePoisson(lambda[j/nBatch], &gen));
-    }
+  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, float> gen,
+                                  const int N, const int step,
+                                  index_t nParm, index_t nSample,
+                                  const IType *lambda, OType *out) {
+    RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
+      index_t nBatch(1 + (nSample - 1) / nParm);
+      out[i] = OType(SamplePoisson<xpu>(lambda[i / nBatch], &genImpl));
+    });
   }
 };
 
@@ -281,29 +243,29 @@ struct PoissonSampler {
   template<typename IType, typename OType>
   MSHADOW_FORCE_INLINE void Sample(const Tensor<xpu, 1, IType>& lambda,
                                    const Tensor<xpu, 1, OType>& out,
-                                   const Tensor<xpu, 1, unsigned>& seed,
-                                         Stream<xpu> *s) {
-    Kernel<SamplePoissonKernel<xpu>, xpu>
-      ::Launch(s, seed.size(0), lambda.size(0), out.size(0), seed.size(0),
-               lambda.dptr_, out.dptr_, seed.dptr_);
+                                   RandGenerator<xpu, OType> *pgen,
+                                   Stream<xpu> *s) {
+    RandGenerator<xpu, float> *gen = reinterpret_cast<RandGenerator<xpu, float> *>(pgen);
+    LaunchRNG<SamplePoissonKernel<xpu>, xpu>(s, gen, out.size(0), lambda.size(0), out.size(0),
+                                             lambda.dptr_, out.dptr_);
   }
 };
 
 template<typename xpu>
 struct SampleNegativeBinomialKernel {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int i, index_t nParm, index_t nSample, index_t nSeed,
-                             const IType *k, const IType *p, OType *out, const unsigned *seed) {
-    index_t nBatch(nSample/nParm), nChunk((nSample+nSeed-1)/nSeed),
-            start(i*nChunk), end((i+1)*nChunk < nSample ? (i+1)*nChunk : nSample);
-    RandGenerator<xpu, float> gen(seed[i]);
-    for ( index_t j = start; j < end; ++j ) {
-      float alpha = k[j/nBatch];
-      float prob = p[j/nBatch];
+  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, float> gen,
+                                  const int N, const int step,
+                                  index_t nParm, index_t nSample,
+                                  const IType *k, const IType *p, OType *out) {
+    RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
+      index_t nBatch(1 + (nSample - 1) / nParm);
+      float alpha = k[i / nBatch];
+      float prob = p[i / nBatch];
       float beta = (1.0 - prob) / prob;
-      float lambda = SampleGamma(alpha, beta, &gen);
-      out[j] = OType(SamplePoisson(lambda, &gen));
-    }
+      float lambda = SampleGamma<xpu, IType, float>(alpha, beta, &genImpl);
+      out[i] = OType(SamplePoisson<xpu>(lambda, &genImpl));
+    });
   }
 };
 
@@ -313,27 +275,29 @@ struct NegativeBinomialSampler {
   MSHADOW_FORCE_INLINE void Sample(const Tensor<xpu, 1, IType>& k,
                                    const Tensor<xpu, 1, IType>& p,
                                    const Tensor<xpu, 1, OType>& out,
-                                   const Tensor<xpu, 1, unsigned>& seed,
-                                         Stream<xpu> *s) {
-    Kernel<SampleNegativeBinomialKernel<xpu>, xpu>
-      ::Launch(s, seed.size(0), k.size(0), out.size(0), seed.size(0),
-               k.dptr_, p.dptr_, out.dptr_, seed.dptr_);
+                                   RandGenerator<xpu, OType> *pgen,
+                                   Stream<xpu> *s) {
+    RandGenerator<xpu, float> *gen = reinterpret_cast<RandGenerator<xpu, float> *>(pgen);
+    LaunchRNG<SampleNegativeBinomialKernel<xpu>, xpu>(s, gen, out.size(0), k.size(0), out.size(0),
+                                                      k.dptr_, p.dptr_, out.dptr_);
   }
 };
 
 template<typename xpu>
 struct SampleGeneralizedNegativeBinomialKernel {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int i, index_t nParm, index_t nSample, index_t nSeed,
-                        const IType *mu, const IType *alpha, OType *out, const unsigned *seed) {
-    index_t nBatch(nSample/nParm), nChunk((nSample+nSeed-1)/nSeed),
-            start(i*nChunk), end((i+1)*nChunk < nSample ? (i+1)*nChunk : nSample);
-    RandGenerator<xpu, float> gen(seed[i]);
-    for ( index_t j = start; j < end; ++j ) {
-      float lambda = alpha[j/nBatch] == 0 ? static_cast<float>(mu[j/nBatch])
-              : SampleGamma(IType(1) / alpha[j/nBatch], alpha[j/nBatch] * mu[j/nBatch], &gen);
-      out[j] = OType(SamplePoisson(lambda, &gen));
-    }
+  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, float> gen,
+                                  const int N, const int step,
+                                  index_t nParm, index_t nSample,
+                                  const IType *mu, const IType *alpha, OType *out) {
+    RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
+      index_t nBatch(1 + (nSample - 1) / nParm);
+      float lambda = alpha[i / nBatch] == 0 ?
+                     static_cast<float>(mu[i / nBatch]) :
+                     SampleGamma<xpu, IType, float>(IType(1) / alpha[i / nBatch],
+                                                    alpha[i / nBatch] * mu[i / nBatch], &genImpl);
+      out[i] = OType(SamplePoisson<xpu>(lambda, &genImpl));
+    });
   }
 };
 
@@ -343,11 +307,12 @@ struct GeneralizedNegativeBinomialSampler {
   MSHADOW_FORCE_INLINE void Sample(const Tensor<xpu, 1, IType>& mu,
                                    const Tensor<xpu, 1, IType>& alpha,
                                    const Tensor<xpu, 1, OType>& out,
-                                   const Tensor<xpu, 1, unsigned>& seed,
-                                         Stream<xpu> *s) {
-    Kernel<SampleGeneralizedNegativeBinomialKernel<xpu>, xpu>
-      ::Launch(s, seed.size(0), mu.size(0), out.size(0), seed.size(0),
-               mu.dptr_, alpha.dptr_, out.dptr_, seed.dptr_);
+                                   RandGenerator<xpu, OType> *pgen,
+                                   Stream<xpu> *s) {
+    RandGenerator<xpu, float> *gen = reinterpret_cast<RandGenerator<xpu, float> *>(pgen);
+    LaunchRNG<SampleGeneralizedNegativeBinomialKernel<xpu>, xpu>(s, gen, out.size(0),
+                                                                 mu.size(0), out.size(0),
+                                                                 mu.dptr_, alpha.dptr_, out.dptr_);
   }
 };
 
diff --git a/src/resource.cc b/src/resource.cc
index d1038dc..e195006 100644
--- a/src/resource.cc
+++ b/src/resource.cc
@@ -32,6 +32,8 @@
 #include <limits>
 #include <atomic>
 #include "./common/lazy_alloc_array.h"
+#include "./common/random_generator.h"
+#include "./common/utils.h"
 
 namespace mxnet {
 namespace resource {
@@ -88,20 +90,26 @@ class ResourceManagerImpl : public ResourceManager {
       : global_seed_(0) {
     cpu_temp_space_copy_ = dmlc::GetEnv("MXNET_CPU_TEMP_COPY", 4);
     gpu_temp_space_copy_ = dmlc::GetEnv("MXNET_GPU_TEMP_COPY", 1);
+    cpu_native_rand_copy_ = dmlc::GetEnv("MXNET_CPU_NATIVE_RAND_COPY", 1);
+    gpu_native_rand_copy_ = dmlc::GetEnv("MXNET_GPU_NATIVE_RAND_COPY", 4);
     engine_ref_ = Engine::_GetSharedRef();
     storage_ref_ = Storage::_GetSharedRef();
     cpu_rand_.reset(new ResourceRandom<cpu>(
         Context::CPU(), global_seed_));
     cpu_space_.reset(new ResourceTempSpace(
         Context::CPU(), cpu_temp_space_copy_));
+    cpu_native_rand_.reset(new ResourceNativeRandom<cpu>(
+        Context::CPU(), cpu_native_rand_copy_, global_seed_));
   }
   ~ResourceManagerImpl() {
     // need explicit delete, before engine get killed
     cpu_rand_.reset(nullptr);
     cpu_space_.reset(nullptr);
+    cpu_native_rand_.reset(nullptr);
 #if MXNET_USE_CUDA
     gpu_rand_.Clear();
     gpu_space_.Clear();
+    gpu_native_rand_.Clear();
 #endif
     if (engine_ref_ != nullptr) {
       engine_ref_ = nullptr;
@@ -117,6 +125,7 @@ class ResourceManagerImpl : public ResourceManager {
       switch (req.type) {
         case ResourceRequest::kRandom: return cpu_rand_->resource;
         case ResourceRequest::kTempSpace: return cpu_space_->GetNext();
+        case ResourceRequest::kParallelRandom: return cpu_native_rand_->GetNext();
         default: LOG(FATAL) << "Unknown supported type " << req.type;
       }
     } else {
@@ -133,6 +142,11 @@ class ResourceManagerImpl : public ResourceManager {
               return new ResourceTempSpace(ctx, gpu_temp_space_copy_);
             })->GetNext();
         }
+        case ResourceRequest::kParallelRandom: {
+          return gpu_native_rand_.Get(ctx.dev_id, [ctx, this]() {
+            return new ResourceNativeRandom<gpu>(ctx, gpu_native_rand_copy_, global_seed_);
+          })->GetNext();
+        }
         default: LOG(FATAL) << "Unknown supported type " << req.type;
       }
 #else
@@ -146,10 +160,14 @@ class ResourceManagerImpl : public ResourceManager {
   void SeedRandom(uint32_t seed) override {
     global_seed_ = seed;
     cpu_rand_->Seed(global_seed_);
+    cpu_native_rand_->Seed(global_seed_);
 #if MXNET_USE_CUDA
     gpu_rand_.ForEach([seed](size_t i, ResourceRandom<gpu> *p) {
         p->Seed(seed);
       });
+    gpu_native_rand_.ForEach([seed](size_t i, ResourceNativeRandom<gpu> *p) {
+      p->Seed(seed);
+    });
 #endif
   }
 
@@ -205,7 +223,7 @@ class ResourceManagerImpl : public ResourceManager {
     std::vector<SpaceAllocator> space;
     /*! \brief resource representation */
     std::vector<Resource> resource;
-    /*! \brief current pointer to the round roubin alloator */
+    /*! \brief current pointer to the round roubin allocator */
     std::atomic<size_t> curr_ptr;
     /*! \brief constructor */
     explicit ResourceTempSpace(Context ctx, size_t ncopy)
@@ -241,10 +259,82 @@ class ResourceManagerImpl : public ResourceManager {
       return resource[ptr % space.size()];
     }
   };
+
+  // the native random sampler resources
+  template<typename xpu>
+  struct ResourceNativeRandom {
+    /*! \brief the context of the PRNG */
+    Context ctx;
+    /*! \brief pointers to sampler */
+    std::vector<common::random::RandGenerator<xpu> *> sampler;
+    /*! \brief resource representation */
+    std::vector<Resource> resource;
+    /*! \brief current pointer to the round roubin allocator */
+    std::atomic<size_t> curr_ptr;
+    /*! \brief constructor */
+    explicit ResourceNativeRandom(Context ctx, size_t ncopy, uint32_t global_seed)
+        : ctx(ctx), sampler(ncopy), resource(ncopy), curr_ptr(0) {
+      for (size_t i = 0; i < sampler.size(); ++i) {
+        const uint32_t seed = ctx.dev_id + i * kMaxNumGPUs + global_seed * kRandMagic;
+        resource[i].var = Engine::Get()->NewVariable();
+        common::random::RandGenerator<xpu> *r = new common::random::RandGenerator<xpu>();
+        common::random::RandGenerator<xpu>::AllocState(r);
+        Engine::Get()->PushSync(
+        [r, seed](RunContext rctx) {
+          r->Seed(rctx.get_stream<xpu>(), seed);
+        }, ctx, {}, {resource[i].var},
+        FnProperty::kNormal, 0, PROFILER_MESSAGE("ResourceNativeRandomSetSeed"));
+        sampler[i] = r;
+        resource[i].ptr_ = sampler[i];
+        resource[i].req = ResourceRequest(ResourceRequest::kParallelRandom);
+      }
+    }
+    ~ResourceNativeRandom() {
+      for (size_t i = 0; i < sampler.size(); ++i) {
+        common::random::RandGenerator<xpu> *r = sampler[i];
+        Engine::Get()->DeleteVariable(
+        [r](RunContext rctx) {
+          MSHADOW_CATCH_ERROR(common::random::RandGenerator<xpu>::FreeState(r));
+          MSHADOW_CATCH_ERROR(delete r);
+        }, ctx, resource[i].var);
+      }
+    }
+    // set seed to a sampler
+    inline void Seed(uint32_t global_seed) {
+      for (size_t i = 0; i < sampler.size(); ++i) {
+        const uint32_t seed = ctx.dev_id + i * kMaxNumGPUs + global_seed * kRandMagic;
+        common::random::RandGenerator<xpu> *r = sampler[i];
+        Engine::Get()->PushAsync(
+        [r, seed](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+          r->Seed(rctx.get_stream<xpu>(), seed);
+          on_complete();
+        }, ctx, {}, {resource[i].var},
+        FnProperty::kNormal, 0, PROFILER_MESSAGE("ResourceNativeRandomSetSeed"));
+      }
+      // reset pointer to ensure the same result with the same seed.
+      curr_ptr.store(0);
+    }
+    // get next resource in round roubin matter
+    inline Resource GetNext() {
+      const size_t kMaxDigit = std::numeric_limits<size_t>::max() / 2;
+      size_t ptr = ++curr_ptr;
+      // reset ptr to avoid undefined behavior during overflow
+      // usually this won't happen
+      if (ptr > kMaxDigit) {
+        curr_ptr.store((ptr + 1) % sampler.size());
+      }
+      return resource[ptr % sampler.size()];
+    }
+  };
+
   /*! \brief number of copies in CPU temp space */
   int cpu_temp_space_copy_;
   /*! \brief number of copies in GPU temp space */
   int gpu_temp_space_copy_;
+  /*! \brief number of copies in CPU native random sampler */
+  int cpu_native_rand_copy_;
+  /*! \brief number of copies in GPU native random sampler */
+  int gpu_native_rand_copy_;
   /*! \brief Reference to the engine */
   std::shared_ptr<Engine> engine_ref_;
   /*! \brief Reference to the storage */
@@ -255,11 +345,15 @@ class ResourceManagerImpl : public ResourceManager {
   std::unique_ptr<ResourceRandom<cpu> > cpu_rand_;
   /*! \brief CPU temp space resources */
   std::unique_ptr<ResourceTempSpace> cpu_space_;
+  /*! \brief CPU native random number resources */
+  std::unique_ptr<ResourceNativeRandom<cpu> > cpu_native_rand_;
 #if MXNET_USE_CUDA
   /*! \brief random number generator for GPU */
   common::LazyAllocArray<ResourceRandom<gpu> > gpu_rand_;
   /*! \brief temp space for GPU */
   common::LazyAllocArray<ResourceTempSpace> gpu_space_;
+  /*! \brief GPU native (on device) random number resources */
+  common::LazyAllocArray<ResourceNativeRandom<gpu> > gpu_native_rand_;
 #endif
 };
 }  // namespace resource
diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py
index 08302b8..db5508d 100644
--- a/tests/python/unittest/test_module.py
+++ b/tests/python/unittest/test_module.py
@@ -613,7 +613,7 @@ def test_factorization_machine_module(verbose=False):
             expected_accuracy = 0.02
         elif optimizer == 'adam':
             # use Sparse Adam to train
-            adam = mx.optimizer.Adam(clip_gradient=5.0, learning_rate=0.001,
+            adam = mx.optimizer.Adam(clip_gradient=5.0, learning_rate=0.0005,
                                      rescale_grad=1.0/batch_size)
             mod.init_optimizer(optimizer=adam)
             if num_epochs is None:
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index d05e325..0230d5f 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -330,8 +330,8 @@ def check_softmax_with_shape(shape, xpu, preserve_shape=False):
     X = mx.symbol.Variable('X')
     L = mx.symbol.Variable('L')
     Y = mx.symbol.SoftmaxOutput(data=X, label=L, preserve_shape=preserve_shape)
-    x = mx.random.uniform(-1, 1, shape, ctx=mx.cpu()).copyto(xpu)
-    l = mx.random.uniform(-1, 1, shape, ctx=mx.cpu()).copyto(xpu)
+    x = mx.random.uniform(-1, 1, shape, ctx=xpu)
+    l = mx.random.uniform(-1, 1, shape, ctx=xpu)
     l[:] = np_softmax(l.asnumpy())
     grad = mx.nd.empty(shape, ctx = xpu)
     exec1 = Y.bind(xpu, args = [x, l], args_grad = {'X': grad})
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index 6178cbe..1aa2e22 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -666,7 +666,7 @@ def test_nadam():
     loss = Loss(output, l)
     loss = mx.sym.make_loss(loss)
     mod = mx.mod.Module(loss, data_names=('data',), label_names=('label',))
-    mod.fit(data_iter, num_epoch=30, optimizer_params={'learning_rate': 0.005, 'wd': 0.0005},
+    mod.fit(data_iter, num_epoch=60, optimizer_params={'learning_rate': 0.0005, 'wd': 0.0005},
             initializer=mx.init.Xavier(magnitude=2), eval_metric=mx.metric.Loss(),
             optimizer='nadam')
     assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.1
diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py
index a67e2d1..0efe8e6 100644
--- a/tests/python/unittest/test_random.py
+++ b/tests/python/unittest/test_random.py
@@ -17,12 +17,16 @@
 
 import os
 import mxnet as mx
+from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf
 import numpy as np
+import scipy.stats as ss
 
 def same(a, b):
     return np.sum(a != b) == 0
 
 def check_with_device(device, dtype):
+    # The thresholds chosen for the tests are too loose. We will rely on the other tests to test the samples from the
+    #  generators.
     tol = 0.1
     symbols = [
         {
@@ -216,6 +220,124 @@ def test_sample_multinomial():
             real_dx[y[i][j]] += 5.0 / rprob[j]
         mx.test_utils.assert_almost_equal(real_dx, dx.asnumpy()[i])
 
+# Test the generators with the chi-square testing
+def test_normal_generator():
+    ctx = mx.context.current_context()
+    for dtype in ['float16', 'float32', 'float64']:
+        for mu, sigma in [(0.0, 1.0), (1.0, 5.0)]:
+            print("ctx=%s, dtype=%s, Mu=%g, Sigma=%g:" % (ctx, dtype, mu, sigma))
+            buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.norm.ppf(x, mu, sigma), 5)
+            generator_mx = lambda x: mx.nd.random.normal(mu, sigma, shape=x, ctx=ctx, dtype=dtype).asnumpy()
+            verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
+            generator_mx_same_seed =\
+                lambda x: np.concatenate(
+                    [mx.nd.random.normal(mu, sigma, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
+                     for _ in range(10)])
+            verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
+
+def test_uniform_generator():
+    ctx = mx.context.current_context()
+    for dtype in ['float16', 'float32', 'float64']:
+        for low, high in [(-1.0, 1.0), (1.0, 3.0)]:
+            print("ctx=%s, dtype=%s, Low=%g, High=%g:" % (ctx, dtype, low, high))
+            buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.uniform.ppf(x, loc=low, scale=high - low), 5)
+            generator_mx = lambda x: mx.nd.random.uniform(low, high, shape=x, ctx=ctx, dtype=dtype).asnumpy()
+            verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
+            generator_mx_same_seed = \
+                lambda x: np.concatenate(
+                    [mx.nd.random.uniform(low, high, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
+                     for _ in range(10)])
+            verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
+
+def test_gamma_generator():
+    ctx = mx.context.current_context()
+    for dtype in ['float16', 'float32', 'float64']:
+        for kappa, theta in [(0.5, 1.0), (1.0, 5.0)]:
+            print("ctx=%s, dtype=%s, Shape=%g, Scale=%g:" % (ctx, dtype, kappa, theta))
+            buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.gamma.ppf(x, a=kappa, loc=0, scale=theta), 5)
+            generator_mx = lambda x: mx.nd.random.gamma(kappa, theta, shape=x, ctx=ctx, dtype=dtype).asnumpy()
+            verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
+            generator_mx_same_seed = \
+                lambda x: np.concatenate(
+                    [mx.nd.random.gamma(kappa, theta, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
+                     for _ in range(10)])
+            verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
+
+def test_exponential_generator():
+    ctx = mx.context.current_context()
+    for dtype in ['float16', 'float32', 'float64']:
+        for scale in [0.1, 1.0]:
+            print("ctx=%s, dtype=%s, Scale=%g:" % (ctx, dtype, scale))
+            buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.expon.ppf(x, loc=0, scale=scale), 5)
+            generator_mx = lambda x: mx.nd.random.exponential(scale, shape=x, ctx=ctx, dtype=dtype).asnumpy()
+            verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
+            generator_mx_same_seed = \
+                lambda x: np.concatenate(
+                    [mx.nd.random.exponential(scale, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
+                     for _ in range(10)])
+            verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
+
+def test_poisson_generator():
+    ctx = mx.context.current_context()
+    for dtype in ['float16', 'float32', 'float64']:
+        for lam in [1, 10]:
+            print("ctx=%s, dtype=%s, Lambda=%d:" % (ctx, dtype, lam))
+            buckets = [(-1.0, lam - 0.5), (lam - 0.5, 2 * lam + 0.5), (2 * lam + 0.5, np.inf)]
+            probs = [ss.poisson.cdf(bucket[1], lam) - ss.poisson.cdf(bucket[0], lam) for bucket in buckets]
+            generator_mx = lambda x: mx.nd.random.poisson(lam, shape=x, ctx=ctx, dtype=dtype).asnumpy()
+            verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
+            generator_mx_same_seed = \
+                lambda x: np.concatenate(
+                    [mx.nd.random.poisson(lam, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
+                     for _ in range(10)])
+            verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
+
+def test_negative_binomial_generator():
+    ctx = mx.context.current_context()
+    for dtype in ['float16', 'float32', 'float64']:
+        success_num = 2
+        success_prob = 0.2
+        print("ctx=%s, dtype=%s, Success Num=%d:, Success Prob=%g" % (ctx, dtype, success_num, success_prob))
+        buckets = [(-1.0, 2.5), (2.5, 5.5), (5.5, 8.5), (8.5, np.inf)]
+        probs = [ss.nbinom.cdf(bucket[1], success_num, success_prob) -
+                 ss.nbinom.cdf(bucket[0], success_num, success_prob) for bucket in buckets]
+        generator_mx = lambda x: mx.nd.random.negative_binomial(success_num, success_prob,
+                                                                shape=x, ctx=ctx, dtype=dtype).asnumpy()
+        verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
+        generator_mx_same_seed = \
+            lambda x: np.concatenate(
+                [mx.nd.random.negative_binomial(success_num, success_prob, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
+                 for _ in range(10)])
+        verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
+        # Also test the Gamm-Poisson Mixture
+        print('Gamm-Poisson Mixture Test:')
+        alpha = 1.0 / success_num
+        mu = (1.0 - success_prob) / success_prob / alpha
+        generator_mx = lambda x: mx.nd.random.generalized_negative_binomial(mu, alpha,
+                                                                            shape=x, ctx=ctx, dtype=dtype).asnumpy()
+        verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
+        generator_mx_same_seed = \
+            lambda x: np.concatenate(
+                [mx.nd.random.generalized_negative_binomial(mu, alpha, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
+                 for _ in range(10)])
+        verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
+
+def test_multinomial_generator():
+    ctx = mx.context.current_context()
+    probs = [0.1, 0.2, 0.3, 0.05, 0.15, 0.2]
+    buckets = list(range(6))
+    for dtype in ['float16', 'float32', 'float64']:
+        print("ctx=%s, dtype=%s" %(ctx, dtype))
+        generator_mx = lambda x: mx.nd.random.multinomial(data=mx.nd.array(np.array(probs), ctx=ctx, dtype=dtype),
+                                                          shape=x).asnumpy()
+        verify_generator(generator_mx, buckets, probs)
+        generator_mx_same_seed = \
+            lambda x: np.concatenate(
+                [mx.nd.random.multinomial(data=mx.nd.array(np.array(probs), ctx=ctx, dtype=dtype),
+                                                          shape=x // 10).asnumpy()
+                 for _ in range(10)])
+        verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
+
 
 if __name__ == '__main__':
     import nose

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].