You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/12/28 20:27:30 UTC

[GitHub] piiswrong closed pull request #9119: fix random generator: do not gen seed each time

piiswrong closed pull request #9119: fix random generator: do not gen seed each time
URL: https://github.com/apache/incubator-mxnet/pull/9119
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/amalgamation/amalgamation.py b/amalgamation/amalgamation.py
index 9419898135..f1e1e02f54 100644
--- a/amalgamation/amalgamation.py
+++ b/amalgamation/amalgamation.py
@@ -21,7 +21,7 @@
 
 blacklist = [
     'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh',
-    'cuda_runtime.h', 'cudnn.h', 'cudnn_lrn-inl.h', 'curand.h',
+    'cuda_runtime.h', 'cudnn.h', 'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h',
     'glog/logging.h', 'io/azure_filesys.h', 'io/hdfs_filesys.h', 'io/s3_filesys.h',
     'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h',
     'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h',
diff --git a/include/mxnet/resource.h b/include/mxnet/resource.h
index 7d2e6caf85..773baf04c1 100644
--- a/include/mxnet/resource.h
+++ b/include/mxnet/resource.h
@@ -28,6 +28,7 @@
 #include <dmlc/logging.h>
 #include "./base.h"
 #include "./engine.h"
+#include "../../src/common/random_generator.h"
 
 namespace mxnet {
 
@@ -40,7 +41,9 @@ struct ResourceRequest {
     /*! \brief mshadow::Random<xpu> object */
     kRandom,
     /*! \brief A dynamic temp space that can be arbitrary size */
-    kTempSpace
+    kTempSpace,
+    /*! \brief common::RandGenerator<xpu> object, which can be used in GPU kernel functions */
+    kParallelRandom
   };
   /*! \brief type of resources */
   Type type;
@@ -89,6 +92,19 @@ struct Resource {
     ret->set_stream(stream);
     return ret;
   }
+
+  /*!
+   * \brief Get parallel random number generator.
+   * \tparam xpu the device type of random number generator.
+   * \tparam DType the return type.
+   * \return the native random number generator. for gpu, it is allocated on global memory.
+   */
+  template<typename xpu, typename DType>
+  inline common::random::RandGenerator<xpu, DType>* get_parallel_random() const {
+    CHECK_EQ(req.type, ResourceRequest::kParallelRandom);
+    return static_cast<common::random::RandGenerator<xpu, DType>*>(ptr_);
+  }
+
   /*!
    * \brief Get space requested as mshadow Tensor.
    *  The caller can request arbitrary size.
diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h
index d19f98b2f4..a8481c1d36 100644
--- a/include/mxnet/storage.h
+++ b/include/mxnet/storage.h
@@ -82,7 +82,7 @@ class Storage {
   virtual void SharedIncrementRefCount(Handle handle) = 0;
   /*!
    * \brief Free storage.
-   * \param handle Handle struect.
+   * \param handle Handle struct.
    */
   virtual void Free(Handle handle) = 0;
   /*!
diff --git a/perl-package/AI-MXNet/t/test_random.t b/perl-package/AI-MXNet/t/test_random.t
index c95a199f21..60cebcf34d 100644
--- a/perl-package/AI-MXNet/t/test_random.t
+++ b/perl-package/AI-MXNet/t/test_random.t
@@ -87,7 +87,7 @@ sub check_with_device
             ]
         },
     );
-    my $shape = [100, 100];
+    my $shape = [1000, 1000];
     for my $symbdic (@symbols)
     {
         my $name = $symbdic->{name};
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 7e8e7c2937..aebb52e36b 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -648,7 +648,8 @@ def update(self, index, weight, grad, state):
         if self.clip_gradient is not None:
             grad = clip(grad, -self.clip_gradient, self.clip_gradient)
         weight[:] += - lr/2 * (grad + wd * weight) + normal(0, math.sqrt(lr),
-                                                            weight.shape, weight.context)
+                                                            shape=weight.shape,
+                                                            ctx=weight.context)
 
 
 @register  # pylint: disable=invalid-name
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 53814b766f..58bc8d38f6 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -34,6 +34,10 @@
 import numpy as np
 import numpy.testing as npt
 import numpy.random as rnd
+try:
+    import scipy.stats as ss
+except ImportError:
+    ss = None
 try:
     import requests
 except ImportError:
@@ -1593,3 +1597,225 @@ def next(self):
             The data of next batch.
         """
         return self.the_batch
+
+def gen_buckets_probs_with_ppf(ppf, nbuckets):
+    """Generate the buckets and probabilities for chi_square test when the ppf (Quantile function)
+     is specified.
+
+    Parameters
+    ----------
+    ppf : function
+        The Quantile function that takes a probability and maps it back to a value.
+        It's the inverse of the cdf function
+    nbuckets : int
+        size of the buckets
+
+    Returns
+    -------
+    buckets : list of tuple
+        The generated buckets
+    probs : list
+        The generate probabilities
+    """
+    assert nbuckets > 0
+    probs = [1.0 / nbuckets for _ in range(nbuckets)]
+    buckets = [(ppf(i / float(nbuckets)), ppf((i + 1) / float(nbuckets))) for i in range(nbuckets)]
+    return buckets, probs
+
+def mean_check(generator, mu, sigma, nsamples=1000000):
+    """Test the generator by matching the mean.
+
+    We test the sample mean by checking if it falls inside the range
+        (mu - 3 * sigma / sqrt(n), mu + 3 * sigma / sqrt(n))
+
+    References::
+
+        @incollection{goucher2009beautiful,
+              title={Beautiful Testing: Leading Professionals Reveal How They Improve Software},
+              author={Goucher, Adam and Riley, Tim},
+              year={2009},
+              chapter=10
+        }
+
+    Examples::
+
+        generator = lambda x: np.random.normal(0, 1.0, size=x)
+        mean_check_ret = mean_check(generator, 0, 1.0)
+
+    Parameters
+    ----------
+    generator : function
+        The generator function. It's expected to generate N i.i.d samples by calling generator(N).
+    mu : float
+    sigma : float
+    nsamples : int
+
+    Returns
+    -------
+    ret : bool
+        Whether the mean test succeeds
+    """
+    samples = np.array(generator(nsamples))
+    sample_mean = samples.mean()
+    ret = (sample_mean > mu - 3 * sigma / np.sqrt(nsamples)) and\
+          (sample_mean < mu + 3 * sigma / np.sqrt(nsamples))
+    return ret
+
+def var_check(generator, sigma, nsamples=1000000):
+    """Test the generator by matching the variance.
+    It will need a large number of samples and is not recommended to use
+
+    We test the sample variance by checking if it falls inside the range
+        (sigma^2 - 3 * sqrt(2 * sigma^4 / (n-1)), sigma^2 + 3 * sqrt(2 * sigma^4 / (n-1)))
+
+    References::
+
+        @incollection{goucher2009beautiful,
+              title={Beautiful Testing: Leading Professionals Reveal How They Improve Software},
+              author={Goucher, Adam and Riley, Tim},
+              year={2009},
+              chapter=10
+        }
+
+    Examples::
+
+        generator = lambda x: np.random.normal(0, 1.0, size=x)
+        var_check_ret = var_check(generator, 0, 1.0)
+
+    Parameters
+    ----------
+    generator : function
+        The generator function. It's expected to generate N i.i.d samples by calling generator(N).
+    sigma : float
+    nsamples : int
+
+    Returns
+    -------
+    ret : bool
+        Whether the variance test succeeds
+    """
+    samples = np.array(generator(nsamples))
+    sample_var = samples.var(ddof=1)
+    ret = (sample_var > sigma ** 2 - 3 * np.sqrt(2 * sigma ** 4 / (nsamples - 1))) and\
+          (sample_var < sigma ** 2 + 3 * np.sqrt(2 * sigma ** 4 / (nsamples - 1)))
+    return ret
+
+def chi_square_check(generator, buckets, probs, nsamples=1000000):
+    """Run the chi-square test for the generator. The generator can be both continuous and discrete.
+    If the generator is continuous, the buckets should contain tuples of (range_min, range_max) and
+     the probs should be the corresponding ideal probability within the specific ranges.
+    Otherwise, the buckets should be the possible output of the discrete distribution and the probs
+     should be groud-truth probability.
+
+    Usually the user is required to specify the probs parameter.
+
+    After obtatining the p value, we could further use the standard p > 0.05 threshold to get
+     the final result.
+
+    Examples::
+        buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.norm.ppf(x, 0, 1), 5)
+        generator = lambda x: np.random.normal(0, 1.0, size=x)
+        p = chi_square_check(generator=generator, buckets=buckets, probs=probs)
+        assert(p > 0.05)
+
+    Parameters
+    ----------
+    generator: function
+        A function that is assumed to generate i.i.d samples from a specific distribution.
+        generator(N) should generate N random samples.
+    buckets: list of tuple or list of number
+        The buckets to run the chi-square the test. Make sure that the buckets cover
+         the whole range of the distribution. Also, the buckets must be in ascending order and have
+         no intersection
+    probs: list or tuple
+        The ground-truth probability of the random value fall in a specific bucket.
+    nsamples:int
+        The number of samples to generate for the testing
+
+    Returns
+    -------
+    p : float
+        p value that the generator has the expected distribution.
+        A higher value indicates a larger confidence
+    obs_freq : list
+        Observed frequency of buckets
+    expected_freq : list
+        The expected (ground-truth) frequency of the buckets
+    """
+    if not ss:
+        raise ImportError("scipy is not available."
+                          " Please check if the scipy python bindings are installed.")
+    assert isinstance(buckets, list)
+    samples = generator(nsamples)
+    assert len(probs) == len(buckets)
+    if isinstance(buckets[0], (list, tuple)):
+        # Check whether the buckets are valid and fill them into a npy array
+        continuous_dist = True
+        buckets_npy = np.zeros((len(buckets) * 2, ), dtype=np.float32)
+        for i, _ in enumerate(buckets):
+            assert(buckets[i][0] <= buckets[i][1])
+            if i < len(buckets) - 1:
+                assert(buckets[i][1] <= buckets[i + 1][0])
+            buckets_npy[i * 2] = buckets[i][0]
+            buckets_npy[i * 2 + 1] = buckets[i][1]
+    else:
+        continuous_dist = False
+        buckets_npy = np.array(buckets)
+    expected_freq = (nsamples * np.array(probs, dtype=np.float32)).astype(np.int32)
+    if continuous_dist:
+        sample_bucket_ids = np.searchsorted(buckets_npy, samples, side='right')
+    else:
+        sample_bucket_ids = samples
+    if continuous_dist:
+        sample_bucket_ids = sample_bucket_ids // 2
+    obs_freq = np.zeros(shape=len(buckets), dtype=np.int)
+    for i in range(len(buckets)):
+        obs_freq[i] = (sample_bucket_ids == i).sum()
+    _, p = ss.chisquare(f_obs=obs_freq, f_exp=expected_freq)
+    return p, obs_freq, expected_freq
+
+def verify_generator(generator, buckets, probs, nsamples=1000000, nrepeat=5, success_rate=0.25):
+    """Verify whether the generator is correct using chi-square testing.
+
+    The test is repeated for "nrepeat" times and we check if the success rate is
+     above the threshold (25% by default).
+
+    Parameters
+    ----------
+    generator: function
+        A function that is assumed to generate i.i.d samples from a specific distribution.
+            generator(N) should generate N random samples.
+    buckets: list of tuple or list of number
+        The buckets to run the chi-square the test. Make sure that the buckets cover
+         the whole range of the distribution. Also, the buckets must be in ascending order and
+         have no intersection
+    probs: list or tuple
+        The ground-truth probability of the random value fall in a specific bucket.
+    nsamples: int
+        The number of samples to generate for the testing
+    nrepeat: int
+        The times to repeat the test
+    success_rate: float
+        The desired success rate
+
+    Returns
+    -------
+    cs_ret_l: list
+        The p values of the chi-square test.
+    """
+    cs_ret_l = []
+    obs_freq_l = []
+    expected_freq_l = []
+    for _ in range(nrepeat):
+        cs_ret, obs_freq, expected_freq = chi_square_check(generator=generator, buckets=buckets,
+                                                           probs=probs, nsamples=nsamples)
+        cs_ret_l.append(cs_ret)
+        obs_freq_l.append(obs_freq)
+        expected_freq_l.append(expected_freq)
+    success_num = (np.array(cs_ret_l) > 0.05).sum()
+    if success_num < nrepeat * success_rate:
+        raise AssertionError("Generator test fails, Chi-square p=%s, obs_freq=%s, expected_freq=%s."
+                             "\nbuckets=%s, probs=%s"
+                             % (str(cs_ret_l), str(obs_freq_l), str(expected_freq_l),
+                                str(buckets), str(probs)))
+    return cs_ret_l
diff --git a/src/common/random_generator.cu b/src/common/random_generator.cu
new file mode 100644
index 0000000000..5f6ac4469e
--- /dev/null
+++ b/src/common/random_generator.cu
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2017 by Contributors
+ * \file random_generator.cu
+ * \brief gpu implements for parallel random number generator.
+ */
+
+#include <algorithm>
+#include "./random_generator.h"
+#include "../operator/mxnet_op.h"
+
+namespace mxnet {
+namespace common {
+namespace random {
+
+__global__ void rand_generator_seed_kernel(curandStatePhilox4_32_10_t *states_,
+                                           const int size,
+                                           uint32_t seed) {
+  int id = blockIdx.x * blockDim.x + threadIdx.x;
+  if (id < size) curand_init(seed, id, 0, states_ + id);
+}
+
+template<>
+void RandGenerator<gpu, float>::Seed(Stream<gpu> *s, uint32_t seed) {
+  using namespace mshadow::cuda;
+  int ngrid = std::min(kMaxGridNum,
+                       (RandGenerator<gpu, float>::kNumRandomStates + kBaseThreadNum - 1) /
+                         kBaseThreadNum);
+  rand_generator_seed_kernel
+      <<<ngrid, kBaseThreadNum, 0, Stream<gpu>::GetStream(s)>>>(
+          states_,
+          RandGenerator<gpu, float>::kNumRandomStates,
+          seed);
+}
+
+}  // namespace random
+}  // namespace common
+}  // namespace mxnet
diff --git a/src/common/random_generator.h b/src/common/random_generator.h
new file mode 100644
index 0000000000..21db9d7a72
--- /dev/null
+++ b/src/common/random_generator.h
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2017 by Contributors
+ * \file random_generator.h
+ * \brief Parallel random number generator.
+ */
+#ifndef MXNET_COMMON_RANDOM_GENERATOR_H_
+#define MXNET_COMMON_RANDOM_GENERATOR_H_
+
+#include <mxnet/base.h>
+#include <random>
+#include <new>
+
+#if MXNET_USE_CUDA
+#include <curand_kernel.h>
+#include "../common/cuda_utils.h"
+#endif  // MXNET_USE_CUDA
+
+using namespace mshadow;
+
+namespace mxnet {
+namespace common {
+namespace random {
+
+template<typename Device, typename DType MSHADOW_DEFAULT_DTYPE>
+class RandGenerator;
+
+template<typename DType>
+class RandGenerator<cpu, DType> {
+ public:
+  // at least how many random numbers should be generated by one CPU thread.
+  static const int kMinNumRandomPerThread = 64;
+  // store how many global random states for CPU.
+  static const int kNumRandomStates = 1024;
+
+  // implementation class for random number generator
+  class Impl {
+   public:
+    typedef typename std::conditional<std::is_floating_point<DType>::value,
+                                      DType, double>::type FType;
+
+    explicit Impl(RandGenerator<cpu, DType> *gen, int state_idx)
+        : engine_(gen->states_ + state_idx) {}
+
+    Impl(const Impl &) = delete;
+    Impl &operator=(const Impl &) = delete;
+
+    MSHADOW_XINLINE int rand() { return engine_->operator()(); }
+
+    MSHADOW_XINLINE FType uniform() {
+      typedef typename std::conditional<std::is_integral<DType>::value,
+      std::uniform_int_distribution<DType>,
+      std::uniform_real_distribution<FType>>::type GType;
+      GType dist_uniform;
+      return dist_uniform(*engine_);
+    }
+
+    MSHADOW_XINLINE FType normal() {
+      std::normal_distribution<FType> dist_normal;
+      return dist_normal(*engine_);
+    }
+
+   private:
+    std::mt19937 *engine_;
+  };
+
+  static void AllocState(RandGenerator<cpu, DType> *inst) {
+    inst->states_ = new std::mt19937[kNumRandomStates];
+  }
+
+  static void FreeState(RandGenerator<cpu, DType> *inst) {
+    delete[] inst->states_;
+  }
+
+  MSHADOW_XINLINE void Seed(Stream<cpu> *, uint32_t seed) {
+    for (int i = 0; i < kNumRandomStates; ++i) (states_ + i)->seed(seed + i);
+  }
+
+ private:
+  std::mt19937 *states_;
+};
+
+#if MXNET_USE_CUDA
+
+template<typename DType>
+class RandGenerator<gpu, DType> {
+ public:
+  // at least how many random numbers should be generated by one GPU thread.
+  static const int kMinNumRandomPerThread = 64;
+  // store how many global random states for GPU.
+  static const int kNumRandomStates = 32768;
+
+  // uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
+  // by using 1.0-curand_uniform().
+  // Needed as some samplers in sampler.h won't be able to deal with
+  // one of the boundary cases.
+  class Impl {
+   public:
+    Impl &operator=(const Impl &) = delete;
+    Impl(const Impl &) = delete;
+
+    // Copy state to local memory for efficiency.
+    __device__ explicit Impl(RandGenerator<gpu, DType> *gen, int state_idx)
+        : global_gen_(gen),
+          global_state_idx_(state_idx),
+          state_(*(gen->states_ + state_idx)) {}
+
+    __device__ ~Impl() {
+      // store the curand state back into global memory
+      global_gen_->states_[global_state_idx_] = state_;
+    }
+
+    MSHADOW_FORCE_INLINE __device__ int rand() {
+      return curand(&state_);
+    }
+
+    MSHADOW_FORCE_INLINE __device__ float uniform() {
+      return static_cast<float>(1.0) - curand_uniform(&state_);
+    }
+
+    MSHADOW_FORCE_INLINE __device__ float normal() {
+      return curand_normal(&state_);
+    }
+
+   private:
+    RandGenerator<gpu, DType> *global_gen_;
+    int global_state_idx_;
+    curandStatePhilox4_32_10_t state_;
+  };  // class RandGenerator<gpu, DType>::Impl
+
+  static void AllocState(RandGenerator<gpu, DType> *inst) {
+    CUDA_CALL(cudaMalloc(&inst->states_,
+                         kNumRandomStates * sizeof(curandStatePhilox4_32_10_t)));
+  }
+
+  static void FreeState(RandGenerator<gpu, DType> *inst) {
+    CUDA_CALL(cudaFree(inst->states_));
+  }
+
+  void Seed(Stream<gpu> *s, uint32_t seed);
+
+ private:
+  curandStatePhilox4_32_10_t *states_;
+};
+
+template<>
+class RandGenerator<gpu, double> {
+ public:
+  // at least how many random numbers should be generated by one GPU thread.
+  static const int kMinNumRandomPerThread = 64;
+  // store how many global random states for GPU.
+  static const int kNumRandomStates = 32768;
+
+  // uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
+  // by using 1.0-curand_uniform().
+  // Needed as some samplers in sampler.h won't be able to deal with
+  // one of the boundary cases.
+  class Impl {
+   public:
+    Impl &operator=(const Impl &) = delete;
+    Impl(const Impl &) = delete;
+
+    // Copy state to local memory for efficiency.
+    __device__ explicit Impl(RandGenerator<gpu, double> *gen, int state_idx)
+        : global_gen_(gen),
+          global_state_idx_(state_idx),
+          state_(*(gen->states_ + state_idx)) {}
+
+    __device__ ~Impl() {
+      // store the curand state back into global memory
+      global_gen_->states_[global_state_idx_] = state_;
+    }
+
+    MSHADOW_FORCE_INLINE __device__ int rand() {
+      return curand(&state_);
+    }
+
+    MSHADOW_FORCE_INLINE __device__ double uniform() {
+      return static_cast<float>(1.0) - curand_uniform_double(&state_);
+    }
+
+    MSHADOW_FORCE_INLINE __device__ double normal() {
+      return curand_normal_double(&state_);
+    }
+
+   private:
+    RandGenerator<gpu, double> *global_gen_;
+    int global_state_idx_;
+    curandStatePhilox4_32_10_t state_;
+  };  // class RandGenerator<gpu, double>::Impl
+
+ private:
+  curandStatePhilox4_32_10_t *states_;
+};
+
+#endif  // MXNET_USE_CUDA
+
+}  // namespace random
+}  // namespace common
+}  // namespace mxnet
+#endif  // MXNET_COMMON_RANDOM_GENERATOR_H_
diff --git a/src/common/utils.h b/src/common/utils.h
index 038ab2a047..ede218bbdc 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -49,7 +49,6 @@
 namespace mxnet {
 namespace common {
 
-
 /*!
  * \brief IndPtr should be non-negative, in non-decreasing order, start with 0
  *           and end with value equal with size of indices.
diff --git a/src/executor/attach_op_resource_pass.cc b/src/executor/attach_op_resource_pass.cc
index 18feec7957..9a7ed09676 100644
--- a/src/executor/attach_op_resource_pass.cc
+++ b/src/executor/attach_op_resource_pass.cc
@@ -61,6 +61,8 @@ Graph AttachOpResources(Graph g) {
           }
         } else if (req.type == ResourceRequest::kRandom) {
           requested.push_back(ResourceManager::Get()->Request(ctx, req));
+        } else if (req.type == ResourceRequest::kParallelRandom) {
+          requested.push_back(ResourceManager::Get()->Request(ctx, req));
         } else {
           LOG(FATAL) << "resource type not yet supported";
         }
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index e265cce28e..8be1eb4e3d 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -203,6 +203,10 @@ inline void SetDependency(const nnvm::NodeAttrs& attrs,
         requested.push_back(ResourceManager::Get()->Request(ctx, req));
         write_vars.push_back(requested.back().var);
         break;
+       case ResourceRequest::kParallelRandom:
+        requested.push_back(ResourceManager::Get()->Request(ctx, req));
+        write_vars.push_back(requested.back().var);
+        break;
        default:
         LOG(FATAL) << "resource type not yet supported";
       }
diff --git a/src/operator/random/multisample_op.cc b/src/operator/random/multisample_op.cc
index 5f2af61f03..a88db09442 100644
--- a/src/operator/random/multisample_op.cc
+++ b/src/operator/random/multisample_op.cc
@@ -47,7 +47,8 @@ DMLC_REGISTER_PARAMETER(MultiSampleParam);
   .set_attr<nnvm::FInferShape>("FInferShape", MultiSampleOpShape) \
   .set_attr<nnvm::FInferType>("FInferType", MultiSampleOpType) \
   .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) { \
-      return std::vector<ResourceRequest>{ResourceRequest::kRandom, ResourceRequest::kTempSpace}; \
+      return std::vector<ResourceRequest>{ResourceRequest::kParallelRandom, \
+                                          ResourceRequest::kTempSpace}; \
     }) \
   .set_attr<FCompute>("FCompute<cpu>", MultiSampleOpForward<cpu, sampler, num_inputs>) \
   .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes) \
diff --git a/src/operator/random/multisample_op.h b/src/operator/random/multisample_op.h
index 38ccbb6925..e93e453b25 100644
--- a/src/operator/random/multisample_op.h
+++ b/src/operator/random/multisample_op.h
@@ -135,6 +135,8 @@ inline bool MultiSampleOpType(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+using namespace mxnet::common::random;
+
 template<typename xpu, typename IType, typename OType, typename Sampler, int inum>
 struct SamplerCaller;
 
@@ -142,12 +144,12 @@ template<typename xpu, typename IType, typename OType, typename Sampler>
 struct SamplerCaller<xpu, IType, OType, Sampler, 1> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
-                 const Tensor<xpu, 1, unsigned int>& seeds,
-                       mshadow::Stream<xpu> *s) {
+                 RandGenerator<xpu, OType> *pgen,
+                 mshadow::Stream<xpu> *s) {
     Sampler sampler;
     sampler.Sample(inputs[0].FlatTo1D<xpu, IType>(s),
                    outputs[0].FlatTo1D<xpu, OType>(s),
-                   seeds, s);
+                   pgen, s);
   }
 };
 
@@ -155,13 +157,13 @@ template<typename xpu, typename IType, typename OType, typename Sampler>
 struct SamplerCaller<xpu, IType, OType, Sampler, 2> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
-                 const Tensor<xpu, 1, unsigned int>& seeds,
-                       mshadow::Stream<xpu> *s) {
+                 RandGenerator<xpu, OType> *pgen,
+                 mshadow::Stream<xpu> *s) {
     Sampler sampler;
     sampler.Sample(inputs[0].FlatTo1D<xpu, IType>(s),
                    inputs[1].FlatTo1D<xpu, IType>(s),
                    outputs[0].FlatTo1D<xpu, OType>(s),
-                   seeds, s);
+                   pgen, s);
   }
 };
 
@@ -177,15 +179,10 @@ void MultiSampleOpForward(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(outputs.size(), 1);
   CHECK_GT(inputs[0].Size(), 0);
   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-  // Generate multiple seeds for the different threads.
-  const int nSeeds(OptSampleSeedNum<xpu>(outputs[0].Size()));
-  Tensor<xpu, 1, unsigned> seeds
-    = ctx.requested[1].get_space_typed<xpu, 1, unsigned> (Shape1(nSeeds), ctx.get_stream<xpu>());
-  ctx.requested[0].get_random<xpu, float>(s)->GetRandInt(seeds);
   MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, {
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
-        SamplerCaller<xpu, IType, OType, Sampler, inum>
-            ::op(inputs, outputs, seeds, s);
+      RandGenerator<xpu, OType> *pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
+      SamplerCaller<xpu, IType, OType, Sampler, inum>::op(inputs, outputs, pgen, s);
     });
   });
 }
diff --git a/src/operator/random/sample_op.h b/src/operator/random/sample_op.h
index 9fdff03886..a81b41a09a 100644
--- a/src/operator/random/sample_op.h
+++ b/src/operator/random/sample_op.h
@@ -241,31 +241,27 @@ using FSampleCompute = std::function<void (const nnvm::NodeAttrs& attrs,
                                            TBlob* outputs)>;
 
 using mxnet::TBlob;
+using namespace mxnet::common::random;
 
 // Allocates a single chunk of workspace memory and partitions it into three
 // workspace tensors that hold the seeds as well as the distribution parameters.
 template<typename xpu, typename DType>
-MSHADOW_FORCE_INLINE void GetSamplingTempData(index_t N, DType p1, DType p2, const OpContext& ctx,
-                                              Tensor<xpu, 1, unsigned int>* seeds,
+MSHADOW_FORCE_INLINE void GetSamplingTempData(DType p1, DType p2, const OpContext& ctx,
                                               Tensor<xpu, 1, DType>* parm1,
                                               Tensor<xpu, 1, DType>* parm2) {
   Stream<xpu> *s = ctx.get_stream<xpu>();
-  const index_t nSeeds(OptSampleSeedNum<xpu>(N));
   // Combined memory requirement for the workspace data.
-  const index_t nInt(nSeeds + (2 * sizeof(DType) + sizeof(unsigned) - 1) / sizeof(unsigned));
+  const index_t nInt((2 * sizeof(DType) + sizeof(unsigned) - 1) / sizeof(unsigned));
   Tensor<xpu, 1, unsigned> wspace
     = ctx.requested[1].get_space_typed<xpu, 1, unsigned>(Shape1(nInt), s);
-  // Partition workspace into three chunks and initialize them.
-  *seeds = Tensor<xpu, 1, unsigned>(wspace.dptr_, Shape1(nSeeds), s);
-  ctx.requested[0].get_random<xpu, float>(s)->GetRandInt(*seeds);
-  DType *pspace = static_cast<DType*>(static_cast<void*>(wspace.dptr_+nSeeds));
+  // Partition workspace into two chunks and initialize them.
+  DType *pspace = static_cast<DType*>(static_cast<void*>(wspace.dptr_));
   *parm1 = Tensor<xpu, 1, DType>(pspace, Shape1(1), s);
   Copy(*parm1, Tensor<cpu, 1, DType>(&p1, Shape1(1)), s);
   *parm2 = Tensor<xpu, 1, DType>(pspace+1, Shape1(1), s);
   Copy(*parm2, Tensor<cpu, 1, DType>(&p2, Shape1(1)), s);
 }
 
-
 template<typename xpu, typename Sampler>
 struct SampleMaster;
 
@@ -278,14 +274,14 @@ struct SampleMaster<xpu, UniformSampler<xpu>> {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     const SampleUniformParam& param = nnvm::get<SampleUniformParam>(attrs.parsed);
     CHECK_GE(param.high, param.low) << "low must be less or equal to high in uniform distribution";
-    Tensor<xpu, 1, unsigned int> seeds;
     Tensor<xpu, 1, float> low, high;
-    GetSamplingTempData<xpu, float>(outputs->Size(), param.low, param.high, ctx,
-                                    &seeds, &low, &high);
+    GetSamplingTempData<xpu, float>(param.low, param.high, ctx,
+                                    &low, &high);
     UniformSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+      RandGenerator<xpu, OType> *pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(low, high, out, seeds, s);
+      sampler.Sample(low, high, out, pgen, s);
     });
   }
 };
@@ -299,14 +295,13 @@ struct SampleMaster<xpu, NormalSampler<xpu>> {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     const SampleNormalParam& param = nnvm::get<SampleNormalParam>(attrs.parsed);
     CHECK_GT(param.scale, 0) << "scale parameter in gaussian has to be positive";
-    Tensor<xpu, 1, unsigned int> seeds;
     Tensor<xpu, 1, float> loc, scale;
-    GetSamplingTempData<xpu, float>(outputs->Size(), param.loc, param.scale, ctx,
-                                    &seeds, &loc, &scale);
+    GetSamplingTempData<xpu, float>(param.loc, param.scale, ctx, &loc, &scale);
     NormalSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+      RandGenerator<xpu, OType> *pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(loc, scale, out, seeds, s);
+      sampler.Sample(loc, scale, out, pgen, s);
     });
   }
 };
@@ -321,14 +316,13 @@ struct SampleMaster<xpu, GammaSampler<xpu>> {
     const SampleGammaParam& param = nnvm::get<SampleGammaParam>(attrs.parsed);
     CHECK_GT(param.alpha, 0) << "alpha parameter in gamma distribution has to be positive";
     CHECK_GT(param.beta, 0) << "beta parameter in gamma distribution has to be positive";
-    Tensor<xpu, 1, unsigned int> seeds;
     Tensor<xpu, 1, float> alpha, beta;
-    GetSamplingTempData<xpu, float>(outputs->Size(), param.alpha, param.beta, ctx,
-                                    &seeds, &alpha, &beta);
+    GetSamplingTempData<xpu, float>(param.alpha, param.beta, ctx, &alpha, &beta);
     GammaSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+      RandGenerator<xpu, OType> *pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(alpha, beta, out, seeds, s);
+      sampler.Sample(alpha, beta, out, pgen, s);
     });
   }
 };
@@ -342,13 +336,13 @@ struct SampleMaster<xpu, ExponentialSampler<xpu>> {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     const SampleExponentialParam& param = nnvm::get<SampleExponentialParam>(attrs.parsed);
     CHECK_GT(param.lam, 0) << "lambda parameter in exponential distribution has to be positive";
-    Tensor<xpu, 1, unsigned int> seeds;
     Tensor<xpu, 1, float> lam, dummy;
-    GetSamplingTempData<xpu, float>(outputs->Size(), param.lam, 0, ctx, &seeds, &lam, &dummy);
+    GetSamplingTempData<xpu, float>(param.lam, 0, ctx, &lam, &dummy);
     ExponentialSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+      RandGenerator<xpu, OType> *pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(lam, out, seeds, s);
+      sampler.Sample(lam, out, pgen, s);
     });
   }
 };
@@ -362,13 +356,13 @@ struct SampleMaster<xpu, PoissonSampler<xpu>> {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     const SamplePoissonParam& param = nnvm::get<SamplePoissonParam>(attrs.parsed);
     CHECK_GE(param.lam, 0) << "lambda parameter in poisson distribution has to be non-negative";
-    Tensor<xpu, 1, unsigned int> seeds;
     Tensor<xpu, 1, float> lam, dummy;
-    GetSamplingTempData<xpu, float>(outputs->Size(), param.lam, 0, ctx, &seeds, &lam, &dummy);
+    GetSamplingTempData<xpu, float>(param.lam, 0, ctx, &lam, &dummy);
     PoissonSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+      RandGenerator<xpu, OType> *pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(lam, out, seeds, s);
+      sampler.Sample(lam, out, pgen, s);
     });
   }
 };
@@ -383,13 +377,13 @@ struct SampleMaster<xpu, NegativeBinomialSampler<xpu>> {
     const SampleNegBinomialParam& param = nnvm::get<SampleNegBinomialParam>(attrs.parsed);
     CHECK_GE(param.k, 0) << "k parameter in negative binomial distribution has to be non-negative";
     CHECK_GE(param.p, 0) << "p parameter in negative binomial distribution has to be non-negative";
-    Tensor<xpu, 1, unsigned int> seeds;
     Tensor<xpu, 1, float> k, p;
-    GetSamplingTempData<xpu, float>(outputs->Size(), param.k, param.p, ctx, &seeds, &k, &p);
+    GetSamplingTempData<xpu, float>(param.k, param.p, ctx, &k, &p);
     NegativeBinomialSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+      RandGenerator<xpu, OType> *pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(k, p, out, seeds, s);
+      sampler.Sample(k, p, out, pgen, s);
     });
   }
 };
@@ -406,14 +400,13 @@ struct SampleMaster<xpu, GeneralizedNegativeBinomialSampler<xpu>> {
       << "mu parameter in generalized negative binomial distribution has to be non-negative";
     CHECK_GE(param.alpha, 0)
       << "alpha parameter in generalized negative binomial distribution has to be non-negative";
-    Tensor<xpu, 1, unsigned int> seeds;
     Tensor<xpu, 1, float> mu, alpha;
-    GetSamplingTempData<xpu, float>(outputs->Size(), param.mu, param.alpha, ctx,
-                                    &seeds, &mu, &alpha);
+    GetSamplingTempData<xpu, float>(param.mu, param.alpha, ctx, &mu, &alpha);
     GeneralizedNegativeBinomialSampler<xpu> sampler;
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+      RandGenerator<xpu, OType> *pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
       Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
-      sampler.Sample(mu, alpha, out, seeds, s);
+      sampler.Sample(mu, alpha, out, pgen, s);
     });
   }
 };
@@ -502,7 +495,7 @@ inline bool SampleOpType(const nnvm::NodeAttrs& attrs,
 }
 
 inline std::vector<ResourceRequest> SampleResource(const NodeAttrs& attrs) {
-  return { ResourceRequest::kRandom, ResourceRequest::kTempSpace };
+  return { ResourceRequest::kParallelRandom, ResourceRequest::kTempSpace };
 }
 
 }  // namespace op
diff --git a/src/operator/random/sampler.h b/src/operator/random/sampler.h
index d544aec88d..8eace1e092 100644
--- a/src/operator/random/sampler.h
+++ b/src/operator/random/sampler.h
@@ -25,89 +25,52 @@
 #ifndef MXNET_OPERATOR_RANDOM_SAMPLER_H_
 #define MXNET_OPERATOR_RANDOM_SAMPLER_H_
 
-#ifdef __CUDACC__
-#include <curand.h>
-#include <curand_kernel.h>
-#endif  // __CUDACC__
+#include <algorithm>
 
 using namespace mshadow;
 using namespace mxnet::op::mxnet_op;
+using namespace mxnet::common::random;
 
 namespace mxnet {
 namespace op {
 
-// Elementary random number generation for int/uniform/gaussian in CPU and GPU.
-// Will use float data type whenever instantiated for half_t or any other non
-// standard real type.
-template<typename xpu, typename DType>
-class RandGenerator;
-
-template<typename DType>
-class RandGenerator<cpu, DType> {
- public:
-  typedef typename std::conditional<std::is_floating_point<DType>::value,
-                                    DType, float>::type FType;
-  std::mt19937 engine;
-  std::uniform_real_distribution<FType> uniformNum;
-  std::normal_distribution<FType> normalNum;
-  explicit RandGenerator(unsigned int seed): engine(seed) {}
-  MSHADOW_XINLINE int rand() { return engine(); }
-  MSHADOW_XINLINE FType uniform() { return uniformNum(engine); }
-  MSHADOW_XINLINE FType normal() { return normalNum(engine); }
-};
-
-#ifdef __CUDACC__
-
-// uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
-// by using 1.0-curand_uniform(). Needed as some samplers below won't be able to deal with
-// one of the boundary cases.
-template<typename DType>
-class RandGenerator<gpu, DType> {
- public:
-  curandState_t state;
-  __device__ RandGenerator(unsigned int seed) { curand_init(seed, 0, 0, &state); }
-  MSHADOW_FORCE_INLINE __device__ int rand() { return curand(&state); }
-  MSHADOW_FORCE_INLINE __device__ float uniform()
-                              { return static_cast<float>(1.0) - curand_uniform(&state); }
-  MSHADOW_FORCE_INLINE __device__ float normal() { return curand_normal(&state); }
-};
-
-template<>
-class RandGenerator<gpu, double> {
- public:
-  curandState_t state;
-  __device__ RandGenerator(unsigned int seed) { curand_init(seed, 0, 0, &state); }
-  MSHADOW_FORCE_INLINE __device__ int rand() { return curand(&state); }
-  MSHADOW_FORCE_INLINE __device__ double uniform()
-                            { return static_cast<double>(1.0) - curand_uniform_double(&state); }
-  MSHADOW_FORCE_INLINE __device__ double normal() { return curand_normal_double(&state); }
-};
-
-#endif  // __CUDACC__
-
-// Number of seeds/threads when sampling on cpu/gpu.
-template<typename xpu>
-MSHADOW_XINLINE index_t OptSampleSeedNum(index_t N);
-template<>
-MSHADOW_XINLINE index_t OptSampleSeedNum<cpu>(index_t N) {
-  return omp_get_num_threads();
-}
-template<>
-MSHADOW_XINLINE index_t OptSampleSeedNum<gpu>(index_t N) {
-  return N;
+/*!
+ * \brief Launch a generic kernel with parallel random generator.
+ * \tparam gen random generator
+ * \tparam N Number of iterations
+ * \tparam Args Varargs type to eventually pass to the OP::Map() functoion
+ */
+template<typename OP, typename xpu, typename GType, typename ...Args>
+inline static void LaunchRNG(mshadow::Stream<xpu> *s,
+                             common::random::RandGenerator<xpu, GType> *gen,
+                             const int N, Args... args) {
+  const int nloop = (N + RandGenerator<xpu, GType>::kMinNumRandomPerThread - 1) /
+                    RandGenerator<xpu, GType>::kMinNumRandomPerThread;
+  const int nthread = std::min(nloop, RandGenerator<xpu, GType>::kNumRandomStates);
+  const int step = (N + nthread - 1) / nthread;
+  Kernel<OP, xpu>::Launch(s, nthread, *gen, N, step, args...);
 }
 
+#define RNG_KERNEL_LOOP(xpu, GType, thread_id, gen, N, step, ...)        \
+  const int start = thread_id * step;                                    \
+  const int end = start + step;                                          \
+  typename RandGenerator<xpu, GType>::Impl genImpl(&gen, thread_id);     \
+  for (int i = start; i < end && i < N; ++i) {                           \
+    {__VA_ARGS__}                                                        \
+  }
+
 template<typename xpu>
 struct SampleUniformKernel {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int i, index_t nParm, index_t nSample, index_t nSeed,
-                     const IType *lower, const IType *upper, OType *out, const unsigned *seed) {
-    index_t nBatch(nSample/nParm), nChunk((nSample+nSeed-1)/nSeed),
-            start(i*nChunk), end((i+1)*nChunk < nSample ? (i+1)*nChunk : nSample);
-    RandGenerator<xpu, OType> gen(seed[i]);
-    for ( index_t j = start; j < end; ++j ) {
-      out[j] = OType(lower[j/nBatch] + (upper[j/nBatch] - lower[j/nBatch]) * gen.uniform());
-    }
+  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, OType> gen,
+                                  const int N, const int step,
+                                  index_t nParm, index_t nSample,
+                                  const IType *lower, const IType *upper, OType *out) {
+    RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, {
+      index_t nBatch(1 + (nSample - 1) / nParm);
+      out[i] = OType(lower[i / nBatch] +
+                     (upper[i / nBatch] - lower[i / nBatch]) * genImpl.uniform());
+    });
   }
 };
 
@@ -117,25 +80,24 @@ struct UniformSampler {
   MSHADOW_FORCE_INLINE void Sample(const Tensor<xpu, 1, IType>& lower,
                                    const Tensor<xpu, 1, IType>& upper,
                                    const Tensor<xpu, 1, OType>& out,
-                                   const Tensor<xpu, 1, unsigned>& seed,
-                                         Stream<xpu> *s) {
-    Kernel<SampleUniformKernel<xpu>, xpu>
-      ::Launch(s, seed.size(0), lower.size(0), out.size(0), seed.size(0),
-               lower.dptr_, upper.dptr_, out.dptr_, seed.dptr_);
+                                   RandGenerator<xpu, OType> *pgen,
+                                   Stream<xpu> *s) {
+    LaunchRNG<SampleUniformKernel<xpu>, xpu>(s, pgen, out.size(0), lower.size(0), out.size(0),
+                                             lower.dptr_, upper.dptr_, out.dptr_);
   }
 };
 
 template<typename xpu>
 struct SampleNormalKernel {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int i, index_t nParm, index_t nSample, index_t nSeed,
-                            const IType *mean, const IType *std, OType *out, const unsigned *seed) {
-    index_t nBatch(nSample/nParm), nChunk((nSample+nSeed-1)/nSeed),
-            start(i*nChunk), end((i+1)*nChunk < nSample ? (i+1)*nChunk : nSample);
-    RandGenerator<xpu, OType> gen(seed[i]);
-    for ( index_t j = start; j < end; ++j ) {
-      out[j] = OType(gen.normal() * std[j/nBatch] + mean[j/nBatch]);
-    }
+  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, OType> gen,
+                                  const int N, const int step,
+                                  index_t nParm, index_t nSample,
+                                  const IType *mean, const IType *std, OType *out) {
+    RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, {
+      index_t nBatch(1 + (nSample - 1) / nParm);
+      out[i] = OType(genImpl.normal() * std[i / nBatch] + mean[i / nBatch]);
+    });
   }
 };
 
@@ -145,25 +107,24 @@ struct NormalSampler {
   MSHADOW_FORCE_INLINE void Sample(const Tensor<xpu, 1, IType>& mean,
                                    const Tensor<xpu, 1, IType>& std,
                                    const Tensor<xpu, 1, OType>& out,
-                                   const Tensor<xpu, 1, unsigned>& seed,
-                                         Stream<xpu> *s) {
-    Kernel<SampleNormalKernel<xpu>, xpu>
-      ::Launch(s, seed.size(0), mean.size(0), out.size(0), seed.size(0),
-               mean.dptr_, std.dptr_, out.dptr_, seed.dptr_);
+                                   RandGenerator<xpu, OType> *pgen,
+                                   Stream<xpu> *s) {
+    LaunchRNG<SampleNormalKernel<xpu>, xpu>(s, pgen, out.size(0), mean.size(0), out.size(0),
+                                            mean.dptr_, std.dptr_, out.dptr_);
   }
 };
 
 template<typename xpu>
 struct SampleExponentialKernel {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int i, index_t nParm, index_t nSample, index_t nSeed,
-                                  const IType *lambda, OType *out, const unsigned *seed) {
-    index_t nBatch(nSample/nParm), nChunk((nSample+nSeed-1)/nSeed),
-            start(i*nChunk), end((i+1)*nChunk < nSample ? (i+1)*nChunk : nSample);
-    RandGenerator<xpu, OType> gen(seed[i]);
-    for ( index_t j = start; j < end; ++j ) {
-      out[j] = OType(-log(1.0-gen.uniform()) / lambda[j/nBatch]);
-    }
+  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, OType> gen,
+                                  const int N, const int step,
+                                  index_t nParm, index_t nSample,
+                                  const IType *lambda, OType *out) {
+    RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, {
+      index_t nBatch(1 + (nSample - 1) / nParm);
+      out[i] = OType(-log(1.0 - genImpl.uniform()) / lambda[i / nBatch]);
+    });
   }
 };
 
@@ -172,16 +133,16 @@ struct ExponentialSampler {
   template<typename IType, typename OType>
   MSHADOW_FORCE_INLINE void Sample(const Tensor<xpu, 1, IType>& lambda,
                                    const Tensor<xpu, 1, OType>& out,
-                                   const Tensor<xpu, 1, unsigned>& seed,
-                                         Stream<xpu> *s) {
-    Kernel<SampleExponentialKernel<xpu>, xpu>
-      ::Launch(s, seed.size(0), lambda.size(0), out.size(0), seed.size(0),
-               lambda.dptr_, out.dptr_, seed.dptr_);
+                                   RandGenerator<xpu, OType> *pgen,
+                                   Stream<xpu> *s) {
+    LaunchRNG<SampleExponentialKernel<xpu>, xpu>(s, pgen, out.size(0),
+                                                 lambda.size(0), out.size(0),
+                                                 lambda.dptr_, out.dptr_);
   }
 };
 
 template<typename xpu, typename IType, typename OType>
-MSHADOW_XINLINE OType SampleGamma(IType a, IType b, RandGenerator<xpu, OType> *gen) {
+MSHADOW_XINLINE OType SampleGamma(IType a, IType b, typename RandGenerator<xpu, OType>::Impl *gen) {
   // Generate one sample of the gamma distribution
   OType sample;
   OType d = a < 1 ? a + 2.0 / 3.0 : a - 1.0 / 3.0;
@@ -203,17 +164,16 @@ MSHADOW_XINLINE OType SampleGamma(IType a, IType b, RandGenerator<xpu, OType> *g
 
 template<typename xpu>
 struct SampleGammaKernel {
-  template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int i, index_t nParm, index_t nSample, index_t nSeed,
-                      const IType *alpha, const IType *beta, OType *out, const unsigned *seed) {
-    index_t nBatch(nSample/nParm), nChunk((nSample+nSeed-1)/nSeed),
-            start(i*nChunk), end((i+1)*nChunk < nSample ? (i+1)*nChunk : nSample);
-    typedef typename std::conditional<std::is_floating_point<OType>::value,
-                                     OType, float>::type FType;
-    RandGenerator<xpu, FType> gen(seed[i]);
-    for ( index_t j = start; j < end; ++j ) {
-      out[j] = OType(SampleGamma(alpha[j/nBatch], beta[j/nBatch], &gen));
-    }
+  template<typename IType, typename OType, typename FType>
+  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, FType> gen,
+                                  const int N, const int step,
+                                  index_t nParm, index_t nSample,
+                                  const IType *alpha, const IType *beta, OType *out) {
+    RNG_KERNEL_LOOP(xpu, FType, id, gen, N, step, {
+      index_t nBatch(1 + (nSample - 1) / nParm);
+      out[i] = OType(SampleGamma<xpu, IType, FType>(alpha[i / nBatch],
+                                                    beta[i / nBatch], &genImpl));
+    });
   }
 };
 
@@ -223,16 +183,18 @@ struct GammaSampler {
   MSHADOW_FORCE_INLINE void Sample(const Tensor<xpu, 1, IType>& alpha,
                                    const Tensor<xpu, 1, IType>& beta,
                                    const Tensor<xpu, 1, OType>& out,
-                                   const Tensor<xpu, 1, unsigned>& seed,
-                                         Stream<xpu> *s) {
-    Kernel<SampleGammaKernel<xpu>, xpu>
-      ::Launch(s, seed.size(0), alpha.size(0), out.size(0), seed.size(0),
-               alpha.dptr_, beta.dptr_, out.dptr_, seed.dptr_);
+                                   RandGenerator<xpu, OType> *pgen,
+                                   Stream<xpu> *s) {
+    typedef typename std::conditional<std::is_floating_point<OType>::value,
+                                      OType, float>::type FType;
+    RandGenerator<xpu, FType> *gen = reinterpret_cast<RandGenerator<xpu, FType> *>(pgen);
+    LaunchRNG<SampleGammaKernel<xpu>, xpu>(s, gen, out.size(0), alpha.size(0), out.size(0),
+                                           alpha.dptr_, beta.dptr_, out.dptr_);
   }
 };
 
 template<typename xpu>
-MSHADOW_XINLINE int SamplePoisson(float lambda, RandGenerator<xpu, float> *gen) {
+MSHADOW_XINLINE int SamplePoisson(float lambda, typename RandGenerator<xpu, float>::Impl *gen) {
   // Generate one sample of the poisson distribution. Intentionally written
   // towards a specific type (float) for internal computation which is sufficient
   // for accurate enough computation.
@@ -265,14 +227,14 @@ MSHADOW_XINLINE int SamplePoisson(float lambda, RandGenerator<xpu, float> *gen)
 template<typename xpu>
 struct SamplePoissonKernel {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int i, index_t nParm, index_t nSample, index_t nSeed,
-                                  const IType *lambda, OType *out, const unsigned *seed) {
-    index_t nBatch(nSample/nParm), nChunk((nSample+nSeed-1)/nSeed),
-            start(i*nChunk), end((i+1)*nChunk < nSample ? (i+1)*nChunk : nSample);
-    RandGenerator<xpu, float> gen(seed[i]);
-    for ( index_t j = start; j < end; ++j ) {
-      out[j] = OType(SamplePoisson(lambda[j/nBatch], &gen));
-    }
+  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, float> gen,
+                                  const int N, const int step,
+                                  index_t nParm, index_t nSample,
+                                  const IType *lambda, OType *out) {
+    RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
+      index_t nBatch(1 + (nSample - 1) / nParm);
+      out[i] = OType(SamplePoisson<xpu>(lambda[i / nBatch], &genImpl));
+    });
   }
 };
 
@@ -281,29 +243,29 @@ struct PoissonSampler {
   template<typename IType, typename OType>
   MSHADOW_FORCE_INLINE void Sample(const Tensor<xpu, 1, IType>& lambda,
                                    const Tensor<xpu, 1, OType>& out,
-                                   const Tensor<xpu, 1, unsigned>& seed,
-                                         Stream<xpu> *s) {
-    Kernel<SamplePoissonKernel<xpu>, xpu>
-      ::Launch(s, seed.size(0), lambda.size(0), out.size(0), seed.size(0),
-               lambda.dptr_, out.dptr_, seed.dptr_);
+                                   RandGenerator<xpu, OType> *pgen,
+                                   Stream<xpu> *s) {
+    RandGenerator<xpu, float> *gen = reinterpret_cast<RandGenerator<xpu, float> *>(pgen);
+    LaunchRNG<SamplePoissonKernel<xpu>, xpu>(s, gen, out.size(0), lambda.size(0), out.size(0),
+                                             lambda.dptr_, out.dptr_);
   }
 };
 
 template<typename xpu>
 struct SampleNegativeBinomialKernel {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int i, index_t nParm, index_t nSample, index_t nSeed,
-                             const IType *k, const IType *p, OType *out, const unsigned *seed) {
-    index_t nBatch(nSample/nParm), nChunk((nSample+nSeed-1)/nSeed),
-            start(i*nChunk), end((i+1)*nChunk < nSample ? (i+1)*nChunk : nSample);
-    RandGenerator<xpu, float> gen(seed[i]);
-    for ( index_t j = start; j < end; ++j ) {
-      float alpha = k[j/nBatch];
-      float prob = p[j/nBatch];
+  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, float> gen,
+                                  const int N, const int step,
+                                  index_t nParm, index_t nSample,
+                                  const IType *k, const IType *p, OType *out) {
+    RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
+      index_t nBatch(1 + (nSample - 1) / nParm);
+      float alpha = k[i / nBatch];
+      float prob = p[i / nBatch];
       float beta = (1.0 - prob) / prob;
-      float lambda = SampleGamma(alpha, beta, &gen);
-      out[j] = OType(SamplePoisson(lambda, &gen));
-    }
+      float lambda = SampleGamma<xpu, IType, float>(alpha, beta, &genImpl);
+      out[i] = OType(SamplePoisson<xpu>(lambda, &genImpl));
+    });
   }
 };
 
@@ -313,27 +275,29 @@ struct NegativeBinomialSampler {
   MSHADOW_FORCE_INLINE void Sample(const Tensor<xpu, 1, IType>& k,
                                    const Tensor<xpu, 1, IType>& p,
                                    const Tensor<xpu, 1, OType>& out,
-                                   const Tensor<xpu, 1, unsigned>& seed,
-                                         Stream<xpu> *s) {
-    Kernel<SampleNegativeBinomialKernel<xpu>, xpu>
-      ::Launch(s, seed.size(0), k.size(0), out.size(0), seed.size(0),
-               k.dptr_, p.dptr_, out.dptr_, seed.dptr_);
+                                   RandGenerator<xpu, OType> *pgen,
+                                   Stream<xpu> *s) {
+    RandGenerator<xpu, float> *gen = reinterpret_cast<RandGenerator<xpu, float> *>(pgen);
+    LaunchRNG<SampleNegativeBinomialKernel<xpu>, xpu>(s, gen, out.size(0), k.size(0), out.size(0),
+                                                      k.dptr_, p.dptr_, out.dptr_);
   }
 };
 
 template<typename xpu>
 struct SampleGeneralizedNegativeBinomialKernel {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int i, index_t nParm, index_t nSample, index_t nSeed,
-                        const IType *mu, const IType *alpha, OType *out, const unsigned *seed) {
-    index_t nBatch(nSample/nParm), nChunk((nSample+nSeed-1)/nSeed),
-            start(i*nChunk), end((i+1)*nChunk < nSample ? (i+1)*nChunk : nSample);
-    RandGenerator<xpu, float> gen(seed[i]);
-    for ( index_t j = start; j < end; ++j ) {
-      float lambda = alpha[j/nBatch] == 0 ? static_cast<float>(mu[j/nBatch])
-              : SampleGamma(IType(1) / alpha[j/nBatch], alpha[j/nBatch] * mu[j/nBatch], &gen);
-      out[j] = OType(SamplePoisson(lambda, &gen));
-    }
+  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, float> gen,
+                                  const int N, const int step,
+                                  index_t nParm, index_t nSample,
+                                  const IType *mu, const IType *alpha, OType *out) {
+    RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
+      index_t nBatch(1 + (nSample - 1) / nParm);
+      float lambda = alpha[i / nBatch] == 0 ?
+                     static_cast<float>(mu[i / nBatch]) :
+                     SampleGamma<xpu, IType, float>(IType(1) / alpha[i / nBatch],
+                                                    alpha[i / nBatch] * mu[i / nBatch], &genImpl);
+      out[i] = OType(SamplePoisson<xpu>(lambda, &genImpl));
+    });
   }
 };
 
@@ -343,11 +307,12 @@ struct GeneralizedNegativeBinomialSampler {
   MSHADOW_FORCE_INLINE void Sample(const Tensor<xpu, 1, IType>& mu,
                                    const Tensor<xpu, 1, IType>& alpha,
                                    const Tensor<xpu, 1, OType>& out,
-                                   const Tensor<xpu, 1, unsigned>& seed,
-                                         Stream<xpu> *s) {
-    Kernel<SampleGeneralizedNegativeBinomialKernel<xpu>, xpu>
-      ::Launch(s, seed.size(0), mu.size(0), out.size(0), seed.size(0),
-               mu.dptr_, alpha.dptr_, out.dptr_, seed.dptr_);
+                                   RandGenerator<xpu, OType> *pgen,
+                                   Stream<xpu> *s) {
+    RandGenerator<xpu, float> *gen = reinterpret_cast<RandGenerator<xpu, float> *>(pgen);
+    LaunchRNG<SampleGeneralizedNegativeBinomialKernel<xpu>, xpu>(s, gen, out.size(0),
+                                                                 mu.size(0), out.size(0),
+                                                                 mu.dptr_, alpha.dptr_, out.dptr_);
   }
 };
 
diff --git a/src/resource.cc b/src/resource.cc
index d1038dc57c..e195006c36 100644
--- a/src/resource.cc
+++ b/src/resource.cc
@@ -32,6 +32,8 @@
 #include <limits>
 #include <atomic>
 #include "./common/lazy_alloc_array.h"
+#include "./common/random_generator.h"
+#include "./common/utils.h"
 
 namespace mxnet {
 namespace resource {
@@ -88,20 +90,26 @@ class ResourceManagerImpl : public ResourceManager {
       : global_seed_(0) {
     cpu_temp_space_copy_ = dmlc::GetEnv("MXNET_CPU_TEMP_COPY", 4);
     gpu_temp_space_copy_ = dmlc::GetEnv("MXNET_GPU_TEMP_COPY", 1);
+    cpu_native_rand_copy_ = dmlc::GetEnv("MXNET_CPU_NATIVE_RAND_COPY", 1);
+    gpu_native_rand_copy_ = dmlc::GetEnv("MXNET_GPU_NATIVE_RAND_COPY", 4);
     engine_ref_ = Engine::_GetSharedRef();
     storage_ref_ = Storage::_GetSharedRef();
     cpu_rand_.reset(new ResourceRandom<cpu>(
         Context::CPU(), global_seed_));
     cpu_space_.reset(new ResourceTempSpace(
         Context::CPU(), cpu_temp_space_copy_));
+    cpu_native_rand_.reset(new ResourceNativeRandom<cpu>(
+        Context::CPU(), cpu_native_rand_copy_, global_seed_));
   }
   ~ResourceManagerImpl() {
     // need explicit delete, before engine get killed
     cpu_rand_.reset(nullptr);
     cpu_space_.reset(nullptr);
+    cpu_native_rand_.reset(nullptr);
 #if MXNET_USE_CUDA
     gpu_rand_.Clear();
     gpu_space_.Clear();
+    gpu_native_rand_.Clear();
 #endif
     if (engine_ref_ != nullptr) {
       engine_ref_ = nullptr;
@@ -117,6 +125,7 @@ class ResourceManagerImpl : public ResourceManager {
       switch (req.type) {
         case ResourceRequest::kRandom: return cpu_rand_->resource;
         case ResourceRequest::kTempSpace: return cpu_space_->GetNext();
+        case ResourceRequest::kParallelRandom: return cpu_native_rand_->GetNext();
         default: LOG(FATAL) << "Unknown supported type " << req.type;
       }
     } else {
@@ -133,6 +142,11 @@ class ResourceManagerImpl : public ResourceManager {
               return new ResourceTempSpace(ctx, gpu_temp_space_copy_);
             })->GetNext();
         }
+        case ResourceRequest::kParallelRandom: {
+          return gpu_native_rand_.Get(ctx.dev_id, [ctx, this]() {
+            return new ResourceNativeRandom<gpu>(ctx, gpu_native_rand_copy_, global_seed_);
+          })->GetNext();
+        }
         default: LOG(FATAL) << "Unknown supported type " << req.type;
       }
 #else
@@ -146,10 +160,14 @@ class ResourceManagerImpl : public ResourceManager {
   void SeedRandom(uint32_t seed) override {
     global_seed_ = seed;
     cpu_rand_->Seed(global_seed_);
+    cpu_native_rand_->Seed(global_seed_);
 #if MXNET_USE_CUDA
     gpu_rand_.ForEach([seed](size_t i, ResourceRandom<gpu> *p) {
         p->Seed(seed);
       });
+    gpu_native_rand_.ForEach([seed](size_t i, ResourceNativeRandom<gpu> *p) {
+      p->Seed(seed);
+    });
 #endif
   }
 
@@ -205,7 +223,7 @@ class ResourceManagerImpl : public ResourceManager {
     std::vector<SpaceAllocator> space;
     /*! \brief resource representation */
     std::vector<Resource> resource;
-    /*! \brief current pointer to the round roubin alloator */
+    /*! \brief current pointer to the round roubin allocator */
     std::atomic<size_t> curr_ptr;
     /*! \brief constructor */
     explicit ResourceTempSpace(Context ctx, size_t ncopy)
@@ -241,10 +259,82 @@ class ResourceManagerImpl : public ResourceManager {
       return resource[ptr % space.size()];
     }
   };
+
+  // the native random sampler resources
+  template<typename xpu>
+  struct ResourceNativeRandom {
+    /*! \brief the context of the PRNG */
+    Context ctx;
+    /*! \brief pointers to sampler */
+    std::vector<common::random::RandGenerator<xpu> *> sampler;
+    /*! \brief resource representation */
+    std::vector<Resource> resource;
+    /*! \brief current pointer to the round roubin allocator */
+    std::atomic<size_t> curr_ptr;
+    /*! \brief constructor */
+    explicit ResourceNativeRandom(Context ctx, size_t ncopy, uint32_t global_seed)
+        : ctx(ctx), sampler(ncopy), resource(ncopy), curr_ptr(0) {
+      for (size_t i = 0; i < sampler.size(); ++i) {
+        const uint32_t seed = ctx.dev_id + i * kMaxNumGPUs + global_seed * kRandMagic;
+        resource[i].var = Engine::Get()->NewVariable();
+        common::random::RandGenerator<xpu> *r = new common::random::RandGenerator<xpu>();
+        common::random::RandGenerator<xpu>::AllocState(r);
+        Engine::Get()->PushSync(
+        [r, seed](RunContext rctx) {
+          r->Seed(rctx.get_stream<xpu>(), seed);
+        }, ctx, {}, {resource[i].var},
+        FnProperty::kNormal, 0, PROFILER_MESSAGE("ResourceNativeRandomSetSeed"));
+        sampler[i] = r;
+        resource[i].ptr_ = sampler[i];
+        resource[i].req = ResourceRequest(ResourceRequest::kParallelRandom);
+      }
+    }
+    ~ResourceNativeRandom() {
+      for (size_t i = 0; i < sampler.size(); ++i) {
+        common::random::RandGenerator<xpu> *r = sampler[i];
+        Engine::Get()->DeleteVariable(
+        [r](RunContext rctx) {
+          MSHADOW_CATCH_ERROR(common::random::RandGenerator<xpu>::FreeState(r));
+          MSHADOW_CATCH_ERROR(delete r);
+        }, ctx, resource[i].var);
+      }
+    }
+    // set seed to a sampler
+    inline void Seed(uint32_t global_seed) {
+      for (size_t i = 0; i < sampler.size(); ++i) {
+        const uint32_t seed = ctx.dev_id + i * kMaxNumGPUs + global_seed * kRandMagic;
+        common::random::RandGenerator<xpu> *r = sampler[i];
+        Engine::Get()->PushAsync(
+        [r, seed](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+          r->Seed(rctx.get_stream<xpu>(), seed);
+          on_complete();
+        }, ctx, {}, {resource[i].var},
+        FnProperty::kNormal, 0, PROFILER_MESSAGE("ResourceNativeRandomSetSeed"));
+      }
+      // reset pointer to ensure the same result with the same seed.
+      curr_ptr.store(0);
+    }
+    // get next resource in round roubin matter
+    inline Resource GetNext() {
+      const size_t kMaxDigit = std::numeric_limits<size_t>::max() / 2;
+      size_t ptr = ++curr_ptr;
+      // reset ptr to avoid undefined behavior during overflow
+      // usually this won't happen
+      if (ptr > kMaxDigit) {
+        curr_ptr.store((ptr + 1) % sampler.size());
+      }
+      return resource[ptr % sampler.size()];
+    }
+  };
+
   /*! \brief number of copies in CPU temp space */
   int cpu_temp_space_copy_;
   /*! \brief number of copies in GPU temp space */
   int gpu_temp_space_copy_;
+  /*! \brief number of copies in CPU native random sampler */
+  int cpu_native_rand_copy_;
+  /*! \brief number of copies in GPU native random sampler */
+  int gpu_native_rand_copy_;
   /*! \brief Reference to the engine */
   std::shared_ptr<Engine> engine_ref_;
   /*! \brief Reference to the storage */
@@ -255,11 +345,15 @@ class ResourceManagerImpl : public ResourceManager {
   std::unique_ptr<ResourceRandom<cpu> > cpu_rand_;
   /*! \brief CPU temp space resources */
   std::unique_ptr<ResourceTempSpace> cpu_space_;
+  /*! \brief CPU native random number resources */
+  std::unique_ptr<ResourceNativeRandom<cpu> > cpu_native_rand_;
 #if MXNET_USE_CUDA
   /*! \brief random number generator for GPU */
   common::LazyAllocArray<ResourceRandom<gpu> > gpu_rand_;
   /*! \brief temp space for GPU */
   common::LazyAllocArray<ResourceTempSpace> gpu_space_;
+  /*! \brief GPU native (on device) random number resources */
+  common::LazyAllocArray<ResourceNativeRandom<gpu> > gpu_native_rand_;
 #endif
 };
 }  // namespace resource
diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py
index 08302b8f99..db5508deef 100644
--- a/tests/python/unittest/test_module.py
+++ b/tests/python/unittest/test_module.py
@@ -613,7 +613,7 @@ def fm(factor_size, feature_dim, init):
             expected_accuracy = 0.02
         elif optimizer == 'adam':
             # use Sparse Adam to train
-            adam = mx.optimizer.Adam(clip_gradient=5.0, learning_rate=0.001,
+            adam = mx.optimizer.Adam(clip_gradient=5.0, learning_rate=0.0005,
                                      rescale_grad=1.0/batch_size)
             mod.init_optimizer(optimizer=adam)
             if num_epochs is None:
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index d05e3256a1..0230d5f064 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -330,8 +330,8 @@ def check_softmax_with_shape(shape, xpu, preserve_shape=False):
     X = mx.symbol.Variable('X')
     L = mx.symbol.Variable('L')
     Y = mx.symbol.SoftmaxOutput(data=X, label=L, preserve_shape=preserve_shape)
-    x = mx.random.uniform(-1, 1, shape, ctx=mx.cpu()).copyto(xpu)
-    l = mx.random.uniform(-1, 1, shape, ctx=mx.cpu()).copyto(xpu)
+    x = mx.random.uniform(-1, 1, shape, ctx=xpu)
+    l = mx.random.uniform(-1, 1, shape, ctx=xpu)
     l[:] = np_softmax(l.asnumpy())
     grad = mx.nd.empty(shape, ctx = xpu)
     exec1 = Y.bind(xpu, args = [x, l], args_grad = {'X': grad})
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index 6178cbe838..1aa2e22760 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -666,7 +666,7 @@ def get_net(num_hidden, flatten=True):
     loss = Loss(output, l)
     loss = mx.sym.make_loss(loss)
     mod = mx.mod.Module(loss, data_names=('data',), label_names=('label',))
-    mod.fit(data_iter, num_epoch=30, optimizer_params={'learning_rate': 0.005, 'wd': 0.0005},
+    mod.fit(data_iter, num_epoch=60, optimizer_params={'learning_rate': 0.0005, 'wd': 0.0005},
             initializer=mx.init.Xavier(magnitude=2), eval_metric=mx.metric.Loss(),
             optimizer='nadam')
     assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.1
diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py
index a67e2d1113..0efe8e6834 100644
--- a/tests/python/unittest/test_random.py
+++ b/tests/python/unittest/test_random.py
@@ -17,12 +17,16 @@
 
 import os
 import mxnet as mx
+from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf
 import numpy as np
+import scipy.stats as ss
 
 def same(a, b):
     return np.sum(a != b) == 0
 
 def check_with_device(device, dtype):
+    # The thresholds chosen for the tests are too loose. We will rely on the other tests to test the samples from the
+    #  generators.
     tol = 0.1
     symbols = [
         {
@@ -216,6 +220,124 @@ def test_sample_multinomial():
             real_dx[y[i][j]] += 5.0 / rprob[j]
         mx.test_utils.assert_almost_equal(real_dx, dx.asnumpy()[i])
 
+# Test the generators with the chi-square testing
+def test_normal_generator():
+    ctx = mx.context.current_context()
+    for dtype in ['float16', 'float32', 'float64']:
+        for mu, sigma in [(0.0, 1.0), (1.0, 5.0)]:
+            print("ctx=%s, dtype=%s, Mu=%g, Sigma=%g:" % (ctx, dtype, mu, sigma))
+            buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.norm.ppf(x, mu, sigma), 5)
+            generator_mx = lambda x: mx.nd.random.normal(mu, sigma, shape=x, ctx=ctx, dtype=dtype).asnumpy()
+            verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
+            generator_mx_same_seed =\
+                lambda x: np.concatenate(
+                    [mx.nd.random.normal(mu, sigma, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
+                     for _ in range(10)])
+            verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
+
+def test_uniform_generator():
+    ctx = mx.context.current_context()
+    for dtype in ['float16', 'float32', 'float64']:
+        for low, high in [(-1.0, 1.0), (1.0, 3.0)]:
+            print("ctx=%s, dtype=%s, Low=%g, High=%g:" % (ctx, dtype, low, high))
+            buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.uniform.ppf(x, loc=low, scale=high - low), 5)
+            generator_mx = lambda x: mx.nd.random.uniform(low, high, shape=x, ctx=ctx, dtype=dtype).asnumpy()
+            verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
+            generator_mx_same_seed = \
+                lambda x: np.concatenate(
+                    [mx.nd.random.uniform(low, high, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
+                     for _ in range(10)])
+            verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
+
+def test_gamma_generator():
+    ctx = mx.context.current_context()
+    for dtype in ['float16', 'float32', 'float64']:
+        for kappa, theta in [(0.5, 1.0), (1.0, 5.0)]:
+            print("ctx=%s, dtype=%s, Shape=%g, Scale=%g:" % (ctx, dtype, kappa, theta))
+            buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.gamma.ppf(x, a=kappa, loc=0, scale=theta), 5)
+            generator_mx = lambda x: mx.nd.random.gamma(kappa, theta, shape=x, ctx=ctx, dtype=dtype).asnumpy()
+            verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
+            generator_mx_same_seed = \
+                lambda x: np.concatenate(
+                    [mx.nd.random.gamma(kappa, theta, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
+                     for _ in range(10)])
+            verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
+
+def test_exponential_generator():
+    ctx = mx.context.current_context()
+    for dtype in ['float16', 'float32', 'float64']:
+        for scale in [0.1, 1.0]:
+            print("ctx=%s, dtype=%s, Scale=%g:" % (ctx, dtype, scale))
+            buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.expon.ppf(x, loc=0, scale=scale), 5)
+            generator_mx = lambda x: mx.nd.random.exponential(scale, shape=x, ctx=ctx, dtype=dtype).asnumpy()
+            verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
+            generator_mx_same_seed = \
+                lambda x: np.concatenate(
+                    [mx.nd.random.exponential(scale, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
+                     for _ in range(10)])
+            verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
+
+def test_poisson_generator():
+    ctx = mx.context.current_context()
+    for dtype in ['float16', 'float32', 'float64']:
+        for lam in [1, 10]:
+            print("ctx=%s, dtype=%s, Lambda=%d:" % (ctx, dtype, lam))
+            buckets = [(-1.0, lam - 0.5), (lam - 0.5, 2 * lam + 0.5), (2 * lam + 0.5, np.inf)]
+            probs = [ss.poisson.cdf(bucket[1], lam) - ss.poisson.cdf(bucket[0], lam) for bucket in buckets]
+            generator_mx = lambda x: mx.nd.random.poisson(lam, shape=x, ctx=ctx, dtype=dtype).asnumpy()
+            verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
+            generator_mx_same_seed = \
+                lambda x: np.concatenate(
+                    [mx.nd.random.poisson(lam, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
+                     for _ in range(10)])
+            verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
+
+def test_negative_binomial_generator():
+    ctx = mx.context.current_context()
+    for dtype in ['float16', 'float32', 'float64']:
+        success_num = 2
+        success_prob = 0.2
+        print("ctx=%s, dtype=%s, Success Num=%d:, Success Prob=%g" % (ctx, dtype, success_num, success_prob))
+        buckets = [(-1.0, 2.5), (2.5, 5.5), (5.5, 8.5), (8.5, np.inf)]
+        probs = [ss.nbinom.cdf(bucket[1], success_num, success_prob) -
+                 ss.nbinom.cdf(bucket[0], success_num, success_prob) for bucket in buckets]
+        generator_mx = lambda x: mx.nd.random.negative_binomial(success_num, success_prob,
+                                                                shape=x, ctx=ctx, dtype=dtype).asnumpy()
+        verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
+        generator_mx_same_seed = \
+            lambda x: np.concatenate(
+                [mx.nd.random.negative_binomial(success_num, success_prob, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
+                 for _ in range(10)])
+        verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
+        # Also test the Gamm-Poisson Mixture
+        print('Gamm-Poisson Mixture Test:')
+        alpha = 1.0 / success_num
+        mu = (1.0 - success_prob) / success_prob / alpha
+        generator_mx = lambda x: mx.nd.random.generalized_negative_binomial(mu, alpha,
+                                                                            shape=x, ctx=ctx, dtype=dtype).asnumpy()
+        verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
+        generator_mx_same_seed = \
+            lambda x: np.concatenate(
+                [mx.nd.random.generalized_negative_binomial(mu, alpha, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
+                 for _ in range(10)])
+        verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
+
+def test_multinomial_generator():
+    ctx = mx.context.current_context()
+    probs = [0.1, 0.2, 0.3, 0.05, 0.15, 0.2]
+    buckets = list(range(6))
+    for dtype in ['float16', 'float32', 'float64']:
+        print("ctx=%s, dtype=%s" %(ctx, dtype))
+        generator_mx = lambda x: mx.nd.random.multinomial(data=mx.nd.array(np.array(probs), ctx=ctx, dtype=dtype),
+                                                          shape=x).asnumpy()
+        verify_generator(generator_mx, buckets, probs)
+        generator_mx_same_seed = \
+            lambda x: np.concatenate(
+                [mx.nd.random.multinomial(data=mx.nd.array(np.array(probs), ctx=ctx, dtype=dtype),
+                                                          shape=x // 10).asnumpy()
+                 for _ in range(10)])
+        verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
+
 
 if __name__ == '__main__':
     import nose


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services