You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by di...@apache.org on 2021/05/25 01:40:31 UTC

[incubator-mxnet] branch master updated: [FEATURE] Use RTC for reduction ops (#19426)

This is an automated email from the ASF dual-hosted git repository.

dickjc123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 57d0ace  [FEATURE] Use RTC for reduction ops (#19426)
57d0ace is described below

commit 57d0ace3a9b89032896822dd27af870a9b446d83
Author: Przemyslaw Tredak <pt...@nvidia.com>
AuthorDate: Mon May 24 18:38:16 2021 -0700

    [FEATURE] Use RTC for reduction ops (#19426)
    
    * Initial rebase
    
    * Fixes after merge
    
    * Fixes
    
    * Fix lint
    
    * Fix lint for real
    
    * Cleaning and code reuse
    
    * Fix lint
    
    * Try to WAR the maybe-uninitialized warning
    
    * Second try
    
    * Fix Windows compilation
    
    * More fixes for Windows compilation
    
    * Breaking the strings to please Windows compiler
    
    * Do not use the default stream in kron
    
    * Fix argmin/argmax
    
    * Fix layernorm
---
 src/common/cuda/rtc.cc                             |   5 +-
 src/common/cuda/rtc/backward_functions-inl.h       | 240 ++++++++----
 src/common/cuda/rtc/forward_functions-inl.h        |   9 +
 src/common/cuda/rtc/reducer-inl.h                  | 316 +++++++++++-----
 src/common/cuda/rtc/special_functions-inl.h        |   5 -
 src/common/cuda/rtc/util-inl.h                     | 138 +++++++
 src/operator/mshadow_op.h                          |  10 +-
 src/operator/nn/group_norm-inl.h                   | 153 +++++---
 src/operator/nn/layer_norm-inl.h                   | 226 ++++-------
 src/operator/nn/layer_norm.cc                      | 116 ++++++
 src/operator/nn/layer_norm.cu                      |  83 +++++
 src/operator/nn/moments-inl.h                      |  15 +-
 .../linalg/broadcast_reduce_customized-inl.cuh     | 415 ---------------------
 .../numpy/linalg/broadcast_reduce_customized-inl.h |   7 -
 .../numpy/linalg/broadcast_reduce_op_customized.h  |  24 +-
 src/operator/numpy/linalg/np_norm-inl.h            | 123 +++---
 src/operator/numpy/np_broadcast_reduce_op.cc       |  85 +++++
 src/operator/numpy/np_broadcast_reduce_op.cuh      |  44 ---
 src/operator/numpy/np_broadcast_reduce_op.h        | 262 +++++++------
 .../numpy/np_broadcast_reduce_op_boolean.cu        |   8 +-
 src/operator/numpy/np_broadcast_reduce_op_index.cu |   4 +-
 src/operator/numpy/np_broadcast_reduce_op_value.cu |  16 +-
 src/operator/numpy/np_constraint_check.h           |   5 +
 src/operator/numpy/np_cross-inl.h                  |  10 +-
 src/operator/numpy/np_elemwise_broadcast_op.h      |   4 +-
 src/operator/numpy/np_kron-inl.h                   |  33 +-
 src/operator/numpy/np_tensordot_op-inl.h           |  18 +-
 src/operator/numpy/np_where_op-inl.h               |  27 +-
 src/operator/numpy/random/dist_common.h            |  80 ++++
 src/operator/numpy/random/np_exponential_op.h      |   7 +-
 src/operator/numpy/random/np_gamma_op.h            |   7 +-
 src/operator/numpy/random/np_location_scale_op.h   |  72 +---
 src/operator/numpy/random/np_normal_op.h           |  74 +---
 src/operator/numpy/random/np_pareto_op.h           |  28 +-
 src/operator/numpy/random/np_rayleigh_op.h         |  28 +-
 src/operator/numpy/random/np_weibull_op.h          |  28 +-
 src/operator/quantization/quantization_utils.h     |   2 +-
 src/operator/quantization/quantize_v2-inl.h        |   7 +
 src/operator/quantization/requantize-inl.h         |  13 +
 src/operator/random/pdf_op.h                       |  34 +-
 src/operator/tensor/broadcast_reduce-inl.cuh       | 414 --------------------
 src/operator/tensor/broadcast_reduce-inl.h         |  31 +-
 .../tensor/broadcast_reduce_minmax_value.cu        |   6 +-
 src/operator/tensor/broadcast_reduce_op.cc         | 187 ++++++++++
 src/operator/tensor/broadcast_reduce_op.h          |  62 ++-
 src/operator/tensor/broadcast_reduce_op_value.cu   |   3 +-
 src/operator/tensor/broadcast_reduce_prod_value.cu |   6 +-
 src/operator/tensor/broadcast_reduce_sum_value.cu  |   9 +-
 .../tensor/elemwise_binary_broadcast_op.cc         |  12 +-
 src/operator/tensor/elemwise_binary_broadcast_op.h |   4 +-
 src/operator/tensor/matrix_op-inl.h                |  10 +
 src/operator/tensor/reduce_rtc.cc                  | 316 ++++++++++------
 tests/python/unittest/test_numpy_op.py             |  10 +-
 53 files changed, 1944 insertions(+), 1907 deletions(-)

diff --git a/src/common/cuda/rtc.cc b/src/common/cuda/rtc.cc
index 2284bee..af4abbe 100644
--- a/src/common/cuda/rtc.cc
+++ b/src/common/cuda/rtc.cc
@@ -150,13 +150,16 @@ CUfunction get_function(const std::string &parameters,
         std::string(fp16_support_string) + "\n" +
         type_support_string + "\n" +
         util_string + "\n" +
+        limits + "\n" +
         special_functions_definitions + '\n' +
         vectorization_support_string + "\n" +
         function_definitions_util + "\n" +
         function_definitions_binary + "\n" +
         function_definitions_unary + "\n" +
         backward_function_definitions + "\n" +
-        reducer + "\n";
+        grad_function_definitions + "\n" +
+        reducer + "\n" +
+        logic_reducer + "\n";
     std::string code_with_header = common_header + parameters + code;
     // If verbose mode, output kernel source, though not including the common header
     if (dmlc::GetEnv("MXNET_RTC_VERBOSE", false)) {
diff --git a/src/common/cuda/rtc/backward_functions-inl.h b/src/common/cuda/rtc/backward_functions-inl.h
index 64ec251..cb1bae8 100644
--- a/src/common/cuda/rtc/backward_functions-inl.h
+++ b/src/common/cuda/rtc/backward_functions-inl.h
@@ -238,6 +238,98 @@ backward_square(const DTypeGrad grad, const DType val) {
 }
 
 template <typename DType, typename DType2>
+__device__ inline DType div_rgrad(const DType val,
+                                  const DType2 val2) {
+  return -val / (val2 * val2);
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
+backward_clip(const DTypeGrad grad, const DType val,
+              const float a_min, const float a_max) {
+  if (val > a_max || val < a_min) {
+    return 0;
+  } else {
+    return grad;
+  }
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
+backward_reciprocal(const DTypeGrad grad, const DType val) {
+  return -grad / (val * val);
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
+backward_erf(const DTypeGrad grad, const DType val) {
+  using type = mixed_type<DTypeGrad, DType>;
+  const type v = val;
+  constexpr type my_pi = pi;
+  return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad;
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
+backward_erfinv(const DTypeGrad grad, const DType val) {
+  using type = mixed_type<DTypeGrad, DType>;
+  constexpr type my_pi = pi;
+  const type g = grad;
+  const type v = val;
+  return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g;
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
+backward_gamma(const DTypeGrad grad, const DType val) {
+  using type = mixed_type<DTypeGrad, DType>;
+  const type v = val;
+  if (type_util::is_same<DTypeGrad, double>::value) {
+    return grad * op::gamma(v) * op::special_functions::cephes::psi<double>(v);
+  } else {
+    return grad * op::gamma(v) * op::special_functions::cephes::psi<float>(v);
+  }
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
+backward_gammaln(const DTypeGrad grad, const DType val) {
+  using type = mixed_type<DTypeGrad, DType>;
+  const type v = val;
+  if (type_util::is_same<DTypeGrad, double>::value) {
+    return grad * op::special_functions::cephes::psi<double>(v);
+  } else {
+    return grad * op::special_functions::cephes::psi<float>(v);
+  }
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
+backward_digamma(const DTypeGrad grad, const DType val) {
+  using type = mixed_type<DTypeGrad, DType>;
+  const type v = val;
+  if (type_util::is_same<DTypeGrad, double>::value) {
+    return grad * op::special_functions::trigamma<double>(v);
+  } else {
+    return grad * op::special_functions::trigamma<float>(v);
+  }
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
+backward_gelu(const DTypeGrad grad, const DType val) {
+  return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) +
+                 val * backward_erf(grad, val / op::sqrt(2.0f)) / op::sqrt(2.0f));
+}
+
+}  // namespace op
+
+)code";
+
+const char grad_function_definitions[] = R"code(
+namespace op {
+
+template <typename DType, typename DType2>
 __device__ inline mixed_type<DType, DType2>
 rdiv_grad(const DType val,
           const DType2 val2) {
@@ -253,12 +345,6 @@ div_grad(const DType val,
 }
 
 template <typename DType, typename DType2>
-__device__ inline DType div_rgrad(const DType val,
-                                  const DType2 val2) {
-  return -val / (val2 * val2);
-}
-
-template <typename DType, typename DType2>
 __device__ inline DType mod_grad(const DType val,
                                  const DType2 val2) {
   if (type_util::is_integral<DType>::value) {
@@ -368,80 +454,6 @@ rldexp_grad(const DType val,
   return val2 * op::power(static_cast<type>(2), val) * op::log(static_cast<type>(2));
 }
 
-template <typename DType, typename DTypeGrad>
-__device__ inline mixed_type<DTypeGrad, DType>
-backward_clip(const DTypeGrad grad, const DType val,
-              const float a_min, const float a_max) {
-  if (val > a_max || val < a_min) {
-    return 0;
-  } else {
-    return grad;
-  }
-}
-
-template <typename DType, typename DTypeGrad>
-__device__ inline mixed_type<DTypeGrad, DType>
-backward_reciprocal(const DTypeGrad grad, const DType val) {
-  return -grad / (val * val);
-}
-
-template <typename DType, typename DTypeGrad>
-__device__ inline mixed_type<DTypeGrad, DType>
-backward_erf(const DTypeGrad grad, const DType val) {
-  const mixed_type<DTypeGrad, DType> v = val;
-  constexpr mixed_type<DTypeGrad, DType> my_pi = pi;
-  return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad;
-}
-
-template <typename DType, typename DTypeGrad>
-__device__ inline mixed_type<DTypeGrad, DType>
-backward_erfinv(const DTypeGrad grad, const DType val) {
-  constexpr mixed_type<DTypeGrad, DType> my_pi = pi;
-  const mixed_type<DTypeGrad, DType> g = grad;
-  const mixed_type<DTypeGrad, DType> v = val;
-  return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g;
-}
-
-template <typename DType, typename DTypeGrad>
-__device__ inline mixed_type<DTypeGrad, DType>
-backward_gamma(const DTypeGrad grad, const DType val) {
-  const mixed_type<DTypeGrad, DType> v = val;
-  if (type_util::is_same<DTypeGrad, double>::value) {
-    return grad * op::gamma(v) * op::special_functions::cephes::psi<double>(v);
-  } else {
-    return grad * op::gamma(v) * op::special_functions::cephes::psi<float>(v);
-  }
-}
-
-template <typename DType, typename DTypeGrad>
-__device__ inline mixed_type<DTypeGrad, DType>
-backward_gammaln(const DTypeGrad grad, const DType val) {
-  const mixed_type<DTypeGrad, DType> v = val;
-  if (type_util::is_same<DTypeGrad, double>::value) {
-    return grad * op::special_functions::cephes::psi<double>(v);
-  } else {
-    return grad * op::special_functions::cephes::psi<float>(v);
-  }
-}
-
-template <typename DType, typename DTypeGrad>
-__device__ inline mixed_type<DTypeGrad, DType>
-backward_digamma(const DTypeGrad grad, const DType val) {
-  const mixed_type<DTypeGrad, DType> v = val;
-  if (type_util::is_same<DTypeGrad, double>::value) {
-    return grad * op::special_functions::trigamma<double>(v);
-  } else {
-    return grad * op::special_functions::trigamma<float>(v);
-  }
-}
-
-template <typename DType, typename DTypeGrad>
-__device__ inline mixed_type<DTypeGrad, DType>
-backward_gelu(const DTypeGrad grad, const DType val) {
-  return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) +
-                 val * backward_erf(grad, val / op::sqrt(2.0f)) / op::sqrt(2.0f));
-}
-
 template <typename DType, typename DType2>
 __device__ inline DType smooth_l1_grad(const DType val, const DType2 scalar) {
   auto bsq = scalar * scalar;
@@ -467,8 +479,74 @@ __device__ inline DType prelu_grad(const DType val,
   return (val > 0) ? 0 : val;
 }
 
-}  // namespace op
+template <typename DType, typename DType2>
+__device__ inline mixed_type<DType2, DType>
+gamma_implicit_grad(const DType a_in, const DType2 x_in) {
+    using OType = mixed_type<DType2, DType>;
+    const OType a = a_in;
+    const OType x = x_in;
+    if (x < 0.8f) {
+      OType numer = 1;
+      OType denom = a;
+      OType series1 = numer / denom;
+      OType series2 = numer / (denom * denom);
+#pragma unroll
+      for (int i = 1; i <= 5; i++) {
+        numer *= -x / static_cast<DType>(i);
+        denom += 1;
+        series1 += numer / denom;
+        series2 += numer / (denom * denom);
+      }
+      OType pow_x_alpha = op::power(x, a);
+      OType gamma_pdf = op::power(x, a - 1) * op::exp(-x);
+      OType gamma_cdf = pow_x_alpha * series1;
+      OType gamma_cdf_alpha =
+          (op::log(x) - OType(special_functions::cephes::psi<float>(a))) *
+              gamma_cdf -
+          pow_x_alpha * series2;
+      OType result = -gamma_cdf_alpha / gamma_pdf;
+      return op::isnan(result) ? 0.f : result;
+    }
+    if (a > 8.0f) {
+      if (0.9f * a <= x && x <= 1.1f * a) {
+        OType numer_1 = 1 + 24 * a * (1 + 12 * a);
+        OType numer_2 = 1440 * (a * a) + 6 * x * (53 - 120 * x) -
+                        65 * x * x / a + a * (107 + 3600 * x);
+        OType denom = 1244160 * (a * a) * (a * a);
+        return numer_1 * numer_2 / denom;
+      }
+      OType denom = op::sqrt(8 * a);
+      OType term2 = denom / (a - x);
+      OType term3 =
+          op::power(x - a - a * op::log(x / a), static_cast<OType>(-1.5));
+      OType term23 = (x < a) ? term2 - term3 : term2 + term3;
+      OType term1 = op::log(x / a) * term23 -
+                    op::sqrt(2 / a) * (a + x) / ((a - x) * (a - x));
+      OType stirling = 1.f + 1.f / (12.f * a) * (1.f + 1.f / (24.f * a));
+      OType numer = x * term1;
+      return -stirling * numer / denom;
+    }
+    OType u = op::log(x / a);
+    OType v = op::log(a);
+    OType coef_uv[3][8] = {
+        {0.16009398, -0.094634809, 0.025146376, -0.0030648343, 1, 0.32668115,
+         0.10406089, 0.0014179084},
+        {0.53487893, 0.1298071, 0.065735949, -0.0015649758, 0.16639465,
+         0.020070113, -0.0035938915, -0.00058392623},
+        {0.040121004, -0.0065914022, -0.0026286047, -0.0013441777, 0.017050642,
+         -0.0021309326, 0.00085092367, -1.5247877e-07},
+    };
+    OType coef_v[8];
+#pragma unroll
+    for (int i = 0; i < 8; i++) {
+      coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
+    }
+    OType p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
+    OType q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
+    return op::exp(p / q);
+}
 
+}  // namespace op
 )code";
 
 }  // namespace rtc
diff --git a/src/common/cuda/rtc/forward_functions-inl.h b/src/common/cuda/rtc/forward_functions-inl.h
index 9018a5d..4f87db6 100644
--- a/src/common/cuda/rtc/forward_functions-inl.h
+++ b/src/common/cuda/rtc/forward_functions-inl.h
@@ -696,6 +696,10 @@ __device__ inline DType log_sigmoid(const DType val) {
 
 template <typename DType>
 __device__ inline DType softrelu(const DType val) {
+  // Avoid overflow of exp for large inputs.
+  // The threshold 20 is chosen such that softrelu(a) = a
+  // for a > 20 using floating precision.
+  if (val > 20) return val;
   if (type_util::has_double_or_integral<DType>::value) {
     return ::log(1 + ::exp(val));
   } else {
@@ -936,6 +940,11 @@ __device__ inline bool_t np_logical_not(const DType val) {
   return !static_cast<bool>(val);
 }
 
+template <typename DType>
+__device__ inline bool_t NonZero(const DType val) {
+  return val != 0;
+}
+
 #undef DEFINE_UNARY_MATH_FUNC
 
 template <typename DType>
diff --git a/src/common/cuda/rtc/reducer-inl.h b/src/common/cuda/rtc/reducer-inl.h
index 259d0e0..f5b70d8 100644
--- a/src/common/cuda/rtc/reducer-inl.h
+++ b/src/common/cuda/rtc/reducer-inl.h
@@ -27,11 +27,10 @@ namespace common {
 namespace cuda {
 namespace rtc {
 
-const char reducer[] = R"code(
 
+const char reducer[] = R"code(
 namespace red {
 
-/*! \brief sum reducer */
 struct sum {
   /*! \brief do reduction into dst */
   template<typename DType, typename DType2>
@@ -95,103 +94,6 @@ struct sum {
   }
 };
 
-/*! \brief maximum reducer */
-struct maximum {
-  /*! \brief do reduction into dst */
-  template<typename DType, typename DType2>
-  __device__ inline static void Reduce(volatile DType& dst,  volatile DType2 src) { // NOLINT(*)
-    if (!util::isnan(dst)) {
-      if (!(dst >= src)) dst = src;
-    }
-  }
-  /*! \brief do reduction into dst */
-  template<typename DType, typename DType2>
-  __device__ inline static void Reduce(volatile DType& dst,  volatile DType2 src,
-                                       volatile DType& none) {
-    Reduce(dst, src);
-  }
-  /*! \brief combine the results of two reducers */
-  template<typename DType>
-  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
-    Reduce(dst_val, src_val);
-  }
-  /*! \brief combine the results of two reducers */
-  template<typename DType>
-  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
-                                      volatile DType& src_val, volatile DType& src_residual) {
-    Reduce(dst_val, src_val);
-  }
-  /*! \brief finalize reduction result */
-  template<typename DType>
-  __device__ inline static void Finalize(volatile DType& dst) {}
-  /*! \brief finalize reduction result */
-  template<typename DType>
-  __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
-  /*!
-   *\brief set the initial value during reduction
-   */
-  template<typename DType>
-  __device__ inline static void SetInitValue(DType &initv) {
-    initv = -2*DBL_MAX;
-  }
-  /*!
-   *\brief set the initial value during reduction
-   */
-  template<typename DType>
-  __device__ inline static void SetInitValue(DType &initv, DType &none) {
-    SetInitValue(initv);
-  }
-};
-
-/*! \brief minimum reducer */
-struct minimum {
-  /*! \brief do reduction into dst */
-  template<typename DType, typename DType2>
-  __device__ inline static void Reduce(volatile DType& dst,  volatile DType2 src) {
-    if (!util::isnan(dst)) {
-      if (!(dst <= src)) dst = src;
-    }
-  }
-  /*! \brief do reduction into dst */
-  template<typename DType, typename DType2>
-  __device__ inline static void Reduce(volatile DType& dst,  volatile DType2 src,
-                                       volatile DType& none) {
-    Reduce(dst, src);
-  }
-  /*! \brief combine the results of two reducers */
-  template<typename DType>
-  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
-    Reduce(dst_val, src_val);
-  }
-  /*! \brief combine the results of two reducers */
-  template<typename DType>
-  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
-                                      volatile DType& src_val, volatile DType& src_residual) {
-    Reduce(dst_val, src_val);
-  }
-  /*! \brief finalize reduction result */
-  template<typename DType>
-  __device__ inline static void Finalize(volatile DType& dst) {}
-  /*! \brief finalize reduction result */
-  template<typename DType>
-  __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
-  /*!
-   *\brief set the initial value during reduction
-   */
-  template<typename DType>
-  __device__ inline static void SetInitValue(DType &initv) {
-    initv = 2*DBL_MAX;
-  }
-  /*!
-   *\brief set the initial value during reduction
-   */
-  template<typename DType>
-  __device__ inline static void SetInitValue(DType &initv, DType &none) {
-    SetInitValue(initv);
-  }
-};
-
-/*! \brief product reducer */
 struct product {
   /*! \brief do reduction into dst */
   template<typename DType, typename DType2>
@@ -237,7 +139,6 @@ struct product {
   }
 };
 
-/*! \brief sum reducer that ignores NaN values in the input */
 struct nansum {
   /*! \brief do reduction into dst */
   template<typename DType, typename DType2>
@@ -293,7 +194,6 @@ struct nansum {
   }
 };
 
-/*! \brief product reducer that ignores NaN values in the input */
 struct nanprod {
   /*! \brief do reduction into dst */
   template<typename DType, typename DType2>
@@ -493,10 +393,222 @@ struct nrmlp {
     scale = 0;
   }
 };
-}  // namespace red
 
+}  // namespace red
 )code";
 
+const char logic_reducer[] = R"code(
+namespace red {
+
+struct maximum {
+  /*! \brief do reduction into dst */
+  template<typename DType, typename DType2>
+  __device__ inline static void Reduce(volatile DType& dst,  volatile DType2 src) { // NOLINT(*)
+    if (!util::isnan(dst)) {
+      if (!(dst >= src)) dst = src;
+    }
+  }
+  /*! \brief do reduction into dst */
+  template<typename DType, typename DType2>
+  __device__ inline static void Reduce(volatile DType& dst,  volatile DType2 src,
+                                       volatile DType& none) {
+    Reduce(dst, src);
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
+    Reduce(dst_val, src_val);
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
+                                      volatile DType& src_val, volatile DType& src_residual) {
+    Reduce(dst_val, src_val);
+  }
+  /*! \brief finalize reduction result */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& dst) {}
+  /*! \brief finalize reduction result */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
+  /*!
+   *\brief set the initial value during reduction
+   */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &initv) {
+    initv = limits::NegInfValue<DType>();
+  }
+  /*!
+   *\brief set the initial value during reduction
+   */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &initv, DType &none) {
+    SetInitValue(initv);
+  }
+};
+
+struct minimum {
+  /*! \brief do reduction into dst */
+  template<typename DType, typename DType2>
+  __device__ inline static void Reduce(volatile DType& dst,  volatile DType2 src) {
+    if (!util::isnan(dst)) {
+      if (!(dst <= src)) dst = src;
+    }
+  }
+  /*! \brief do reduction into dst */
+  template<typename DType, typename DType2>
+  __device__ inline static void Reduce(volatile DType& dst,  volatile DType2 src,
+                                       volatile DType& none) {
+    Reduce(dst, src);
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
+    Reduce(dst_val, src_val);
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
+                                      volatile DType& src_val, volatile DType& src_residual) {
+    Reduce(dst_val, src_val);
+  }
+  /*! \brief finalize reduction result */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& dst) {}
+  /*! \brief finalize reduction result */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
+  /*!
+   *\brief set the initial value during reduction
+   */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &initv) {
+    initv = limits::PosInfValue<DType>();
+  }
+  /*!
+   *\brief set the initial value during reduction
+   */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &initv, DType &none) {
+    SetInitValue(initv);
+  }
+};
+
+struct argmax {
+  /*! \brief do reduction into dst */
+  template<typename AType, typename DType>
+  __device__ inline static void Reduce(volatile AType& dst,  volatile DType src) {
+    if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
+      dst.num = src.num;
+      dst.idx = src.idx;
+    }
+  }
+  /*! \brief do stable reduction into dst */
+  template<typename AType, typename DType>
+  __device__ inline static void Reduce(volatile AType& dst,  volatile DType src,
+                                       volatile DType&) {
+    if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
+      dst.num = src.num;
+      dst.idx = src.idx;
+    }
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst, volatile DType& src) {
+    if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
+      dst.num = src.num;
+      dst.idx = src.idx;
+    }
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst, volatile DType&,
+                                      volatile DType& src, volatile DType&) {
+    if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
+      dst.num = src.num;
+      dst.idx = src.idx;
+    }
+  }
+  /*! \brief finalize reduction */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& dst) {}
+  /*! \brief finalize reduction */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& dst, volatile DType&) {}
+  /*!
+   *\brief set the initial value during reduction
+   */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &initv) {
+    initv.num = limits::NegInfValue<decltype(initv.num)>();
+  }
+  /*!
+   *\brief set the initial value during reduction
+   */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &initv, DType &) {
+    initv.num = limits::NegInfValue<decltype(initv.num)>();
+  }
+};
+
+struct argmin {
+  /*! \brief do reduction into dst */
+  template<typename AType, typename DType>
+  __device__ inline static void Reduce(volatile AType& dst,  volatile DType src) {
+    if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) {
+      dst.num = src.num;
+      dst.idx = src.idx;
+    }
+  }
+  /*! \brief do stable reduction into dst */
+  template<typename AType, typename DType>
+  __device__ inline static void Reduce(volatile AType& dst,  volatile DType src,
+                                       volatile DType& residual) {
+    if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) {
+      dst.num = src.num;
+      dst.idx = src.idx;
+    }
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst, volatile DType& src) {
+    if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) {
+      dst.num = src.num;
+      dst.idx = src.idx;
+    }
+  }
+  /*! \brief combine the results of two reducers */
+  template<typename DType>
+  __device__ inline static void Merge(volatile DType& dst, volatile DType&,
+                                      volatile DType& src, volatile DType&) {
+    if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) {
+      dst.num = src.num;
+      dst.idx = src.idx;
+    }
+  }
+  /*! \brief finalize reduction */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& dst) {}
+  /*! \brief finalize reduction */
+  template<typename DType>
+  __device__ inline static void Finalize(volatile DType& dst, volatile DType& residual) {}
+  /*!
+   *\brief set the initial value during reduction
+   */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &initv) {
+    initv.num = limits::PosInfValue<decltype(initv.num)>();
+  }
+  /*!
+   *\brief set the initial value during reduction
+   */
+  template<typename DType>
+  __device__ inline static void SetInitValue(DType &initv, DType &residual) {
+    initv.num = limits::PosInfValue<decltype(initv.num)>();
+  }
+};
+}  // namespace red
+)code";
 }  // namespace rtc
 }  // namespace cuda
 }  // namespace common
diff --git a/src/common/cuda/rtc/special_functions-inl.h b/src/common/cuda/rtc/special_functions-inl.h
index d64afb5..7110e47 100644
--- a/src/common/cuda/rtc/special_functions-inl.h
+++ b/src/common/cuda/rtc/special_functions-inl.h
@@ -50,11 +50,6 @@ namespace rtc {
 // Direct inquiries to 30 Frost Street, Cambridge, MA 02140
 //
 const char special_functions_definitions[] = R"code(
-constexpr double DBL_MAX = 1.7976931348623157081e+308;
-constexpr float FLT_MAX = 3.4028234663852885981e+38;
-#define inf ((float)1e50)
-#define nan (inf - inf)
-
 namespace op {
 
 namespace special_functions {
diff --git a/src/common/cuda/rtc/util-inl.h b/src/common/cuda/rtc/util-inl.h
index b426603..bafa8cf 100644
--- a/src/common/cuda/rtc/util-inl.h
+++ b/src/common/cuda/rtc/util-inl.h
@@ -446,6 +446,144 @@ __device__ inline T strided_grouped_warp_allreduce(T value, OP redfun, const int
 
 }  // namespace util
 )code";
+
+const char limits[] = R"code(
+constexpr double DBL_MAX = 1.7976931348623157081e+308;
+constexpr float FLT_MAX = 3.4028234663852885981e+38;
+#define inf ((float)1e50)
+#define nan (inf - inf)
+
+namespace limits {
+
+template<typename DType>
+__device__ inline DType MinValue(void);
+
+template<>
+__device__ inline float MinValue<float>(void) {
+  return -FLT_MAX;
+}
+/*! \brief minimum value of double */
+template<>
+__device__ inline double MinValue<double>(void) {
+  return -DBL_MAX;
+}
+/*! \brief minimum value of uint8 */
+template<>
+__device__ inline uint8 MinValue<uint8>(void) {
+  return 0;
+}
+/*! \brief minimum value of int8_t */
+template<>
+__device__ inline int8 MinValue<int8>(void) {
+  return -128;
+}
+/*! \brief minimum value of int32 */
+template<>
+__device__ inline int32 MinValue<int32>(void) {
+  return -2147483648;
+}
+/*! \brief minimum value of int64_t */
+template<>
+__device__ inline int64 MinValue<int64>(void) {
+  return -9223372036854775808LL;
+}
+/*! \brief minimum value of bool */
+template<>
+__device__ inline bool MinValue<bool>(void) {
+  return false;
+}
+/*! \brief minimum value of bool_t */
+template<>
+__device__ inline bool_t MinValue<bool_t>(void) {
+  return MinValue<index_t>();
+}
+
+/*!
+ * \brief negative infinity of certain types
+ * \tparam DType data type
+ */
+template<typename DType>
+__device__ inline DType NegInfValue(void) {
+  return MinValue<DType>();
+}
+/*! \brief negative infinity value of float */
+template<>
+__device__ inline float NegInfValue<float>(void) {
+  return -inf;
+}
+/*! \brief negative infinity value of double */
+template<>
+__device__ inline double NegInfValue<double>(void) {
+  return -inf;
+}
+
+/*!
+ * \brief maximum value of certain types
+ * \tparam DType data type
+ */
+template<typename DType>
+__device__ inline DType MaxValue(void);
+/*! \brief maximum value of float */
+template<>
+__device__ inline float MaxValue<float>(void) {
+  return FLT_MAX;
+}
+/*! \brief maximum value of double */
+template<>
+__device__ inline double MaxValue<double>(void) {
+  return DBL_MAX;
+}
+/*! \brief maximum value of uint8 */
+template<>
+__device__ inline uint8 MaxValue<uint8>(void) {
+  return 255;
+}
+/*! \brief maximum value of int8 */
+template<>
+__device__ inline int8 MaxValue<int8>(void) {
+  return 127;
+}
+/*! \brief maximum value of int32 */
+template<>
+__device__ inline int32 MaxValue<int32>(void) {
+  return 2147483647;
+}
+/*! \brief maximum value of int64 */
+template<>
+__device__ inline int64 MaxValue<int64>(void) {
+  return 9223372036854775807LL;
+}
+/*! \brief maximum value of bool */
+template<>
+__device__ inline bool MaxValue<bool>(void) {
+  return true;
+}
+/*! \brief maximum value of bool_t */
+template<>
+__device__ inline bool_t MaxValue<bool_t>(void) {
+  return MaxValue<index_t>();
+}
+/*!
+ * \brief positive infinity of certain types
+ * \tparam DType data type
+ */
+template<typename DType>
+__device__ inline DType PosInfValue(void) {
+  return MaxValue<DType>();
+}
+/*! \brief positive infinity value of float */
+template<>
+__device__ inline float PosInfValue<float>(void) {
+  return inf;
+}
+/*! \brief positive infinity value of double */
+template<>
+__device__ inline double PosInfValue<double>(void) {
+  return inf;
+}
+
+}  // namespace limits
+)code";
 }  // namespace rtc
 }  // namespace cuda
 }  // namespace common
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index c33dad4..be49f0f 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -940,7 +940,7 @@ template<>
 MSHADOW_XINLINE mshadow::half::half_t mod_rgrad::Map<mshadow::half::half_t>
                                                     (mshadow::half::half_t a,
                                                      mshadow::half::half_t b) {
-  return mshadow::half::half_t(-::floorf(static_cast<float>(a/b)));
+  return mshadow::half::half_t(-::floorf(static_cast<float>(a)/static_cast<float>(b)));
 }
 
 struct rmod : public mxnet_op::tunable {
@@ -1573,7 +1573,7 @@ struct argmax {
   /*! \brief do reduction into dst */
   template<typename AType, typename DType>
   MSHADOW_XINLINE static void Reduce(volatile AType& dst,  volatile DType src) { // NOLINT(*)
-    if (dst.num < src.num) {
+    if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
       dst.num = src.num;
       dst.idx = src.idx;
     }
@@ -1581,7 +1581,7 @@ struct argmax {
   /*! \brief do stable reduction into dst */
   template<typename AType, typename DType>
   MSHADOW_XINLINE static void Reduce(volatile AType& dst,  volatile DType src, volatile DType& residual) { // NOLINT(*)
-    if (dst.num < src.num) {
+    if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
       dst.num = src.num;
       dst.idx = src.idx;
     }
@@ -1589,7 +1589,7 @@ struct argmax {
   /*! \brief combine the results of two reducers */
   template<typename DType>
   MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
-    if (dst_val.num < src_val.num) {
+    if (dst_val.num < src_val.num || (dst_val.num == src_val.num && dst_val.idx > src_val.idx)) {
       dst_val.num = src_val.num;
       dst_val.idx = src_val.idx;
     }
@@ -1597,7 +1597,7 @@ struct argmax {
   /*! \brief combine the results of two reducers */
   template<typename DType>
   MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
-    if (dst_val.num < src_val.num) {
+    if (dst_val.num < src_val.num || (dst_val.num == src_val.num && dst_val.idx > src_val.idx)) {
       dst_val.num = src_val.num;
       dst_val.idx = src_val.idx;
     }
diff --git a/src/operator/nn/group_norm-inl.h b/src/operator/nn/group_norm-inl.h
index da30192..0df0db2 100644
--- a/src/operator/nn/group_norm-inl.h
+++ b/src/operator/nn/group_norm-inl.h
@@ -113,24 +113,29 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs,
 
   Tensor<xpu, 1, char> workspace;
 
-  size_t workspace_size = 0;
-  MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, {
-    workspace_size =
-      broadcast::ReduceWorkspaceSize(s, red_dst_shape, req[0],
-                                     red_src_shape, sizeof(DType));
-  });
+  size_t workspace_size = broadcast::ReduceWorkspaceSize(s, red_dst_shape, req[0],
+                                                         red_src_shape);
 
   workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
 
   // Calculate mean
+#if !defined(__CUDACC__)
   MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, {
     BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
       broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
         s, mean_, req[0], workspace, data_);
-      Tensor<xpu, 1, DType> mean_data_tensor = mean_.FlatTo1D<xpu, DType>(s);
-      mean_data_tensor /= scalar<DType>(channel_size);
     });
   });
+#else
+  BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+    broadcast::RTCReduce(ctx, mean_, req[0], workspace,
+                         data_, "red::sum{}", NDim, "identity");
+  });
+#endif  // !defined(__CUDACC__)
+  MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, {
+    Tensor<xpu, 1, DType> mean_data_tensor = mean_.FlatTo1D<xpu, DType>(s);
+    mean_data_tensor /= scalar<DType>(channel_size);
+  });
 
   TBlob data_grp = data.reshape(temp_data_shape);
   const TBlob& mean_grp = mean.reshape(moments_shape);
@@ -150,15 +155,25 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs,
 
   // Calculate std
   const TBlob centered_out = outputs[groupnorm::kOut].reshape(red_src_shape);
+#if !defined(__CUDACC__)
   MSHADOW_REAL_TYPE_SWITCH(output_grp.type_flag_, DType, {
     BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
       broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::square, true>(
         s, std_, req[0], workspace, centered_out);
-      Tensor<xpu, 1, DType> std_data_tensor = std_.FlatTo1D<xpu, DType>(s);
-      std_data_tensor = F<mshadow_op::square_root>(std_data_tensor / scalar<DType>(channel_size)
-                        + scalar<DType>(param.eps));
     });
   });
+#else
+  BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+    broadcast::RTCReduce(ctx, std_, req[0],
+                         workspace, centered_out,
+                         "red::sum{}", NDim, "square");
+  });
+#endif
+  MSHADOW_REAL_TYPE_SWITCH(output_grp.type_flag_, DType, {
+    Tensor<xpu, 1, DType> std_data_tensor = std_.FlatTo1D<xpu, DType>(s);
+    std_data_tensor = F<mshadow_op::square_root>(std_data_tensor / scalar<DType>(channel_size)
+                      + scalar<DType>(param.eps));
+  });
 
   // Calculate data = data / std
 #if !defined(__CUDACC__)
@@ -263,26 +278,17 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
 
   // Initialize the workspace + Construct the temporary TBlobs
   Tensor<xpu, 1, char> workspace;
-  size_t reduce_workspace_size = 0;
-  size_t data_size = 0;
-  size_t red_out_size = 0;
-  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-    data_size = sizeof(DType) * data.Size();
-    red_out_size = sizeof(DType) * mean.Size();
-    // There are two types of reduction workloads: reduce over axis and reduce exclude axis
-    // We take the maximum of the workspace sizes required by these workloads.
-    // Also, we explicitly set the req_type=kAddto in case we want to use it.
-    reduce_workspace_size =
-      std::max(reduce_workspace_size,
-               broadcast::ReduceWorkspaceSize(s, red_dst_shape,
-                                              kAddTo, red_src_shape,
-                                              sizeof(DType)));
-    reduce_workspace_size =
-      std::max(reduce_workspace_size,
-               broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo,
-                                              red_exclude_src_shape,
-                                              sizeof(DType)));
-  });
+  size_t dtype_size = common::mshadow_type_info(outputs[0].type_flag_).size;
+  size_t data_size = data.Size() * dtype_size;
+  size_t red_out_size = mean.Size() * dtype_size;
+  // There are two types of reduction workloads: reduce over axis and reduce exclude axis
+  // We take the maximum of the workspace sizes required by these workloads.
+  // Also, we explicitly set the req_type=kAddto in case we want to use it.
+  size_t reduce_workspace_size =
+    std::max(broadcast::ReduceWorkspaceSize(s, red_dst_shape,
+                                            kAddTo, red_src_shape),
+             broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo,
+                                            red_exclude_src_shape));
   workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(
     Shape1(reduce_workspace_size + data_size * 2 + red_out_size), s);
   const TBlob normalized_data =
@@ -300,14 +306,6 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
   BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
                                                    {normalized_data, std_},
                                                    {kWriteTo}, {normalized_data});
-#else
-  BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
-                                    {data_, mean_},
-                                    {kWriteTo}, {normalized_data});
-  BinaryBroadcastRTCCompute {"div"}(attrs, ctx,
-                                    {normalized_data, std_},
-                                    {kWriteTo}, {normalized_data});
-#endif  // !defined(__CUDACC__)
   // Calculate grad_beta
   if (req[2] != kNullOp) {
     MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, {
@@ -319,13 +317,8 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
     });
   }
   // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis)
-#if !defined(__CUDACC__)
   ElemwiseBinaryOp::Compute<xpu, op::mshadow_op::mul>(attrs, ctx, {normalized_data, ograd},
                                                       {kWriteTo}, {ograd_mult});
-#else
-  ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd},
-                                   {kWriteTo}, {ograd_mult});
-#endif  // !defined(__CUDACC__)
   if (req[1] != kNullOp) {
     MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, {
       BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
@@ -335,6 +328,32 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
       });
     });
   }
+#else
+  BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
+                                    {data_, mean_},
+                                    {kWriteTo}, {normalized_data});
+  BinaryBroadcastRTCCompute {"div"}(attrs, ctx,
+                                    {normalized_data, std_},
+                                    {kWriteTo}, {normalized_data});
+  // Calculate grad_beta
+  if (req[2] != kNullOp) {
+    BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+      broadcast::RTCReduce(ctx, outputs[2].reshape(red_exclude_dst_shape),
+                           req[2], workspace, ograd.reshape(red_exclude_src_shape),
+                           "red::sum{}", NDim, "identity");
+    });
+  }
+  // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis)
+  ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd},
+                                   {kWriteTo}, {ograd_mult});
+  if (req[1] != kNullOp) {
+    BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+      broadcast::RTCReduce(ctx, outputs[1].reshape(red_exclude_dst_shape),
+                           req[1], workspace, ograd_mult.reshape(red_exclude_src_shape),
+                           "red::sum{}", NDim, "identity");
+    });
+  }
+#endif  // !defined(__CUDACC__)
 
   // Calculate grad_data:
   //   ograd_mult = ograd * gamma / std
@@ -350,15 +369,6 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
     BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
                                                     {ograd_mult, std_},
                                                     {kWriteTo}, {ograd_mult});
-#else
-    BinaryBroadcastRTCCompute {"mul"}(attrs, ctx,
-                                      {inputs[0], gamma},
-                                      {kWriteTo},
-                                      {ograd_mult.reshape(data.shape_)});
-    BinaryBroadcastRTCCompute {"div"}(attrs, ctx,
-                                      {ograd_mult, std_},
-                                      {kWriteTo}, {ograd_mult});
-#endif  // !defined(__CUDACC__)
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
         broadcast::Reduce<mshadow_op::sum, NDim, DType, op::mshadow_op::identity, true>(
@@ -368,19 +378,11 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
       Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
       red_out_tensor /= scalar<DType>(N);
     });
-#if !defined(__CUDACC__)
     BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
                                                       {ograd_mult, red_out},
                                                       {req[0]}, {output_});
     ElemwiseBinaryOp::Compute<xpu, op::mshadow_op::mul>(attrs, ctx, {ograd_mult, normalized_data},
                                                         {kWriteTo}, {ograd_mult});
-#else
-    BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
-                                      {ograd_mult, red_out},
-                                      {req[0]}, {output_});
-    ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data},
-                                     {kWriteTo}, {ograd_mult});
-#endif  // !defined(__CUDACC__)
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
         broadcast::Reduce<mshadow_op::sum, NDim, DType, op::mshadow_op::identity, true>(
@@ -390,12 +392,39 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
       Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
       red_out_tensor /= scalar<DType>(-N);
     });
-#if !defined(__CUDACC__)
     BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
                                                      {normalized_data, red_out},
                                                      {kAddTo}, {output_});
 #else
     BinaryBroadcastRTCCompute {"mul"}(attrs, ctx,
+                                      {inputs[0], gamma},
+                                      {kWriteTo},
+                                      {ograd_mult.reshape(data.shape_)});
+    BinaryBroadcastRTCCompute {"div"}(attrs, ctx,
+                                      {ograd_mult, std_},
+                                      {kWriteTo}, {ograd_mult});
+    BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+      broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+                           ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity");
+    });
+    MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
+      red_out_tensor /= scalar<DType>(N);
+    });
+    BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
+                                      {ograd_mult, red_out},
+                                      {req[0]}, {output_});
+    ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data},
+                                     {kWriteTo}, {ograd_mult});
+    BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+      broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+                           ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity");
+    });
+    MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
+      red_out_tensor /= scalar<DType>(-N);
+    });
+    BinaryBroadcastRTCCompute {"mul"}(attrs, ctx,
                                       {normalized_data, red_out},
                                       {kAddTo}, {output_});
 #endif  // !defined(__CUDACC__)
diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h
index d8c8dbc..79d0906 100644
--- a/src/operator/nn/layer_norm-inl.h
+++ b/src/operator/nn/layer_norm-inl.h
@@ -38,6 +38,7 @@
 #include "../operator_common.h"
 #include "../mxnet_op.h"
 #include "../tensor/broadcast_reduce_op.h"
+#include "mxnet/tuple.h"
 
 namespace mxnet {
 namespace op {
@@ -115,14 +116,11 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
   int channel_size = red_src_shape.Size() / red_dst_shape.Size();
   // Initialize the workspace
   Tensor<xpu, 1, char> workspace;
-  size_t workspace_size = 0;
-  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-    workspace_size =
-      broadcast::ReduceWorkspaceSize(s, mean_data.shape_, req[0],
-                                     in_data.shape_, sizeof(DType));
-  });
+  size_t workspace_size = broadcast::ReduceWorkspaceSize(s, mean_data.shape_, req[0],
+                                                         in_data.shape_);
   workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
 
+#if !defined(__CUDACC__)
   bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true);
   if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) {
     common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for float16 inputs for LayerNorm. "
@@ -145,15 +143,9 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
     });
   });
   // Calculate data = data - mean
-#if !defined(__CUDACC__)
   BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
                                                      {inputs[0], outputs[layernorm::kMean]},
                                                      {kWriteTo}, {outputs[0]});
-#else
-  BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
-                                    {inputs[0], outputs[layernorm::kMean]},
-                                    {kWriteTo}, {outputs[0]});
-#endif  // !defined(__CUDACC__)
   // Calculate std
   const TBlob centered_out = outputs[0].reshape(red_src_shape);
   MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
@@ -170,7 +162,6 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
                         + scalar<DType>(param.eps));
     });
   });
-#if !defined(__CUDACC__)
   // Calculate data = data / std
   BinaryBroadcastCompute<xpu, mshadow_op::div>(attrs, ctx,
                                                {outputs[0], outputs[layernorm::kStd]},
@@ -184,6 +175,30 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
                                                 {outputs[0], beta},
                                                 {kWriteTo}, {outputs[0]});
 #else
+  // Calculate mean
+  BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+    broadcast::RTCReduce(ctx, mean_data, req[0], workspace, in_data,
+                         "red::sum{}", NDim, "identity");
+  });
+  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      Tensor<xpu, 1, DType> mean_data_tensor = mean_data.FlatTo1D<xpu, DType>(s);
+      mean_data_tensor /= scalar<DType>(channel_size);
+  });
+  // Calculate data = data - mean
+  BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
+                                    {inputs[0], outputs[layernorm::kMean]},
+                                    {kWriteTo}, {outputs[0]});
+  // Calculate std
+  const TBlob centered_out = outputs[0].reshape(red_src_shape);
+  BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+    broadcast::RTCReduce(ctx, std_data, req[0], workspace, centered_out,
+                         "red::sum{}", NDim, "square");
+  });
+  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    Tensor<xpu, 1, DType> std_data_tensor = std_data.FlatTo1D<xpu, DType>(s);
+    std_data_tensor = F<mshadow_op::square_root>(std_data_tensor / scalar<DType>(channel_size)
+                      + scalar<DType>(param.eps));
+  });
   // Calculate data = data / std
   BinaryBroadcastRTCCompute {"div"}(attrs, ctx,
                                     {outputs[0], outputs[layernorm::kStd]},
@@ -196,7 +211,7 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
   BinaryBroadcastRTCCompute {"add"}(attrs, ctx,
                                     {outputs[0], beta},
                                     {kWriteTo}, {outputs[0]});
-#endif  // !defined(__CUDACC__)
+#endif
 }
 
 template<typename xpu>
@@ -205,6 +220,26 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs,
                           const std::vector<OpReqType>& req,
                           const std::vector<TBlob>& outputs);
 
+template<typename xpu>
+void LayerNormGradComputeGeneralImpl(const nnvm::NodeAttrs& attrs,
+                                     const OpContext& ctx,
+                                     const TBlob& ograd,
+                                     const TBlob& data,
+                                     const TBlob& gamma,
+                                     const TBlob& mean,
+                                     const TBlob& std,
+                                     const TBlob& normalized_data,
+                                     const TBlob& ograd_mult,
+                                     const TBlob& red_out,
+                                     const std::vector<OpReqType>& req,
+                                     const std::vector<TBlob>& outputs,
+                                     const mshadow::Tensor<xpu, 1, char>& workspace,
+                                     const mxnet::TShape& red_dst_shape,
+                                     const mxnet::TShape& red_src_shape,
+                                     const mxnet::TShape& red_exclude_dst_shape,
+                                     const mxnet::TShape& red_exclude_src_shape,
+                                     const int channel_size);
+
 /*
 Calculate the gradient of layer normalization.
 We have the following gradient for gamma, beta and x:
@@ -250,26 +285,17 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
   int channel_size = red_src_shape.Size() / red_dst_shape.Size();
   // Initialize the workspace + Construct the temporary TBlobs
   Tensor<xpu, 1, char> workspace;
-  size_t reduce_workspace_size = 0;
-  size_t data_size = 0;
-  size_t red_out_size = 0;
-  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-    data_size = sizeof(DType) * data.Size();
-    red_out_size = sizeof(DType) * mean.Size();
-    // There are two types of reduction workloads: reduce over axis and reduce exclude axis
-    // We take the maximum of the workspace sizes required by these workloads.
-    // Also, we explicitly set the req_type=kAddto in case we want to use it.
-    reduce_workspace_size =
-      std::max(reduce_workspace_size,
-               broadcast::ReduceWorkspaceSize(s, red_dst_shape,
-                                              kAddTo, red_src_shape,
-                                              sizeof(DType)));
-    reduce_workspace_size =
-      std::max(reduce_workspace_size,
+  size_t dtype_size = common::mshadow_type_info(outputs[0].type_flag_).size;
+  size_t data_size = data.Size() * dtype_size;
+  size_t red_out_size = mean.Size() * dtype_size;
+  // There are two types of reduction workloads: reduce over axis and reduce exclude axis
+  // We take the maximum of the workspace sizes required by these workloads.
+  // Also, we explicitly set the req_type=kAddto in case we want to use it.
+  size_t reduce_workspace_size =
+      std::max(broadcast::ReduceWorkspaceSize(s, red_dst_shape,
+                                              kAddTo, red_src_shape),
                broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo,
-                                              red_exclude_src_shape,
-                                              sizeof(DType)));
-  });
+                                              red_exclude_src_shape));
   workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(
     Shape1(reduce_workspace_size + data_size * 2 + red_out_size), s);
   const TBlob normalized_data = TBlob(workspace.dptr_ + reduce_workspace_size,
@@ -278,135 +304,11 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
                                  ograd.shape_, ograd.dev_mask(), ograd.type_flag_, ograd.dev_id());
   const TBlob red_out = TBlob(workspace.dptr_ + reduce_workspace_size + data_size * 2,
                               mean.shape_, mean.dev_mask(), mean.type_flag_, mean.dev_id());
-  // Compute normalized_data = (data - mean) / std
-#if !defined(__CUDACC__)
-  BinaryBroadcastCompute<xpu, mshadow_op::minus>(attrs, ctx,
-                                                 {data, mean},
-                                                 {kWriteTo}, {normalized_data});
-  BinaryBroadcastCompute<xpu, mshadow_op::div>(attrs, ctx,
-                                               {normalized_data, std},
-                                               {kWriteTo}, {normalized_data});
-#else
-  BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
-                                    {data, mean},
-                                    {kWriteTo}, {normalized_data});
-  BinaryBroadcastRTCCompute {"div"}(attrs, ctx,
-                                    {normalized_data, std},
-                                    {kWriteTo}, {normalized_data});
-#endif  // !defined(__CUDACC__)
-  // Calculate grad_beta
-  bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true);
-  if (req[2] != kNullOp) {
-    MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, {
-      BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
-        if (!safe_acc) {
-          broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
-            s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
-            ograd.reshape(red_exclude_src_shape));
-        } else {
-          broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
-            s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
-            ograd.reshape(red_exclude_src_shape));
-        }
-      });
-    });
-  }
-  // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis)
-#if !defined(__CUDACC__)
-  ElemwiseBinaryOp::Compute<xpu, op::mshadow_op::mul>(attrs, ctx, {normalized_data, ograd},
-                                                      {kWriteTo}, {ograd_mult});
-#else
-  ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd},
-                                   {kWriteTo}, {ograd_mult});
-#endif  // !defined(__CUDACC__)
-  if (req[1] != kNullOp) {
-    MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, {
-      BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
-        if (!safe_acc) {
-          broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
-            s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
-            ograd_mult.reshape(red_exclude_src_shape));
-        } else {
-          broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
-            s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
-            ograd_mult.reshape(red_exclude_src_shape));
-        }
-      });
-    });
-  }
-  // Calculate grad_data:
-  //   ograd_mult = ograd * gamma / std
-  //   grad_data = ograd_mult - mean(ograd_mult, axis)
-  //               + normalized_data * (-mean(normalized_data * ograd_mult, axis))
-  if (req[0] != kNullOp) {
-#if !defined(__CUDACC__)
-    BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
-                                                    {ograd, gamma},
-                                                    {kWriteTo}, {ograd_mult});
-    BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
-                                                    {ograd_mult, std},
-                                                    {kWriteTo}, {ograd_mult});
-#else
-    BinaryBroadcastRTCCompute {"mul"}(attrs, ctx,
-                                      {ograd, gamma},
-                                      {kWriteTo}, {ograd_mult});
-    BinaryBroadcastRTCCompute {"div"}(attrs, ctx,
-                                      {ograd_mult, std},
-                                      {kWriteTo}, {ograd_mult});
-#endif  // !defined(__CUDACC__)
-    MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-      BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
-        if (!safe_acc) {
-          broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
-            s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
-            ograd_mult.reshape(red_src_shape));
-        } else {
-          broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
-            s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
-            ograd_mult.reshape(red_src_shape));
-        }
-      });
-      Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
-      red_out_tensor /= scalar<DType>(channel_size);
-    });
-#if !defined(__CUDACC__)
-    BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
-                                                      {ograd_mult, red_out},
-                                                      {req[0]}, {outputs[0]});
-    ElemwiseBinaryOp::Compute<xpu, op::mshadow_op::mul>(attrs, ctx, {ograd_mult, normalized_data},
-                                                        {kWriteTo}, {ograd_mult});
-#else
-    BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
-                                      {ograd_mult, red_out},
-                                      {req[0]}, {outputs[0]});
-    ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data},
-                                     {kWriteTo}, {ograd_mult});
-#endif  // !defined(__CUDACC__)
-    MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-      BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
-        if (!safe_acc) {
-          broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
-            s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
-            ograd_mult.reshape(red_src_shape));
-        } else {
-          broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
-            s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
-            ograd_mult.reshape(red_src_shape));
-        }
-      });
-      Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
-      red_out_tensor /=  scalar<DType>(- channel_size);
-    });
-#if !defined(__CUDACC__)
-    BinaryBroadcastCompute<xpu, mshadow_op::mul>(attrs, ctx,
-                                                 {normalized_data, red_out},
-                                                 {kAddTo}, {outputs[0]});
-#else
-    BinaryBroadcastRTCCompute {"mul"}(attrs, ctx,
-                                      {normalized_data, red_out},
-                                      {kAddTo}, {outputs[0]});
-#endif  // !defined(__CUDACC__)
-  }
+
+  LayerNormGradComputeGeneralImpl(attrs, ctx, ograd, data, gamma, mean, std, normalized_data,
+                                  ograd_mult, red_out, req, outputs, workspace, red_dst_shape,
+                                  red_src_shape, red_exclude_dst_shape, red_exclude_src_shape,
+                                  channel_size);
 }
 
 }  // namespace op
diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc
index 0884720..1a040fa 100644
--- a/src/operator/nn/layer_norm.cc
+++ b/src/operator/nn/layer_norm.cc
@@ -268,6 +268,122 @@ void LayerNormCompute<cpu>(const nnvm::NodeAttrs& attrs,
   LayerNormComputeGeneral<cpu>(attrs, ctx, inputs, req, outputs);
 }
 
+template <>
+void LayerNormGradComputeGeneralImpl<cpu>(const nnvm::NodeAttrs& attrs,
+                                          const OpContext& ctx,
+                                          const TBlob& ograd,
+                                          const TBlob& data,
+                                          const TBlob& gamma,
+                                          const TBlob& mean,
+                                          const TBlob& std,
+                                          const TBlob& normalized_data,
+                                          const TBlob& ograd_mult,
+                                          const TBlob& red_out,
+                                          const std::vector<OpReqType>& req,
+                                          const std::vector<TBlob>& outputs,
+                                          const mshadow::Tensor<cpu, 1, char>& workspace,
+                                          const mxnet::TShape& red_dst_shape,
+                                          const mxnet::TShape& red_src_shape,
+                                          const mxnet::TShape& red_exclude_dst_shape,
+                                          const mxnet::TShape& red_exclude_src_shape,
+                                          const int channel_size) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  // Compute normalized_data = (data - mean) / std
+  BinaryBroadcastCompute<cpu, mshadow_op::minus>(attrs, ctx,
+                                                 {data, mean},
+                                                 {kWriteTo}, {normalized_data});
+  BinaryBroadcastCompute<cpu, mshadow_op::div>(attrs, ctx,
+                                               {normalized_data, std},
+                                               {kWriteTo}, {normalized_data});
+  // Calculate grad_beta
+  bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true);
+  if (req[2] != kNullOp) {
+    MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+        if (!safe_acc) {
+          broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
+            s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
+            ograd.reshape(red_exclude_src_shape));
+        } else {
+          broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
+            s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
+            ograd.reshape(red_exclude_src_shape));
+        }
+      });
+    });
+  }
+  // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis)
+  ElemwiseBinaryOp::Compute<cpu, op::mshadow_op::mul>(attrs, ctx, {normalized_data, ograd},
+                                                      {kWriteTo}, {ograd_mult});
+  if (req[1] != kNullOp) {
+    MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+        if (!safe_acc) {
+          broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
+            s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
+            ograd_mult.reshape(red_exclude_src_shape));
+        } else {
+          broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
+            s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
+            ograd_mult.reshape(red_exclude_src_shape));
+        }
+      });
+    });
+  }
+  // Calculate grad_data:
+  //   ograd_mult = ograd * gamma / std
+  //   grad_data = ograd_mult - mean(ograd_mult, axis)
+  //               + normalized_data * (-mean(normalized_data * ograd_mult, axis))
+  if (req[0] != kNullOp) {
+    BinaryBroadcastCompute<cpu, op::mshadow_op::mul>(attrs, ctx,
+                                                    {ograd, gamma},
+                                                    {kWriteTo}, {ograd_mult});
+    BinaryBroadcastCompute<cpu, op::mshadow_op::div>(attrs, ctx,
+                                                    {ograd_mult, std},
+                                                    {kWriteTo}, {ograd_mult});
+    MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+        if (!safe_acc) {
+          broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
+            s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+            ograd_mult.reshape(red_src_shape));
+        } else {
+          broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
+            s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+            ograd_mult.reshape(red_src_shape));
+        }
+      });
+      Tensor<cpu, 1, DType> red_out_tensor = red_out.FlatTo1D<cpu, DType>(s);
+      red_out_tensor /= scalar<DType>(channel_size);
+    });
+    BinaryBroadcastCompute<cpu, op::mshadow_op::minus>(attrs, ctx,
+                                                      {ograd_mult, red_out},
+                                                      {req[0]}, {outputs[0]});
+    ElemwiseBinaryOp::Compute<cpu, op::mshadow_op::mul>(attrs, ctx, {ograd_mult, normalized_data},
+                                                        {kWriteTo}, {ograd_mult});
+    MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+        if (!safe_acc) {
+          broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
+            s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+            ograd_mult.reshape(red_src_shape));
+        } else {
+          broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
+            s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+            ograd_mult.reshape(red_src_shape));
+        }
+      });
+      Tensor<cpu, 1, DType> red_out_tensor = red_out.FlatTo1D<cpu, DType>(s);
+      red_out_tensor /=  scalar<DType>(- channel_size);
+    });
+    BinaryBroadcastCompute<cpu, mshadow_op::mul>(attrs, ctx,
+                                                 {normalized_data, red_out},
+                                                 {kAddTo}, {outputs[0]});
+  }
+}
+
 template<>
 void LayerNormGradCompute<cpu>(const nnvm::NodeAttrs& attrs,
                                const OpContext& ctx, const std::vector<TBlob>& inputs,
diff --git a/src/operator/nn/layer_norm.cu b/src/operator/nn/layer_norm.cu
index a60df41..9a33e06 100644
--- a/src/operator/nn/layer_norm.cu
+++ b/src/operator/nn/layer_norm.cu
@@ -29,6 +29,89 @@ using namespace mshadow::cuda;
 namespace mxnet {
 namespace op {
 
+template <>
+void LayerNormGradComputeGeneralImpl<gpu>(const nnvm::NodeAttrs& attrs,
+                                          const OpContext& ctx,
+                                          const TBlob& ograd,
+                                          const TBlob& data,
+                                          const TBlob& gamma,
+                                          const TBlob& mean,
+                                          const TBlob& std,
+                                          const TBlob& normalized_data,
+                                          const TBlob& ograd_mult,
+                                          const TBlob& red_out,
+                                          const std::vector<OpReqType>& req,
+                                          const std::vector<TBlob>& outputs,
+                                          const mshadow::Tensor<gpu, 1, char>& workspace,
+                                          const mxnet::TShape& red_dst_shape,
+                                          const mxnet::TShape& red_src_shape,
+                                          const mxnet::TShape& red_exclude_dst_shape,
+                                          const mxnet::TShape& red_exclude_src_shape,
+                                          const int channel_size) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  Stream<gpu> *s = ctx.get_stream<gpu>();
+  // Compute normalized_data = (data - mean) / std
+  BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
+                                    {data, mean},
+                                    {kWriteTo}, {normalized_data});
+  BinaryBroadcastRTCCompute {"div"}(attrs, ctx,
+                                    {normalized_data, std},
+                                    {kWriteTo}, {normalized_data});
+  // Calculate grad_beta
+  if (req[2] != kNullOp) {
+    BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+      broadcast::RTCReduce(ctx, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
+                           ograd.reshape(red_exclude_src_shape), "red::sum{}", NDim, "identity");
+    });
+  }
+  // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis)
+  ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd},
+                                   {kWriteTo}, {ograd_mult});
+  if (req[1] != kNullOp) {
+    BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+      broadcast::RTCReduce(ctx, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
+                           ograd_mult.reshape(red_exclude_src_shape), "red::sum{}", NDim,
+                           "identity");
+    });
+  }
+  // Calculate grad_data:
+  //   ograd_mult = ograd * gamma / std
+  //   grad_data = ograd_mult - mean(ograd_mult, axis)
+  //               + normalized_data * (-mean(normalized_data * ograd_mult, axis))
+  if (req[0] != kNullOp) {
+    BinaryBroadcastRTCCompute {"mul"}(attrs, ctx,
+                                      {ograd, gamma},
+                                      {kWriteTo}, {ograd_mult});
+    BinaryBroadcastRTCCompute {"div"}(attrs, ctx,
+                                      {ograd_mult, std},
+                                      {kWriteTo}, {ograd_mult});
+    BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+        broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+                             ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity");
+    });
+    MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      Tensor<gpu, 1, DType> red_out_tensor = red_out.FlatTo1D<gpu, DType>(s);
+      red_out_tensor /= scalar<DType>(channel_size);
+    });
+    BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
+                                      {ograd_mult, red_out},
+                                      {req[0]}, {outputs[0]});
+    ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data},
+                                     {kWriteTo}, {ograd_mult});
+    BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+      broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+                           ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity");
+    });
+    MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      Tensor<gpu, 1, DType> red_out_tensor = red_out.FlatTo1D<gpu, DType>(s);
+      red_out_tensor /=  scalar<DType>(- channel_size);
+    });
+    BinaryBroadcastRTCCompute {"mul"}(attrs, ctx,
+                                      {normalized_data, red_out},
+                                      {kAddTo}, {outputs[0]});
+  }
+}
 template <typename DType>
 __device__ __forceinline__ DType warp_shfl(DType value, int src_lane,
                                            int width = 32, unsigned int mask = 0xffffffff) {
diff --git a/src/operator/nn/moments-inl.h b/src/operator/nn/moments-inl.h
index ca78b65..78c7e4a 100644
--- a/src/operator/nn/moments-inl.h
+++ b/src/operator/nn/moments-inl.h
@@ -126,7 +126,12 @@ inline void MomentsForwardImpl(const OpContext& ctx,
     small = ReduceAxesShapeImpl(inputs[0].shape_, axes, true, false);
   }
 
+#if !defined(__CUDACC__)
   ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, true>(ctx, {data}, {req[0]}, {mean}, small);
+#else
+  ReduceAxesRTCComputeImpl(ctx, {data}, {req[0]}, {mean}, small, "red::sum{}", nullptr, true);
+#endif
+  TBlob temp;
   MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
     Shape<6> data_shape, mean_shape;
     for (int i = 0; i < 6; ++i) {
@@ -137,9 +142,15 @@ inline void MomentsForwardImpl(const OpContext& ctx,
       ctx.requested[0].get_space_typed<xpu, 1, DType>(Shape1(data.shape_.Size()), s);;
     Kernel<VarBroadcastKernel, xpu>::Launch(s, data.shape_.Size(), temp_data.dptr_,
       data.dptr<DType>(), mean.dptr<DType>(), data_shape, mean_shape);
-    ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, true>(
-      ctx, {TBlob(temp_data).reshape(data.shape_)}, {kWriteTo}, {var}, small);
+    temp = TBlob(temp_data);
   });
+#if !defined(__CUDACC__)
+  ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, true>(
+    ctx, {temp.reshape(data.shape_)}, {kWriteTo}, {var}, small);
+#else
+  ReduceAxesRTCComputeImpl(ctx, {temp.reshape(data.shape_)},
+                           {kWriteTo}, {var}, small, "red::sum{}", nullptr, true);
+#endif
 }
 
 template<typename xpu>
diff --git a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh
deleted file mode 100644
index d4374ed..0000000
--- a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh
+++ /dev/null
@@ -1,415 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2015-2020 by Contributors
- * \file broadcast_reduce_customized-inl.cuh
- * \brief Customized CUDA implementations for binary broadcast and reduce
- * \author MXNet contributors
-*/
-#ifndef MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_INL_CUSTOMIZED_CUH_
-#define MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_INL_CUSTOMIZED_CUH_
-
-#include "../../tensor/broadcast_reduce-inl.cuh"
-
-using namespace mshadow::cuda;
-
-template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP, int unroll>
-__launch_bounds__(nthread_reduce)
-__global__ void reduce_kernel_wr(const int N, const int M, const bool addto,
-                                 const DType* __restrict big, OType *small,
-                                 const Shape<ndim> big_shape0, const Shape<ndim> small_shape,
-                                 const Shape<ndim> big_shape, const Shape<ndim> big_stride,
-                                 const int Mnext, const bool do_transpose,
-                                 Reducer* reducer) {
-  extern __shared__ char shTileChar[];
-  AType* shTile = (AType*)(shTileChar);
-  const int tid = threadIdx.x + threadIdx.y*blockDim.x;
-  const int bx = (do_transpose) ? blockDim.y : blockDim.x;
-  const int by = (do_transpose) ? blockDim.x : blockDim.y;
-  const int tidx = (do_transpose) ? tid / by : threadIdx.x;
-  const int tidy = (do_transpose) ? tid % by : threadIdx.y;
-  // bool need_clean = !reducer;
-  // reducer = reducer ? reducer : new Reducer();
-  for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) {
-    // This TB handles M range [Mstart, ...., Mend - 1]
-    const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext);
-    const int Mend   = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext);
-    for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) {
-      int idx = idx0 + tidx;
-      Shape<ndim> coord = unravel(idx, small_shape);
-      int idx_big0 = ravel(coord, big_shape0);
-
-      AType val, residual;
-      reducer->SetInitValue(val, residual);
-      if (idx < N) {
-        for (int k = tidy + Mstart; k < Mend; k += by*unroll) {
-          int idx_big[unroll];
-          #pragma unroll
-          for (int u=0;u < unroll;u++) {
-            idx_big[u] = idx_big0 + unravel_dot(k + u*by, big_shape, big_stride);
-          }
-          DType tmp[unroll];
-          #pragma unroll
-          for (int u=0;u < unroll;u++) {
-            if (k + u*by < Mend) {
-              tmp[u] = OP::Map(big[idx_big[u]]);
-            }
-          }
-          #pragma unroll
-          for (int u=0;u < unroll;u++) {
-            if (k + u*by < Mend) reducer->Reduce(val, AType(tmp[u]), residual);
-          }
-        }
-      }
-
-      // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0
-      if (by > 1) {
-        // Fix bx to avoid bank conflicts. Assumes warpSize number of banks
-        const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx;
-        const int it0 = tidx + tidy*fbx;
-        shTile[it0 * 2] = val;
-        shTile[it0 * 2 + 1] = residual;
-        __syncthreads();
-        for (int t=1;t < by;t <<= 1) {
-          AType tmp, tmp_residual;
-          reducer->SetInitValue(tmp, tmp_residual);
-          if (tidy + t < by) {
-            tmp = shTile[(it0 + t*fbx) * 2];
-            tmp_residual = shTile[(it0 + t*fbx) * 2 + 1];
-          }
-          __syncthreads();
-          reducer->Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual);
-          __syncthreads();
-        }
-        if (idx < N && tidy == 0) {
-          reducer->Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
-          assign(&small[idx + m0*N], addto, OType(shTile[tidx * 2]));
-        }
-      } else {
-        if (idx < N) {
-          reducer->Finalize(val, residual);
-          assign(&small[idx + m0*N], addto, OType(val));
-        }
-      }
-    }
-  }
-  // if (need_clean) {
-  //  delete reducer;
-  // }
-}
-
-template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2, int unroll>
-__launch_bounds__(nthread_reduce)
-__global__ void reduce_kernel_wr(const int N, const int M, const bool addto,
-                                 const DType* __restrict big, const DType* __restrict lhs,
-                                 const DType* __restrict rhs, DType *small,
-                                 const Shape<ndim> big_shape0, const Shape<ndim> lhs_shape0,
-                                 const Shape<ndim> rhs_shape0, const Shape<ndim> small_shape,
-                                 const Shape<ndim> big_shape, const Shape<ndim> lhs_shape,
-                                 const Shape<ndim> rhs_shape, const Shape<ndim> big_stride,
-                                 const Shape<ndim> lhs_stride, const Shape<ndim> rhs_stride,
-                                 const int Mnext, const bool do_transpose,
-                                 Reducer* reducer) {
-  extern __shared__ char shTileChar[];
-  DType* shTile = (DType*)(shTileChar);
-  const int tid = threadIdx.x + threadIdx.y*blockDim.x;
-  const int bx = (do_transpose) ? blockDim.y : blockDim.x;
-  const int by = (do_transpose) ? blockDim.x : blockDim.y;
-  const int tidx = (do_transpose) ? tid / by : threadIdx.x;
-  const int tidy = (do_transpose) ? tid % by : threadIdx.y;
-  // bool need_clean = !reducer;
-  // reducer = reducer ? reducer : new Reducer();
-  for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) {
-    // This TB handles M range [Mstart, ...., Mend - 1]
-    const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext);
-    const int Mend   = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext);
-    for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) {
-      int idx = idx0 + tidx;
-      Shape<ndim> coord = unravel(idx, small_shape);
-      int idx_big0 = ravel(coord, big_shape0);
-      int idx_lhs0 = ravel(coord, lhs_shape0);
-      int idx_rhs0 = ravel(coord, rhs_shape0);
-
-      DType val, residual;
-      reducer->SetInitValue(val, residual);
-      if (idx < N) {
-        for (int k = tidy + Mstart; k < Mend; k += by*unroll) {
-          int idx_big[unroll];
-          int idx_lhs[unroll];
-          int idx_rhs[unroll];
-          #pragma unroll
-          for (int u=0;u < unroll;u++) {
-            idx_big[u] = idx_big0 + unravel_dot(k + u*by, big_shape, big_stride);
-            idx_lhs[u] = idx_lhs0 + unravel_dot(k + u*by, lhs_shape, lhs_stride);
-            idx_rhs[u] = idx_rhs0 + unravel_dot(k + u*by, rhs_shape, rhs_stride);
-          }
-          DType tmp[unroll];
-          #pragma unroll
-          for (int u=0;u < unroll;u++) {
-            if (k + u*by < Mend) {
-              tmp[u] = OP1::Map(big[idx_big[u]], OP2::Map(lhs[idx_lhs[u]], rhs[idx_rhs[u]]));
-            }
-          }
-          #pragma unroll
-          for (int u=0;u < unroll;u++) {
-            if (k + u*by < Mend) reducer->Reduce(val, tmp[u], residual);
-          }
-        }
-      }
-
-      // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0
-      if (by > 1) {
-        // Fix bx to avoid bank conflicts. Assumes warpSize number of banks
-        const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx;
-        const int it0 = tidx + tidy*fbx;
-        shTile[it0 * 2] = val;
-        shTile[it0 * 2 + 1] = residual;
-        __syncthreads();
-        for (int t=1;t < by;t <<= 1) {
-          DType tmp, tmp_residual;
-          reducer->SetInitValue(tmp, tmp_residual);
-          if (tidy + t < by) {
-            tmp = shTile[(it0 + t*fbx) * 2];
-            tmp_residual = shTile[(it0 + t*fbx) * 2 + 1];
-          }
-          __syncthreads();
-          reducer->Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual);
-          __syncthreads();
-        }
-        if (idx < N && tidy == 0) {
-          reducer->Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
-          assign(&small[idx + m0*N], addto, shTile[tidx * 2]);
-        }
-      } else {
-        if (idx < N) {
-          reducer->Finalize(val, residual);
-          assign(&small[idx + m0*N], addto, val);
-        }
-      }
-    }
-  }
-  // if (need_clean) {
-  //   delete reducer;
-  // }
-}
-
-// Simple reduction of lines when M is small
-template<typename Reducer, typename DType>
-__launch_bounds__(kMaxThreadsPerBlock)
-__global__ void reduce_lines_kernel_wr(const int N, const int M, const bool addto,
-  const int small_in_stride, const DType* __restrict small_in, DType *small_out,
-  Reducer* reducer) {
-  for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
-
-    DType val, residual;
-    reducer->SetInitValue(val, residual);
-    for (int k = 0; k < M; k++) {
-      reducer->Reduce(val, small_in[idx + k*small_in_stride], residual);
-    }
-
-    if (idx < N) {
-      reducer->Finalize(val, residual);
-      assign(&small_out[idx], addto, val);
-    }
-
-  }
-}
-
-template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP>
-__launch_bounds__(kMaxThreadsPerBlock)
-__global__ void reduce_kernel_M1_wr(const int N, const bool addto,
-                                    const DType* __restrict big, OType *small, const Shape<ndim> bshape,
-                                    const Shape<ndim> sshape, Reducer* reducer) {
-  for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
-    Shape<ndim> coord = unravel(idx, sshape);
-    int j = ravel(coord, bshape);
-    AType val, residual;
-    reducer->SetInitValue(val, residual);
-    reducer->Reduce(val, AType(OP::Map(big[j])), residual);
-    reducer->Finalize(val, residual);
-    assign(&small[idx], addto, OType(val));
-  }
-}
-
-template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2>
-__launch_bounds__(kMaxThreadsPerBlock)
-__global__ void reduce_kernel_M1_wr(const int N, const bool addto,
-                                    const DType* __restrict big,
-                                    const DType* __restrict lhs,
-                                    const DType* __restrict rhs,
-                                    DType *small,
-                                    const Shape<ndim> big_shape,
-                                    const Shape<ndim> lhs_shape,
-                                    const Shape<ndim> rhs_shape,
-                                    const Shape<ndim> small_shape,
-                                    Reducer* reducer) {
-  for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
-    Shape<ndim> coord = unravel(idx, small_shape);
-    int idx_big = ravel(coord, big_shape);
-    int idx_lhs = ravel(coord, lhs_shape);
-    int idx_rhs = ravel(coord, rhs_shape);
-    DType val, residual;
-    reducer->SetInitValue(val, residual);
-    reducer->Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual);
-    reducer->Finalize(val, residual);
-    assign(&small[idx], addto, val);
-  }
-}
-
-#define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \
-  if (do_unroll) {                                                    \
-    const int unrollVar = unrollAmount;                               \
-    {__VA_ARGS__}                                                     \
-  } else {                                                            \
-    const int unrollVar = 1;                                          \
-    {__VA_ARGS__}                                                     \
-  }
-
-template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP>
-void ReduceImplWithReducer(cudaStream_t stream, const TBlob& small, const OpReqType req,
-                           const TBlob& big, const Tensor<gpu, 1, char>& workspace,
-                           const ReduceImplConfig& config,
-                           Reducer* reducer = nullptr) {
-  bool need_clean = !reducer;
-  reducer = reducer ? reducer : new Reducer();
-  if (config.M == 1) {
-    reduce_kernel_M1_wr<Reducer, ndim, AType, DType, OType, OP>
-    <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>(
-      config.N, req == kAddTo, big.dptr<DType>(), small.dptr<OType>(), big.shape_.get<ndim>(),
-      small.shape_.get<ndim>(), reducer);
-    MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1_wr);
-  } else {
-    OType* small_dptr = small.dptr<OType>();
-    bool addto = (req == kAddTo);
-    if (config.Mnext > 1) {
-      // small_dptr[] is N*Mnext*sizeof(DType) bytes
-      small_dptr = reinterpret_cast<OType*>(workspace.dptr_);
-      addto = false;
-      // Check that the workspace is contigiuous
-      CHECK_EQ(workspace.CheckContiguous(), true);
-      // Check that we have enough storage
-      CHECK_GE(workspace.size(0), config.workspace_size);
-    }
-
-    const int by = (config.kernel_1.do_transpose) ?
-      config.kernel_1.blockDim.x : config.kernel_1.blockDim.y;
-    const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce );
-    KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, {
-      reduce_kernel_wr<Reducer, ndim, AType, DType, OType, OP, UNROLL>
-      <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>(
-        config.N, config.M, addto, big.dptr<DType>(), small_dptr, big.shape_.get<ndim>(),
-        small.shape_.get<ndim>(), config.rshape.get<ndim>(), config.rstride.get<ndim>(),
-        config.Mnext, config.kernel_1.do_transpose, reducer);
-    });
-    MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_wr);
-
-    if (config.Mnext > 1) {
-      reduce_lines_kernel_wr<Reducer, OType>
-      <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>>
-        (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr<OType>(), reducer);
-      MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel_wr);
-    }
-  }
-  if (need_clean) {
-    delete reducer;
-  }
-}
-
-template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2>
-void ReduceImplWithReducer(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const TBlob& rhs,
-                           const OpReqType req, const TBlob& big, const Tensor<gpu, 1, char>& workspace,
-                           const ReduceImplConfig& config, Reducer* reducer = nullptr) {
-  bool need_clean = !reducer;
-  reducer = reducer ? reducer : new Reducer();
-  if (config.M == 1) {
-    reduce_kernel_M1_wr<Reducer, ndim, DType, OP1, OP2>
-    <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>(
-      config.N, req == kAddTo, big.dptr<DType>(), lhs.dptr<DType>(), rhs.dptr<DType>(),
-      small.dptr<DType>(), big.shape_.get<ndim>(), lhs.shape_.get<ndim>(),
-      rhs.shape_.get<ndim>(), small.shape_.get<ndim>());
-    MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1_wr);
-  } else {
-    DType* small_dptr = small.dptr<DType>();
-    bool addto = (req == kAddTo);
-    if (config.Mnext > 1) {
-      // small_dptr[] is N*Mnext*sizeof(DType) bytes
-      small_dptr = reinterpret_cast<DType*>(workspace.dptr_);
-      addto = false;
-      // Check that the workspace is contigiuous
-      CHECK_EQ(workspace.CheckContiguous(), true);
-      // Check that we have enough storage
-      CHECK_GE(workspace.size(0), config.workspace_size);
-    }
-
-    const int by = (config.kernel_1.do_transpose) ?
-      config.kernel_1.blockDim.x : config.kernel_1.blockDim.y;
-    const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce );
-    KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, {
-      reduce_kernel_wr<Reducer, ndim, DType, OP1, OP2, UNROLL>
-      <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>(
-        config.N, config.M, addto, big.dptr<DType>(), lhs.dptr<DType>(), rhs.dptr<DType>(),
-        small_dptr, big.shape_.get<ndim>(), lhs.shape_.get<ndim>(),
-        rhs.shape_.get<ndim>(), small.shape_.get<ndim>(), config.rshape, config.lhs_shape,
-        config.rhs_shape, config.rstride, config.lhs_stride, config.rhs_stride, config.Mnext,
-        config.kernel_1.do_transpose, reducer);
-      MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_wr);
-    });
-
-    if (config.Mnext > 1) {
-      reduce_lines_kernel_wr<Reducer, DType>
-      <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>>
-        (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr<DType>(), reducer);
-      MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel_wr);
-    }
-  }
-  if (need_clean) {
-    delete reducer;
-  }
-}
-
-#undef KERNEL_UNROLL_SWITCH
-
-template<typename Reducer, int ndim, typename DType, typename OP, bool safe_acc = false>
-void ReduceWithReducer(Stream<gpu> *s, const TBlob& small, const OpReqType req,
-                       const Tensor<gpu, 1, char>& workspace, const TBlob& big, Reducer* reducer = nullptr) {
-  if (req == kNullOp) return;
-  cudaStream_t stream = Stream<gpu>::GetStream(s);
-  bool need_clean = !reducer;
-  reducer = reducer ? reducer : new Reducer();
-  ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, sizeof(DType));
-  if (safe_acc) {
-    MXNET_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
-      typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
-      MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
-        typedef typename std::conditional<safe_acc, OType, DataType>::type OutType;
-        config = ReduceImplConfig(small.shape_, big.shape_, nullptr, nullptr, sizeof(AccType));
-        ReduceImplWithReducer<Reducer, ndim, AccType, DataType, OutType, OP>(
-          stream, small, req, big, workspace, config, reducer);
-      });
-    });
-  } else {
-    ReduceImplWithReducer<Reducer, ndim, DType, DType, DType, OP>(stream, small, req, big, workspace, config, reducer);
-  }
-  if (need_clean) {
-    delete reducer;
-  }
-}
-
-#endif  // MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_INL_CUSTOMIZED_CUH_
diff --git a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h
index 0226df4..2941d54 100644
--- a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h
+++ b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h
@@ -54,12 +54,6 @@ MSHADOW_XINLINE void seq_reduce_assign_wr(const index_t idx, const size_t M, con
   assign(&small[idx], addto, OType(val));
 }
 
-#ifdef __CUDACC__
-#include "broadcast_reduce_customized-inl.cuh"
-#include "../../tensor/broadcast_reduce-inl.cuh"
-
-#else
-
 template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP>
 void seq_reduce_compute_wr(const size_t N, const size_t M, const bool addto,
                            const DType *big, OType *small, const Shape<ndim> bshape,
@@ -177,7 +171,6 @@ void ReduceWithReducer(Stream<cpu> *s, const TBlob& small, const OpReqType req,
     reducer);
 }
 
-#endif
 }  // namespace broadcast
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/linalg/broadcast_reduce_op_customized.h b/src/operator/numpy/linalg/broadcast_reduce_op_customized.h
index 8e1c0b3..976991f 100644
--- a/src/operator/numpy/linalg/broadcast_reduce_op_customized.h
+++ b/src/operator/numpy/linalg/broadcast_reduce_op_customized.h
@@ -46,19 +46,17 @@ void ReduceAxesComputeImplWithReducer(const OpContext& ctx,
   mxnet::TShape src_shape, dst_shape;
   BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape);
   Stream<xpu> *s = ctx.get_stream<xpu>();
-  MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
-    MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, {
-      const TBlob in_data = inputs[0].reshape(src_shape);
-      const TBlob out_data = outputs[0].reshape(dst_shape);
-      BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
-        size_t workspace_size = broadcast::ReduceWorkspaceSize(
-            s, out_data.shape_, req[0], in_data.shape_, sizeof(OType));
-        Tensor<xpu, 1, char> workspace =
-            ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
-        broadcast::ReduceWithReducer<Reducer, NDim, OType, OP, safe_acc>(
-            s, out_data, req[0], workspace, in_data, reducer);
-        // no normalization
-      });
+  MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, {
+    const TBlob in_data = inputs[0].reshape(src_shape);
+    const TBlob out_data = outputs[0].reshape(dst_shape);
+    BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
+      size_t workspace_size = broadcast::ReduceWorkspaceSize(
+          s, out_data.shape_, req[0], in_data.shape_);
+      Tensor<xpu, 1, char> workspace =
+          ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+      broadcast::ReduceWithReducer<Reducer, NDim, OType, OP, safe_acc>(
+          s, out_data, req[0], workspace, in_data, reducer);
+      // no normalization
     });
   });
 }
diff --git a/src/operator/numpy/linalg/np_norm-inl.h b/src/operator/numpy/linalg/np_norm-inl.h
index b26e680..60dee6a 100644
--- a/src/operator/numpy/linalg/np_norm-inl.h
+++ b/src/operator/numpy/linalg/np_norm-inl.h
@@ -285,18 +285,10 @@ void NumpyLpNormCompute(const nnvm::NodeAttrs& attrs,
     } else if (param.ord == std::numeric_limits<double>::infinity()) {  // inf norm
       LOG(FATAL) << "inf norm handled in front-end.";
     } else {
+#ifndef __CUDACC__
       mshadow_op::nrmlp host_reducer(param.ord);
       mshadow_op::nrmlp *reducer_instance = nullptr;
-#ifdef __CUDACC__
-      Stream<xpu> *s = ctx.get_stream<xpu>();
-      cudaStream_t copy_stream = mshadow::Stream<gpu>::GetStream(s);
-      cudaMalloc(reinterpret_cast<void**>(&reducer_instance), sizeof(mshadow_op::nrmlp));
-      cudaMemcpyAsync(reducer_instance, &host_reducer, sizeof(mshadow_op::nrmlp),
-                      cudaMemcpyHostToDevice, copy_stream);
-      cudaStreamSynchronize(copy_stream);
-#else
       reducer_instance = &host_reducer;
-#endif
       if (safe_acc) {
         ReduceAxesComputeImplWithReducer<xpu, mshadow_op::nrmlp, true, mshadow_op::abs>(
           ctx, inputs, req, outputs, small, reducer_instance);
@@ -304,8 +296,10 @@ void NumpyLpNormCompute(const nnvm::NodeAttrs& attrs,
         ReduceAxesComputeImplWithReducer<xpu, mshadow_op::nrmlp, false, mshadow_op::abs>(
           ctx, inputs, req, outputs, small, reducer_instance);
       }
-#ifdef __CUDACC__
-      cudaFree(reducer_instance);
+#else
+      ReduceAxesRTCComputeImpl(
+        ctx, inputs, req, outputs, small, "red::nrmlp{" + std::to_string(param.ord) + "}",
+        nullptr, false, "abs");
 #endif
     }
   }
@@ -443,8 +437,13 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs,
   }
 
   if (param.flag == 1) {  // Frobenius norm
-    ReduceAxesComputeImplWithReducer<xpu, mshadow_op::nrm2, false, mshadow_op::identity>(
+#if !defined(__CUDACC__)
+    ReduceAxesComputeImpl<xpu, mshadow_op::nrm2, false, false, mshadow_op::identity>(
       ctx, inputs, req, outputs, reduced_shape);
+#else
+    ReduceAxesRTCComputeImpl(
+      ctx, inputs, req, outputs, reduced_shape, "red::nrm2{}", nullptr, false, "identity");
+#endif
     return;
   }
 
@@ -453,19 +452,29 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs,
   if (param.ord != 2 && param.ord != -2) {  // row norm or col norm
     TShape sum_shape = inputs[0].shape_;
     sum_shape[mat_axis[!(param.ord == 1 || param.ord == -1)]] = 1;
-    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-      TBlob temp = outputs[1].reshape(sum_shape);
-      std::vector<TBlob> sum_output({temp});
-      ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, false, mshadow_op::abs>(
-        ctx, inputs, req, sum_output, sum_shape);
-      if (param.ord > 0) {
-        ReduceAxesComputeImpl<xpu, mshadow::red::maximum, false, false, mshadow_op::identity>(
-          ctx, sum_output, req, outputs, reduced_shape);
-      } else {
-        ReduceAxesComputeImpl<xpu, mshadow::red::minimum, false, false, mshadow_op::identity>(
-          ctx, sum_output, req, outputs, reduced_shape);
-      }
-    });
+    TBlob temp = outputs[1].reshape(sum_shape);
+    std::vector<TBlob> sum_output({temp});
+#if !defined(__CUDACC__)
+    ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, false, mshadow_op::abs>(
+      ctx, inputs, req, sum_output, sum_shape);
+    if (param.ord > 0) {
+      ReduceAxesComputeImpl<xpu, mshadow::red::maximum, false, false, mshadow_op::identity>(
+        ctx, sum_output, req, outputs, reduced_shape);
+    } else {
+      ReduceAxesComputeImpl<xpu, mshadow::red::minimum, false, false, mshadow_op::identity>(
+        ctx, sum_output, req, outputs, reduced_shape);
+    }
+#else
+    ReduceAxesRTCComputeImpl(ctx, inputs, req, sum_output, sum_shape,
+                             "red::sum{}", nullptr, false, "abs");
+    if (param.ord > 0) {
+      ReduceAxesRTCComputeImpl(ctx, sum_output, req, outputs, reduced_shape,
+                               "red::maximum{}", nullptr, false);
+    } else {
+      ReduceAxesRTCComputeImpl(ctx, sum_output, req, outputs, reduced_shape,
+                               "red::minimum{}", nullptr, false);
+    }
+#endif
     return;
   }
 
@@ -500,6 +509,7 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs,
     L_trans[mat_axis[1]] = 1;
   }
 
+  std::vector<TBlob> eigen;
   MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
     Tensor<xpu, 3, DType> UT =
       outputs[1].get_with_shape<xpu, 3, DType>(Shape3(batch_dim, row_dim, row_dim), s);
@@ -523,32 +533,46 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs,
     Tensor<xpu, 3, DType> svd_input =
       workspace.get_with_shape<xpu, 3, DType>(Shape3(batch_dim, row_dim, col_dim), s);
     gesvd::op(svd_input, UT, L, V, ctx, attrs, &svd_workspace);
-
     TBlob workspace0(reinterpret_cast<DType*>(temp.dptr_), L_trans,
                      temp.dev_mask(), temp.dev_id());
     TransposeImpl<xpu>(ctx.run_ctx, TBlob(L).reshape(L_shape), workspace0, reduce_axes);
-    std::vector<TBlob> eigen({ workspace0 });
-    if (param.flag == 2) {  // nuclear norm
-      ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, false, mshadow_op::identity>(
+    eigen.emplace_back(workspace0);
+  });
+
+#if !defined(__CUDACC__)
+  if (param.flag == 2) {  // nuclear norm
+    ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, false, mshadow_op::identity>(
+      ctx, eigen, req, outputs, reduced_shape);
+  } else if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true)) {
+    if (ord == 2) {
+      ReduceAxesComputeImpl<xpu, mshadow::red::maximum, true, false, mshadow_op::abs>(
+        ctx, eigen, req, outputs, reduced_shape);
+    } else if (ord == -2) {
+      ReduceAxesComputeImpl<xpu, mshadow::red::minimum, true, false, mshadow_op::abs>(
         ctx, eigen, req, outputs, reduced_shape);
-    } else if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true)) {
-      if (ord == 2) {
-        ReduceAxesComputeImpl<xpu, mshadow::red::maximum, true, false, mshadow_op::abs>(
-          ctx, eigen, req, outputs, reduced_shape);
-      } else if (ord == -2) {
-        ReduceAxesComputeImpl<xpu, mshadow::red::minimum, true, false, mshadow_op::abs>(
-          ctx, eigen, req, outputs, reduced_shape);
-      }
-    } else {
-      if (ord == 2) {
-        ReduceAxesComputeImpl<xpu, mshadow::red::maximum, false, false, mshadow_op::abs>(
-          ctx, eigen, req, outputs, reduced_shape);
-      } else if (ord == -2) {
-        ReduceAxesComputeImpl<xpu, mshadow::red::minimum, false, false, mshadow_op::abs>(
-          ctx, eigen, req, outputs, reduced_shape);
-      }
     }
-  });
+  } else {
+    if (ord == 2) {
+      ReduceAxesComputeImpl<xpu, mshadow::red::maximum, false, false, mshadow_op::abs>(
+        ctx, eigen, req, outputs, reduced_shape);
+    } else if (ord == -2) {
+      ReduceAxesComputeImpl<xpu, mshadow::red::minimum, false, false, mshadow_op::abs>(
+        ctx, eigen, req, outputs, reduced_shape);
+    }
+  }
+#else
+  if (param.flag == 2) {  // nuclear norm
+    ReduceAxesRTCComputeImpl(ctx, eigen, req, outputs, reduced_shape, "red::sum{}", nullptr, false);
+  } else {
+    if (ord == 2) {
+      ReduceAxesRTCComputeImpl(ctx, eigen, req, outputs, reduced_shape,
+                               "red::maximum{}", nullptr, false, "abs");
+    } else if (ord == -2) {
+      ReduceAxesRTCComputeImpl(ctx, eigen, req, outputs, reduced_shape,
+                               "red::minimum{}", nullptr, false, "abs");
+    }
+  }
+#endif
 }
 
 template<typename xpu>
@@ -784,8 +808,13 @@ void NumpyNormComputeForward(const nnvm::NodeAttrs& attrs,
     std::vector<TBlob> flat_outputs({
       outputs[0].reshape(TShape(1, 1))
     });
-    ReduceAxesComputeImplWithReducer<xpu, mshadow_op::nrm2, false, mshadow_op::identity>(
+#if !defined(__CUDACC__)
+    ReduceAxesComputeImpl<xpu, mshadow_op::nrm2, false, false, mshadow_op::identity>(
       ctx, flat_inputs, req, flat_outputs, TShape(1, 1));
+#else
+    ReduceAxesRTCComputeImpl(
+      ctx, flat_inputs, req, flat_outputs, TShape(1, 1), "red::nrm2{}", nullptr, false, "identity");
+#endif
     return;
   }
 
diff --git a/src/operator/numpy/np_broadcast_reduce_op.cc b/src/operator/numpy/np_broadcast_reduce_op.cc
new file mode 100644
index 0000000..4b64a1a
--- /dev/null
+++ b/src/operator/numpy/np_broadcast_reduce_op.cc
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2020 by Contributors
+ * \file np_broadcast_reduce_op.cc
+ * \brief Function definitions of NumPy-compatible
+ *        broadcast and reduce operators
+ */
+
+#include "np_broadcast_reduce_op.h"
+
+namespace mxnet {
+namespace op {
+#if MXNET_USE_CUDA
+
+void NumpyArgMinMaxRTCCompute::operator()(const nnvm::NodeAttrs& attrs,
+                const OpContext& ctx,
+                const std::vector<TBlob>& inputs,
+                const std::vector<OpReqType>& req,
+                const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  if (req[0] == kNullOp) return;
+  // parse param
+  const auto& param = nnvm::get<ReduceAxisParam>(attrs.parsed);
+  mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
+  TBlob out = outputs[0];
+  TBlob in = inputs[0];
+  // do some shape checks
+  if (in.shape_.ndim() != 0) {
+    if (param.axis.has_value()) {
+      // cannot do argmax in an empty dimension
+      int axis = param.axis.value();
+      axis = CheckAxis(axis, in.shape_.ndim());
+      CHECK_NE(in.shape_[axis], 0)
+          << "searching input tensor of shape " << inputs[0].shape_
+          << " along axis = " << axis << " of zero dim-size is not allowed";
+    } else {
+      // cannot do argmax on an empty array
+      CHECK_NE(in.shape_.Size(), 0U) << "attempt to search an empty sequence";
+    }
+  }
+  if (in.shape_.Size() == 0U) return;  // zero-size tensor
+  // prepare shape
+  dmlc::optional<mxnet::Tuple<int>> axes;
+  if (param.axis.has_value()) {
+    mxnet::Tuple<int> t({param.axis.value()});
+    axes = dmlc::optional<mxnet::Tuple<int>>(t);
+  }
+  TShape small;
+  small = NumpyReduceAxesShapeImpl(in.shape_, axes, true);
+  mxnet::TShape src_shape, dst_shape;
+  BroadcastReduceShapeCompact(in.shape_, small, &src_shape, &dst_shape);
+  const TBlob in_data = in.reshape(src_shape);
+  // request a work space
+  size_t workspace_size = broadcast::ReduceWorkspaceSize(s, dst_shape, req[0], src_shape);
+  Tensor<gpu, 1, char> workspace =
+            ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(workspace_size), s);
+  BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
+      broadcast::RTCReduce(ctx, outputs[0].reshape(dst_shape), req[0], workspace, in_data,
+                           reducer, NDim, "identity", true);
+  });
+}
+
+#endif  // MXNET_USE_CUDA
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/numpy/np_broadcast_reduce_op.cuh b/src/operator/numpy/np_broadcast_reduce_op.cuh
deleted file mode 100644
index f97aa78..0000000
--- a/src/operator/numpy/np_broadcast_reduce_op.cuh
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2015-2020 by Contributors
- * \file np_broadcast_reduce-inl.cuh
- * \brief GPU implementations for numpy binary broadcast ops
- * \author Zhaoqi Zhu
-*/
-#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
-#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
-
-using namespace mshadow::cuda;
-using namespace mshadow;
-using namespace broadcast;
-
-template<typename Reducer, int NDim, typename DType, typename OType>
-void NumpyArgMinMaxReduce(Stream<gpu> *s, const TBlob& in_data, const TBlob& out_data,
-                          const Tensor<gpu, 1, char>& workspace) {
-  cudaStream_t stream = Stream<gpu>::GetStream(s);
-  ReduceImplConfig config(out_data.shape_, in_data.shape_, nullptr, nullptr, sizeof(OType));
-
-  ReduceImpl<Reducer, NDim, OType, DType, OType, mxnet::op::mshadow_op::identity,
-             mxnet::op::mshadow_op::arg_min_max_set_index<OType, int>>
-            (stream, out_data, kWriteTo, in_data, workspace, config);
-}
-
-#endif // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h
index 80714fb6..9ce3967 100644
--- a/src/operator/numpy/np_broadcast_reduce_op.h
+++ b/src/operator/numpy/np_broadcast_reduce_op.h
@@ -298,7 +298,7 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
                             const std::vector<TBlob>& outputs) {
   using namespace mshadow;
   if (req[0] == kNullOp) return;
-  const NumpyReduceAxesParam& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
+  const auto& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
   if (param.initial.has_value()) {
     LOG(FATAL) << "initial is not supported yet";
   }
@@ -494,10 +494,6 @@ void NumpyArgMinMaxReduce(mshadow::Stream<cpu> *s, const TBlob& in_data, const T
     in_data.shape_.get<NDim>(), out_data.shape_.get<NDim>(), rshape, rstride);
 }
 
-#ifdef __CUDACC__
-#include "np_broadcast_reduce_op.cuh"
-#endif
-
 template<typename Reducer, typename xpu, typename IType>
 void NumpyArgMinMaxCompute(const nnvm::NodeAttrs& attrs,
                         const OpContext& ctx,
@@ -508,7 +504,7 @@ void NumpyArgMinMaxCompute(const nnvm::NodeAttrs& attrs,
   using namespace mshadow::expr;
   if (req[0] == kNullOp) return;
   // parse param
-  const ReduceAxisParam& param = nnvm::get<ReduceAxisParam>(attrs.parsed);
+  const auto& param = nnvm::get<ReduceAxisParam>(attrs.parsed);
   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
   TBlob out = outputs[0];
   TBlob in = inputs[0];
@@ -537,36 +533,52 @@ void NumpyArgMinMaxCompute(const nnvm::NodeAttrs& attrs,
   small = NumpyReduceAxesShapeImpl(in.shape_, axes, true);
   mxnet::TShape src_shape, dst_shape;
   BroadcastReduceShapeCompact(in.shape_, small, &src_shape, &dst_shape);
+  const TBlob in_data = in.reshape(src_shape);
+  // request a work space
+  size_t workspace_size = broadcast::ReduceWorkspaceSize(s, dst_shape, req[0], src_shape);
   MSHADOW_TYPE_SWITCH_WITH_BOOL(in.type_flag_, DType, {
     // define OType
     typedef mxnet::op::mshadow_op::IndexedNum<IType, DType> OType;
-    // request a work space
-    size_t workspace_size = sizeof(OType) * out.shape_.Size();
-    Tensor<xpu, 1, char> workspace =
-              ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
-    // set up intermediate output
-    TBlob intermediate = out;
-    intermediate.dptr_ = reinterpret_cast<int64_t*>(workspace.dptr_);
-    // reshape the input and intermediate output tensor
-    const TBlob in_data = in.reshape(src_shape);
-    const TBlob intermediate_out_data = intermediate.reshape(dst_shape);
     // switch dim
     BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
-      size_t workspace_size = broadcast::ReduceWorkspaceSize(
-        s, intermediate_out_data.shape_, req[0], in_data.shape_, sizeof(OType));
+      constexpr size_t align_size = 1024;
+      const size_t aligned_first_workspace_size = ((workspace_size + align_size - 1) / align_size)
+                                                  * align_size;
+      workspace_size = aligned_first_workspace_size +
+                       sizeof(OType) * out.shape_.Size();
       Tensor<xpu, 1, char> workspace =
-        ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+                ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+      // set up intermediate output
+      TBlob intermediate = out;
+      intermediate.dptr_ = reinterpret_cast<int64_t*>(workspace.dptr_ +
+                                                      aligned_first_workspace_size);
+      // reshape the input and intermediate output tensor
+      const TBlob intermediate_out_data = intermediate.reshape(dst_shape);
       NumpyArgMinMaxReduce<Reducer, NDim, DType, OType>(s, in_data,
         intermediate_out_data, workspace);
+      // parse the indices from the intermediate tensor back to the actual output tensor
+      using namespace mxnet_op;
+      Kernel<arg_min_max_parse, xpu>::Launch(
+          s, out.shape_.Size(), outputs[0].dptr<int64_t>(),
+          static_cast<OType*>(intermediate_out_data.dptr_));
     });
-    // parse the indices from the intermediate tensor back to the actual output tensor
-    using namespace mxnet_op;
-    Kernel<arg_min_max_parse, xpu>::Launch(
-        s, out.shape_.Size(), outputs[0].dptr<int64_t>(),
-        static_cast<OType*>(intermediate_out_data.dptr_));
   });
 }
 
+#if MXNET_USE_CUDA
+
+struct NumpyArgMinMaxRTCCompute {
+  std::string reducer;
+
+  void operator()(const nnvm::NodeAttrs& attrs,
+                  const OpContext& ctx,
+                  const std::vector<TBlob>& inputs,
+                  const std::vector<OpReqType>& req,
+                  const std::vector<TBlob>& outputs);
+};
+
+#endif
+
 template<typename xpu, bool normalize = false>
 inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs,
                                            const OpContext& ctx,
@@ -663,36 +675,6 @@ struct NumpyMomentsParam : public dmlc::Parameter<NumpyMomentsParam> {
   }
 };
 
-template<typename xpu, typename reducer, bool safe_acc, bool normalize = false,
-         typename OP = op::mshadow_op::identity>
-void ReduceAxesComputeWithWorkspaceImpl(const OpContext& ctx,
-                                        const std::vector<TBlob>& inputs,
-                                        const std::vector<OpReqType>& req,
-                                        const std::vector<TBlob>& outputs,
-                                        const mshadow::Tensor<xpu, 1, char>& workspace,
-                                        const mxnet::TShape& src_shape,
-                                        const mxnet::TShape& dst_shape,
-                                        const int ddof = 0) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
-
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
-    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
-      const TBlob in_data = inputs[0].reshape(src_shape);
-      const TBlob out_data = outputs[0].reshape(dst_shape);
-      BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
-        broadcast::Reduce<reducer, NDim, DType, OP, safe_acc>(
-            s, out_data, req[0], workspace, in_data);
-        if (normalize) {
-          auto out = out_data.FlatTo2D<xpu, OType>(s);
-          out /= scalar<OType>(src_shape.Size()/dst_shape.Size() - ddof);
-        }
-      });
-    });
-  });
-}
-
 struct NumpyWeightedAverageParam : public dmlc::Parameter<NumpyWeightedAverageParam> {
   dmlc::optional<mxnet::Tuple<int>> axis;
   bool returned;
@@ -871,13 +853,6 @@ struct avg_grad_w_1D_kernel {
   }
 };
 
-// Windows has issues with #ifdefs inside MSHADOW_TYPE_SWITCH
-#ifndef __CUDACC__
-#define NP_BROADCAST_REDUCE_OP_BROADCAST(OP) BinaryBroadcastCompute<xpu, mshadow_op::OP>
-#else
-#define NP_BROADCAST_REDUCE_OP_BROADCAST(OP) BinaryBroadcastRTCCompute {#OP}
-#endif
-
 template<typename xpu, bool back = false>
 void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs,
                                      const OpContext& ctx,
@@ -914,6 +889,9 @@ void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs,
     weights = weights.reshape(new_w_shape);
     small2 = TShape(new_w_shape.ndim(), 1);
   }
+  TBlob wa;
+  TBlob sum_of_wa;
+  Tensor<xpu, 1, char> workspace;
   MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
     // Get temp space
     size_t temp_data_size = data.shape_.Size() * sizeof(DType);
@@ -922,38 +900,53 @@ void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs,
     BroadcastReduceShapeCompact(data.shape_, small1, &src_shape, &dst_shape);
     size_t workspace_size = 0;
     workspace_size = broadcast::ReduceWorkspaceSize(
-      s, dst_shape, {kWriteTo}, src_shape, sizeof(DType));
+      s, dst_shape, {kWriteTo}, src_shape);
     size_t temp_mem_size = temp_data_size + temp_sum_size + workspace_size;
     Tensor<xpu, 1, char> temp_mem =
     ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(temp_mem_size), s);
-    DType *temp_data_ptr = reinterpret_cast<DType*>(temp_mem.dptr_);
-    DType *temp_sum_ptr = reinterpret_cast<DType*>(temp_mem.dptr_ + temp_data_size);
+    auto *temp_data_ptr = reinterpret_cast<DType*>(temp_mem.dptr_);
+    auto *temp_sum_ptr = reinterpret_cast<DType*>(temp_mem.dptr_ + temp_data_size);
     char *workspace_ptr = temp_mem.dptr_ + temp_data_size + temp_sum_size;
-    Tensor<xpu, 1, char> workspace(workspace_ptr, Shape1(workspace_size), s);
+    workspace = Tensor<xpu, 1, char>(workspace_ptr, Shape1(workspace_size), s);
 
     // Compute weighted data
-    TBlob wa = TBlob(temp_data_ptr, data.shape_, xpu::kDevMask);
-    NP_BROADCAST_REDUCE_OP_BROADCAST(mul)(
-      attrs, ctx, {data, weights}, {kWriteTo}, {wa});
-
-    // Compute sum of weighted data
-    TBlob sum_of_wa = TBlob(temp_sum_ptr, small1, xpu::kDevMask);
-    ReduceAxesComputeWithWorkspaceImpl<xpu, mshadow_op::sum, true>(
-      ctx, {wa}, {kWriteTo}, {sum_of_wa}, workspace, src_shape, dst_shape);
-    if (!back) {
-      const TBlob& avg = outputs[0];
-      const TBlob& sum_of_weights = outputs[1];
-      TShape w_src_shape, w_dst_shape;
-      BroadcastReduceShapeCompact(weights.shape_, small2, &w_src_shape, &w_dst_shape);
-      // Compute sum of weight
-      TBlob scl = sum_of_weights.reshape(small2);
-      ReduceAxesComputeWithWorkspaceImpl<xpu, mshadow_op::sum, true>(
-        ctx, {weights}, {kWriteTo}, {scl}, workspace, w_src_shape, w_dst_shape);
-
-      // Compute avg and assign output
-      NP_BROADCAST_REDUCE_OP_BROADCAST(div)(
-        attrs, ctx, {sum_of_wa, scl}, req, {avg.reshape(small1)});
-    } else {
+    wa = TBlob(temp_data_ptr, data.shape_, xpu::kDevMask);
+    sum_of_wa = TBlob(temp_sum_ptr, small1, xpu::kDevMask);
+  });
+#if !defined(__CUDACC__)
+  BinaryBroadcastCompute<xpu, mshadow_op::mul>(
+    attrs, ctx, {data, weights}, {kWriteTo}, {wa});
+
+  // Compute sum of weighted data
+  ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(
+    ctx, {wa}, {kWriteTo}, {sum_of_wa}, small1, &workspace);
+#else
+  BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, {data, weights}, {kWriteTo}, {wa});
+
+  // Compute sum of weighted data
+  ReduceAxesRTCComputeImpl(ctx, {wa}, {kWriteTo}, {sum_of_wa}, small1, "red::sum{}",
+                           &workspace, false, "identity");
+#endif
+  if (!back) {
+    const TBlob& avg = outputs[0];
+    const TBlob& sum_of_weights = outputs[1];
+    // Compute sum of weight
+    TBlob scl = sum_of_weights.reshape(small2);
+#if !defined(__CUDACC__)
+    ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(
+      ctx, {weights}, {kWriteTo}, {scl}, small2, &workspace);
+    // Compute avg and assign output
+    BinaryBroadcastCompute<xpu, mshadow_op::div>(
+      attrs, ctx, {sum_of_wa, scl}, req, {avg.reshape(small1)});
+#else
+    ReduceAxesRTCComputeImpl(ctx, {weights}, {kWriteTo}, {scl}, small2, "red::sum{}",
+                             &workspace, false, "identity");
+    // Compute avg and assign output
+    BinaryBroadcastRTCCompute {"div"}(
+      attrs, ctx, {sum_of_wa, scl}, req, {avg.reshape(small1)});
+#endif
+  } else {
+    MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
       // Compute and assign the derivatives of a and weights
       const TBlob& igrad_a = outputs[0];
       const TBlob& igrad_w = outputs[1];
@@ -992,12 +985,10 @@ void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs,
           }
         });
       })
-    }
-  });
+    });
+  }
 }
 
-#undef NP_BROADCAST_REDUCE_OP_BROADCAST
-
 template<typename xpu>
 void NumpyWeightedAverageForward(const nnvm::NodeAttrs& attrs,
                                  const OpContext& ctx,
@@ -1010,21 +1001,29 @@ void NumpyWeightedAverageForward(const nnvm::NodeAttrs& attrs,
   CHECK_NE(req[0], kWriteInplace) << "Average does not support write in-place";
   const auto& param = nnvm::get<NumpyWeightedAverageParam>(attrs.parsed);
   const TBlob& data = inputs[0];
+  TShape small;
   MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
     if (!param.weighted) {
-      TShape small = NumpyReduceAxesShapeImpl(data.shape_, param.axis, true);
+      small = NumpyReduceAxesShapeImpl(data.shape_, param.axis, true);
       // Compute sum of weights which equals to the product of sizes of reduced axes
       Stream<xpu>* s = ctx.get_stream<xpu>();
       auto ret = outputs[1].FlatTo1D<xpu, DType>(s);
       ret = scalar<DType>(data.shape_.Size()/small.Size());
-      // Compute mean
-      ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, true>(
-        ctx, inputs, req, {outputs[0]}, small);
-    } else {
-      NumpyWeightedAverageComputeImpl<xpu>(
-        attrs, ctx, inputs, req, outputs, param.axis);
     }
   });
+  if (!param.weighted) {
+    // Compute mean
+#if !defined(__CUDACC__)
+    ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, true>(
+      ctx, inputs, req, {outputs[0]}, small);
+#else
+    ReduceAxesRTCComputeImpl(ctx, inputs, req, {outputs[0]}, small,
+                             "red::sum{}", nullptr, true);
+#endif
+  } else {
+    NumpyWeightedAverageComputeImpl<xpu>(
+      attrs, ctx, inputs, req, outputs, param.axis);
+  }
 }
 
 template<typename xpu>
@@ -1090,40 +1089,58 @@ void NumpyMomentsForward(const nnvm::NodeAttrs& attrs,
   mxnet::TShape src_shape, dst_shape;
   BroadcastReduceShapeCompact(data.shape_, small, &src_shape, &dst_shape);
 
+  // Get workspace and temp space for data - mean
+  size_t workspace_size = broadcast::ReduceWorkspaceSize(s, dst_shape, req[0], src_shape);
+  size_t temp_data_size = data.shape_.Size() * common::mshadow_type_info(inputs[0].type_flag_).size;
+  size_t temp_mem_size = temp_data_size + workspace_size;
+  Tensor<xpu, 1, char> temp_mem =
+    ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(temp_mem_size), s);
+  char *workspace_ptr = temp_mem.dptr_ + temp_data_size;
+  Tensor<xpu, 1, char> workspace(workspace_ptr, Shape1(workspace_size), s);
+  // Compute mean
+#if !defined(__CUDACC__)
+  ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, true>(
+    ctx, inputs, {kWriteTo}, {mean}, small, &workspace);
+#else
+  ReduceAxesRTCComputeImpl(ctx, inputs, {kWriteTo}, {mean}, small, "red::sum{}",
+                           &workspace, true, "identity");
+#endif
+  // Compute data - mean
+  Shape<6> data_shape, mean_shape;
+  for (int i = 0; i < 6; ++i) {
+    data_shape[i] = (i < data.shape_.ndim()) ? data.shape_[i] : 1;
+    mean_shape[i] = (i < small.ndim()) ? small[i] : 1;
+  }
+#if !defined(__CUDACC__)
   MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
     MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
-      // Get workspace and temp space for data - mean
-      size_t workspace_size = 0;
-      workspace_size = broadcast::ReduceWorkspaceSize(
-        s, dst_shape, req[0], src_shape, sizeof(DType));
-      size_t temp_data_size = data.shape_.Size() * sizeof(DType);
-      size_t temp_mem_size = temp_data_size + workspace_size;
-      Tensor<xpu, 1, char> temp_mem =
-        ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(temp_mem_size), s);
       DType *temp_data_ptr = reinterpret_cast<DType*>(temp_mem.dptr_);
-      char *workspace_ptr = temp_mem.dptr_ + temp_data_size;
-      Tensor<xpu, 1, char> workspace(workspace_ptr, Shape1(workspace_size), s);
-      // Compute mean
-      ReduceAxesComputeWithWorkspaceImpl<xpu, mshadow_op::sum, true, true>(
-        ctx, inputs, {kWriteTo}, {mean}, workspace, src_shape, dst_shape);
-      // Compute data - mean
-      Shape<6> data_shape, mean_shape;
-      for (int i = 0; i < 6; ++i) {
-        data_shape[i] = (i < data.shape_.ndim()) ? data.shape_[i] : 1;
-        mean_shape[i] = (i < small.ndim()) ? small[i] : 1;
-      }
       Kernel<VarBroadcastKernel, xpu>::Launch(s, data_shape.Size(), temp_data_ptr,
         data.dptr<DType>(), mean.dptr<DType>(), data_shape, mean_shape);
       Tensor<xpu, 1, DType> temp_data_tensor(temp_data_ptr, Shape1(data.shape_.Size()), s);
       TBlob temp_data_blob = TBlob(temp_data_tensor).reshape(data.shape_);
-      ReduceAxesComputeWithWorkspaceImpl<xpu, mshadow_op::sum, true, true>(
-        ctx, {temp_data_blob}, {req[0]}, {moment}, workspace, src_shape, dst_shape, param.ddof);
-      if (sqrt) {
+      ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, true>(
+        ctx, {temp_data_blob}, {req[0]}, {moment}, small, &workspace, param.ddof);
+      if (sqrt && req[0] != kNullOp) {
         Tensor<xpu, 1, OType> moment_tensor = moment.FlatTo1D<xpu, OType>(s);
         moment_tensor = F<mshadow_op::square_root>(moment_tensor);
       }
     });
   });
+#else
+  MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      DType *temp_data_ptr = reinterpret_cast<DType*>(temp_mem.dptr_);
+      Kernel<VarBroadcastKernel, xpu>::Launch(s, data_shape.Size(), temp_data_ptr,
+        data.dptr<DType>(), mean.dptr<DType>(), data_shape, mean_shape);
+      Tensor<xpu, 1, DType> temp_data_tensor(temp_data_ptr, Shape1(data.shape_.Size()), s);
+      TBlob temp_data_blob = TBlob(temp_data_tensor).reshape(data.shape_);
+      ReduceAxesRTCComputeImpl(ctx, {temp_data_blob}, {req[0]}, {moment}, small,
+                               "red::sum{}", &workspace, true, "identity", param.ddof);
+      if (sqrt && req[0] != kNullOp) {
+        UnaryRTCCompute {"sqrt"}({}, ctx, {moment}, {kWriteInplace}, {moment});
+      }
+  });
+#endif
 }
 
 template<typename xpu>
@@ -1159,6 +1176,7 @@ void NumpyBroadcastToBackward(const nnvm::NodeAttrs& attrs,
   for (int i = 0; i < igrad_shape.ndim(); ++i) {
     expanded_igrad_shape[i + ndim_delta] = igrad_shape[i];
   }
+#if !defined(__CUDACC__)
   if (NeedSafeAcc<true>(inputs[0].type_flag_, outputs[0].type_flag_)) {
     ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(
         ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, expanded_igrad_shape);
@@ -1166,6 +1184,10 @@ void NumpyBroadcastToBackward(const nnvm::NodeAttrs& attrs,
     ReduceAxesComputeImpl<xpu, mshadow_op::sum, false>(
         ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, expanded_igrad_shape);
   }
+#else
+  ReduceAxesRTCComputeImpl(ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)},
+                           expanded_igrad_shape, "red::sum{}", nullptr, false);
+#endif
 }
 
 template<typename xpu, typename OP>
diff --git a/src/operator/numpy/np_broadcast_reduce_op_boolean.cu b/src/operator/numpy/np_broadcast_reduce_op_boolean.cu
index d3247b7..405ae4b 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_boolean.cu
+++ b/src/operator/numpy/np_broadcast_reduce_op_boolean.cu
@@ -29,12 +29,12 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(_npi_any)
-.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBoolCompute<gpu,
-  mshadow_op::sum, mshadow_op::NonZero, 0>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<NumpyReduceAxesBoolParam, 0>
+                                     {"NonZero", "red::sum{}", false});
 
 NNVM_REGISTER_OP(_npi_all)
-.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBoolCompute<gpu,
-  mshadow_op::product, mshadow_op::NonZero, 1>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<NumpyReduceAxesBoolParam, 1>
+                                     {"NonZero", "red::product{}", false});
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_broadcast_reduce_op_index.cu b/src/operator/numpy/np_broadcast_reduce_op_index.cu
index 892d046..eb6086c 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_index.cu
+++ b/src/operator/numpy/np_broadcast_reduce_op_index.cu
@@ -28,10 +28,10 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(_npi_argmax)
-.set_attr<FCompute>("FCompute<gpu>", NumpyArgMinMaxCompute<mshadow_op::argmax, gpu, int>);
+.set_attr<FCompute>("FCompute<gpu>", NumpyArgMinMaxRTCCompute{"red::argmax{}"});
 
 NNVM_REGISTER_OP(_npi_argmin)
-.set_attr<FCompute>("FCompute<gpu>", NumpyArgMinMaxCompute<mshadow_op::argmin, gpu, int>);
+.set_attr<FCompute>("FCompute<gpu>", NumpyArgMinMaxRTCCompute{"red::argmin{}"});
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu b/src/operator/numpy/np_broadcast_reduce_op_value.cu
index 422097d..6020573 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cu
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu
@@ -27,25 +27,30 @@
 namespace mxnet {
 namespace op {
 NNVM_REGISTER_OP(_npi_sum)
-.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesCompute<gpu, mshadow_op::sum, true>);
+.set_attr<FCompute>("FCompute<gpu>",
+                    ReduceAxesRTCCompute<NumpyReduceAxesParam, 0>{"identity", "red::sum{}", false});
 
 NNVM_REGISTER_OP(_backward_npi_sum)
 .set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseNone<gpu>);
 
 NNVM_REGISTER_OP(_npi_max)
-.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeCompute<gpu, mshadow::red::maximum>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<NumpyReduceAxesNoDTypeParam, 0>
+                                     {"identity", "red::maximum{}", false});
 
 NNVM_REGISTER_OP(_backward_npi_max)
 .set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeBackward<gpu, mshadow_op::eq>);
 
 NNVM_REGISTER_OP(_npi_min)
-.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeCompute<gpu, mshadow::red::minimum>);
+.set_attr<FCompute>("FCompute<gpu>",
+                    ReduceAxesRTCCompute<NumpyReduceAxesNoDTypeParam, 0>{"identity",
+                      "red::minimum{}", false});
 
 NNVM_REGISTER_OP(_backward_npi_min)
 .set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeBackward<gpu, mshadow_op::eq>);
 
 NNVM_REGISTER_OP(_npi_prod)
-.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesCompute<gpu, mshadow_op::product, true>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<NumpyReduceAxesParam, 1>{"identity",
+                                       "red::product{}", false});
 
 NNVM_REGISTER_OP(_backward_npi_prod)
 .set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseInOut<gpu, mshadow_op::rdiv>);
@@ -57,7 +62,8 @@ NNVM_REGISTER_OP(_backward_np_average)
 .set_attr<FCompute>("FCompute<gpu>", NumpyWeightedAverageBackward<gpu>);
 
 NNVM_REGISTER_OP(_npi_mean)
-.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesCompute<gpu, mshadow_op::sum, true, true>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<NumpyReduceAxesParam, 0>{"identity",
+                                       "red::sum{}", true});
 
 NNVM_REGISTER_OP(_backward_np_mean)
 .set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseNone<gpu, true>);
diff --git a/src/operator/numpy/np_constraint_check.h b/src/operator/numpy/np_constraint_check.h
index 80beaa3..01c54b6 100644
--- a/src/operator/numpy/np_constraint_check.h
+++ b/src/operator/numpy/np_constraint_check.h
@@ -56,9 +56,14 @@ void ConstraintCheckForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
   CHECK_EQ(outputs.size(), 1U);
   const ConstraintCheckParam& param =
       nnvm::get<ConstraintCheckParam>(attrs.parsed);
+#if !defined(__CUDACC__)
   ReduceAxesComputeImpl<xpu, mshadow_op::product, false, false,
                         op::mshadow_op::identity>(ctx, inputs, req, outputs,
                                                   outputs[0].shape_);
+#else
+  ReduceAxesRTCComputeImpl(ctx, inputs, req, outputs,
+                           outputs[0].shape_, "red::product{}");
+#endif
   std::string msg = param.msg;
   bool red_output = true;
   GetReduceOutput(ctx.get_stream<xpu>(), outputs[0], &red_output);
diff --git a/src/operator/numpy/np_cross-inl.h b/src/operator/numpy/np_cross-inl.h
index ab64564..7d5dd30 100644
--- a/src/operator/numpy/np_cross-inl.h
+++ b/src/operator/numpy/np_cross-inl.h
@@ -677,8 +677,7 @@ struct ReduceImplWrap {
     std::vector<int> reduce_axis = GetReduceAxis(out_move_shape, in_move_shape);
     if (reduce_axis.empty() || req == kNullOp) { return 0U; }
     ws_reduce = broadcast::ReduceWorkspaceSize(ctx.get_stream<xpu>(),
-                                               out_shape, req, in_shape,
-                                               sizeof(DType));
+                                               out_shape, req, in_shape);
     return ws_reduce;
   }
 
@@ -690,10 +689,17 @@ struct ReduceImplWrap {
                  const Tensor<xpu, 1, char> workspace_tensor) {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     // Reduce work_in to work_out.
+#if !defined(__CUDACC__)
     SUM_NDIM_SWITCH(work_out.ndim(), NDim, {
       op::broadcast::Reduce<mshadow_op::sum, NDim, DType, op::mshadow_op::identity, false>(
         s, work_out, kWriteTo, workspace_tensor, work_in);
     });
+#else
+    SUM_NDIM_SWITCH(work_out.ndim(), NDim, {
+      op::broadcast::RTCReduce(ctx, work_out, kWriteTo, workspace_tensor, work_in,
+                               "red::sum{}", NDim, "identity");
+    });
+#endif
     // Copy work_out to out_data.
     MXNET_ASSIGN_REQ_SWITCH(out_req, req_type, {
       mxnet_op::Kernel<ResAssign<req_type>, xpu>::Launch(
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h
index 1fa5890..be19b38 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.h
+++ b/src/operator/numpy/np_elemwise_broadcast_op.h
@@ -413,9 +413,9 @@ void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
     MSHADOW_TYPE_SWITCH(ograd.type_flag_, OType, {
       if (need_bc) {
           workspace_size_l = ReduceWorkspaceSize(
-            s, new_lshape, req[0], new_oshape, new_lshape, new_rshape, sizeof(OType));
+            s, new_lshape, req[0], new_oshape, new_lshape, new_rshape);
           workspace_size_r = ReduceWorkspaceSize(
-            s, new_rshape, req[1], new_oshape, new_lshape, new_rshape, sizeof(OType));
+            s, new_rshape, req[1], new_oshape, new_lshape, new_rshape);
       }
       size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
       size_t cast_tensor_size = tensor_size * sizeof(OType);
diff --git a/src/operator/numpy/np_kron-inl.h b/src/operator/numpy/np_kron-inl.h
index f357cec..bf0983f 100644
--- a/src/operator/numpy/np_kron-inl.h
+++ b/src/operator/numpy/np_kron-inl.h
@@ -188,6 +188,14 @@ void KronOpForwardImpl(const OpContext& ctx,
   });
 }
 
+#if !defined(__CUDACC__)
+#define NP_KRON_REDUCE_AXES(safe_acc, workspace, ...) \
+  ReduceAxesComputeImpl<xpu, mshadow_op::sum, safe_acc>(__VA_ARGS__, &workspace)
+#else
+#define NP_KRON_REDUCE_AXES(safe_acc, workspace, ...) \
+  ReduceAxesRTCComputeImpl(__VA_ARGS__, "red::sum{}", &workspace)
+#endif
+
 template<typename xpu>
 void KronOpBackwardImpl(const OpContext& ctx,
                         const std::vector<OpReqType>& req,
@@ -226,12 +234,23 @@ void KronOpBackwardImpl(const OpContext& ctx,
       const OpReqType& scalar_req = (ashape.ndim() == 0) ? req[0] : req[1];
       ASSIGN_DISPATCH(tensor_grad_, tensor_req,
                       broadcast_scalar(scalar_, tensor_grad_.shape_) * ograd_);
-      Tensor<xpu, 1, DType> workspace =
-        ctx.requested[0].get_space_typed<xpu, 1, DType>(Shape1(ograd.shape_.Size()), s);
-      ASSIGN_DISPATCH(workspace, kWriteTo, tensor_ * ograd_);
-
-      ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(
-        ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_);
+      TShape src_shape, dst_shape;
+      BroadcastReduceShapeCompact(ograd.shape_, scalar_grad_.shape_, &src_shape, &dst_shape);
+      size_t workspace_size = broadcast::ReduceWorkspaceSize(s, dst_shape,
+                                                             {scalar_req}, src_shape);
+      constexpr size_t align_size = 1024;
+      const size_t aligned_first_workspace_size = ((workspace_size + align_size - 1) / align_size)
+                                                  * align_size;
+      workspace_size = aligned_first_workspace_size + ograd.shape_.Size() * sizeof(DType);
+      Tensor<xpu, 1, char> workspace =
+        ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+      Tensor<xpu, 1, DType> temp(reinterpret_cast<DType*>(workspace.dptr_ +
+                                                          aligned_first_workspace_size),
+                                 Shape1(ograd.shape_.Size()), s);
+      ASSIGN_DISPATCH(temp, kWriteTo, tensor_ * ograd_);
+
+      NP_KRON_REDUCE_AXES(true, workspace, ctx, {TBlob(temp)}, {scalar_req},
+                          {TBlob(scalar_grad_)}, scalar_grad_.shape_);
     } else {
       MXNET_NDIM_SWITCH(oshape.ndim(), ndim, {
         Shape<ndim> ashape_ = oshape.get<ndim>();
@@ -276,6 +295,8 @@ void KronOpBackwardImpl(const OpContext& ctx,
   });
 }
 
+#undef NP_KRON_REDUCE_AXES
+
 template<typename xpu>
 inline void KronOpForward(const nnvm::NodeAttrs& attrs,
                           const OpContext& ctx,
diff --git a/src/operator/numpy/np_tensordot_op-inl.h b/src/operator/numpy/np_tensordot_op-inl.h
index 1bfc6d1..47e42d0 100644
--- a/src/operator/numpy/np_tensordot_op-inl.h
+++ b/src/operator/numpy/np_tensordot_op-inl.h
@@ -370,6 +370,14 @@ inline mxnet::TShape GetReverseShape(const mxnet::Tuple<int>& shape) {
   return shape2;
 }
 
+#if !defined(__CUDACC__)
+#define NP_TENSORDOT_REDUCE_AXES(safe_acc, ...) \
+  ReduceAxesComputeImpl<xpu, mshadow_op::sum, safe_acc>(__VA_ARGS__)
+#else
+#define NP_TENSORDOT_REDUCE_AXES(safe_acc, ...) \
+  ReduceAxesRTCComputeImpl(__VA_ARGS__, "red::sum{}")
+#endif
+
 /**
  * calculates tensordot derivative.
  */
@@ -424,8 +432,8 @@ void TensordotBackwardImpl(const Tuple<int>& a_axes_summed,
                               workspace.stream_);
       ASSIGN_DISPATCH(dtypespace, kWriteTo, tensor_ * out_grad_);
 
-      ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(
-        ctx, {TBlob(dtypespace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_);
+      NP_TENSORDOT_REDUCE_AXES(true, ctx, {TBlob(dtypespace)}, {scalar_req},
+                               {TBlob(scalar_grad_)}, scalar_grad_.shape_);
     } else {
       // Two tensors of at least 1 dimensions.
       Tuple<int> a_axes_remained;
@@ -734,8 +742,8 @@ void TensordotIntAxesBackwardImpl(const int axes,
         ctx.requested[0].get_space_typed<xpu, 1, DType>(Shape1(out_grad.shape_.Size()), s);
       ASSIGN_DISPATCH(workspace, kWriteTo, tensor_ * out_grad_);
 
-      ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(
-        ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_);
+      NP_TENSORDOT_REDUCE_AXES(true, ctx, {TBlob(workspace)}, {scalar_req},
+                               {TBlob(scalar_grad_)}, scalar_grad_.shape_);
     } else {
       // Two tensors of at least 1 dimensions.
       Tuple<int> a_axes_summed;
@@ -759,6 +767,8 @@ void TensordotIntAxesBackwardImpl(const int axes,
   });
 }
 
+#undef NP_TENSORDOT_REDUCE_AXES
+
 /**
  * backward function.
  */
diff --git a/src/operator/numpy/np_where_op-inl.h b/src/operator/numpy/np_where_op-inl.h
index 10ec081..43af21d 100644
--- a/src/operator/numpy/np_where_op-inl.h
+++ b/src/operator/numpy/np_where_op-inl.h
@@ -175,6 +175,13 @@ inline void NumpyWhereOpForward(const nnvm::NodeAttrs& attrs,
   });
 }
 
+#if !defined(__CUDACC__)
+#define NP_WHERE_REDUCE_AXES(safe_acc, ...) \
+  ReduceAxesComputeImpl<xpu, mshadow_op::sum, safe_acc>(__VA_ARGS__)
+#else
+#define NP_WHERE_REDUCE_AXES(safe_acc, ...) ReduceAxesRTCComputeImpl(__VA_ARGS__, "red::sum{}")
+#endif
+
 template<typename xpu>
 inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs,
                                  const OpContext& ctx,
@@ -226,9 +233,9 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs,
       size_t ws_size = 0;
       if (ograd.shape_ != dx.shape_ || ograd.shape_ != dy.shape_) {
         size_t ws_size1 = broadcast::ReduceWorkspaceSize(
-            s, expanded_lshape, req[0], expanded_oshape, sizeof(DType));
+            s, expanded_lshape, req[0], expanded_oshape);
         size_t ws_size2 = broadcast::ReduceWorkspaceSize(
-            s, expanded_rshape, req[1], expanded_oshape, sizeof(DType));
+            s, expanded_rshape, req[1], expanded_oshape);
         ws_size = std::max(ws_size1, ws_size2);
       }
       // process left output
@@ -246,10 +253,10 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs,
           s, ograd.Size(), req[0], cstride, oshape,
           cond.dptr<CType>(), ograd.dptr<DType>(), workspace.dptr_);
         if (NeedSafeAcc<true>(dx.type_flag_, dx.type_flag_)) {
-          ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(ctx, {TBlob(workspace)}, {req[0]},
+          NP_WHERE_REDUCE_AXES(true, ctx, {TBlob(workspace)}, {req[0]},
               {dx.reshape(expanded_lshape)}, expanded_lshape);
         } else {
-          ReduceAxesComputeImpl<xpu, mshadow_op::sum, false>(ctx, {TBlob(workspace)}, {req[0]},
+          NP_WHERE_REDUCE_AXES(false, ctx, {TBlob(workspace)}, {req[0]},
               {dx.reshape(expanded_lshape)}, expanded_lshape);
         }
       }
@@ -268,10 +275,10 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs,
           s, ograd.Size(), req[1], cstride, oshape,
           cond.dptr<CType>(), ograd.dptr<DType>(), workspace.dptr_);
         if (NeedSafeAcc<true>(dy.type_flag_, dy.type_flag_)) {
-          ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(ctx, {TBlob(workspace)}, {req[1]},
+          NP_WHERE_REDUCE_AXES(true, ctx, {TBlob(workspace)}, {req[1]},
               {dy.reshape(expanded_rshape)}, expanded_rshape);
         } else {
-          ReduceAxesComputeImpl<xpu, mshadow_op::sum, false>(ctx, {TBlob(workspace)}, {req[1]},
+          NP_WHERE_REDUCE_AXES(false, ctx, {TBlob(workspace)}, {req[1]},
               {dy.reshape(expanded_rshape)}, expanded_rshape);
         }
       }
@@ -367,7 +374,7 @@ inline void NumpyWhereScalarOpBackward(const nnvm::NodeAttrs& attrs,
       size_t ws_size = 0;
       if (ograd.shape_ != dx.shape_) {
         ws_size = broadcast::ReduceWorkspaceSize(s, expanded_lshape, req[0],
-                                                 expanded_oshape, sizeof(DType));
+                                                 expanded_oshape);
       }
       // If lscalar, then process right output, `is_left` should be false
       if (ograd.shape_ == dx.shape_) {
@@ -384,10 +391,10 @@ inline void NumpyWhereScalarOpBackward(const nnvm::NodeAttrs& attrs,
           s, ograd.Size(), req[0], cstride, oshape,
           cond.dptr<CType>(), ograd.dptr<DType>(), workspace.dptr_);
         if (NeedSafeAcc<true>(dx.type_flag_, dx.type_flag_)) {
-          ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(ctx, {TBlob(workspace)}, {req[0]},
+          NP_WHERE_REDUCE_AXES(true, ctx, {TBlob(workspace)}, {req[0]},
               {dx.reshape(expanded_lshape)}, expanded_lshape);
         } else {
-          ReduceAxesComputeImpl<xpu, mshadow_op::sum, false>(ctx, {TBlob(workspace)}, {req[0]},
+          NP_WHERE_REDUCE_AXES(false, ctx, {TBlob(workspace)}, {req[0]},
               {dx.reshape(expanded_lshape)}, expanded_lshape);
         }
       }
@@ -395,6 +402,8 @@ inline void NumpyWhereScalarOpBackward(const nnvm::NodeAttrs& attrs,
   });
 }
 
+#undef NP_WHERE_REDUCE_AXES
+
 template<typename xpu>
 inline void NumpyWhereScalar2OpForward(const nnvm::NodeAttrs& attrs,
                                       const OpContext& ctx,
diff --git a/src/operator/numpy/random/dist_common.h b/src/operator/numpy/random/dist_common.h
index ab8afe9..1e43134 100644
--- a/src/operator/numpy/random/dist_common.h
+++ b/src/operator/numpy/random/dist_common.h
@@ -277,6 +277,86 @@ inline bool TwoparamsDistOpConcatShape(const nnvm::NodeAttrs &attrs,
   return true;
 }
 
+template<typename xpu, int ndim, typename DType>
+inline void CommonReparamBackwardImpl(const OpContext& ctx,
+                                      const std::vector<TBlob>& inputs,
+                                      const std::vector<OpReqType>& req,
+                                      const std::vector<TBlob>& outputs,
+                                      const mxnet::TShape& new_lshape,
+                                      const mxnet::TShape& new_rshape,
+                                      const mxnet::TShape& new_oshape) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace broadcast;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob lgrad = outputs[0].reshape(new_lshape);
+  const TBlob rgrad = outputs[1].reshape(new_rshape);
+  const TBlob ograd = inputs[0].reshape(new_oshape);
+  // Mean
+  const TBlob lhs = inputs[2].reshape(new_lshape);
+  // Scale
+  const TBlob rhs = inputs[3].reshape(new_rshape);
+  const TBlob samples = inputs[4].reshape(new_oshape);
+  const TBlob noise = inputs[5].reshape(new_oshape);
+  size_t workspace_size_l = ReduceWorkspaceSize(
+    s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_);
+  size_t workspace_size_r = ReduceWorkspaceSize(
+    s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_);
+  size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
+  Tensor<xpu, 1, char> workspace =
+    ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+#if !defined(__CUDACC__)
+  Reduce<red::sum, ndim, DType, op::mshadow_op::identity>(
+    s, lgrad, req[0], workspace, ograd);
+  Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
+    s, rgrad, req[1], workspace, ograd, noise, rhs);
+#else
+  RTCReduce(ctx, lgrad, req[0], workspace, ograd, "red::sum{}", ndim, "identity");
+  RTCReduce(ctx, rgrad, req[1], workspace, ograd, noise, rhs, "red::sum{}", ndim, "mul", "left");
+#endif
+}
+
+template<typename xpu, int ndim, typename DType>
+inline void CommonScalarReparamBackwardImpl(const OpContext& ctx,
+                                            const std::vector<TBlob>& inputs,
+                                            const std::vector<OpReqType>& req,
+                                            const std::vector<TBlob>& outputs,
+                                            const mxnet::TShape& new_ishape,
+                                            const mxnet::TShape& new_oshape,
+                                            const bool loc_is_tensor = false) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace broadcast;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob igrad = outputs[0].reshape(new_ishape);
+  // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor,
+  //          samples, noise]
+  const TBlob ograd = inputs[0].reshape(new_oshape);
+  const TBlob itensor = inputs[2].reshape(new_ishape);
+  const TBlob samples = inputs[3].reshape(new_oshape);
+  const TBlob noise = inputs[4].reshape(new_oshape);
+  size_t workspace_size =
+    ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_);
+  Tensor<xpu, 1, char> workspace =
+    ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+#if !defined(__CUDACC__)
+  if (loc_is_tensor) {
+    Reduce<red::sum, ndim, DType, op::mshadow_op::identity>(s, igrad, req[0],
+                                                            workspace, ograd);
+  } else {
+    Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
+      s, igrad, req[0], workspace, ograd, noise, noise);
+  }
+#else
+  if (loc_is_tensor) {
+    RTCReduce(ctx, igrad, req[0], workspace, ograd, "red::sum{}", ndim, "identity");
+  } else {
+    RTCReduce(ctx, igrad, req[0], workspace, ograd, noise, noise, "red::sum{}",
+              ndim, "mul", "left");
+  }
+#endif
+}
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/numpy/random/np_exponential_op.h b/src/operator/numpy/random/np_exponential_op.h
index 374b3b4..0f0d462 100644
--- a/src/operator/numpy/random/np_exponential_op.h
+++ b/src/operator/numpy/random/np_exponential_op.h
@@ -171,11 +171,16 @@ inline void ExponentialReparamBackwardImpl(const OpContext& ctx,
   const TBlob samples = inputs[3].reshape(new_oshape);
   const TBlob noise = inputs[4].reshape(new_oshape);
   size_t workspace_size =
-      ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType));
+      ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_);
   Tensor<xpu, 1, char> workspace =
       ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+#if !defined(__CUDACC__)
   Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
       s, igrad, req[0], workspace, ograd, noise, noise);
+#else
+  RTCReduce(ctx, igrad, req[0], workspace, ograd, noise, noise,
+            "red::sum{}", ndim, "mul", "left");
+#endif
 }
 
 template<typename xpu>
diff --git a/src/operator/numpy/random/np_gamma_op.h b/src/operator/numpy/random/np_gamma_op.h
index a0f3299..55041a9 100644
--- a/src/operator/numpy/random/np_gamma_op.h
+++ b/src/operator/numpy/random/np_gamma_op.h
@@ -420,14 +420,19 @@ inline void GammaReparamBackwardImpl(const OpContext& ctx,
   const TBlob alpha = inputs[1].reshape(new_ishape);
   TBlob samples = inputs[2].reshape(new_oshape);
   size_t workspace_size =
-      ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType));
+      ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_);
   // Convert samples to standard gamma
   Kernel<op_with_req<mshadow_op::div, kWriteTo>, xpu>::Launch(
     s, samples.Size(), samples.dptr<DType>(), samples.dptr<DType>(), DType(scale));
   Tensor<xpu, 1, char> workspace =
       ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+#if !defined(__CUDACC__)
   Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::gamma_implicit_grad>(
       s, igrad, req[0], workspace, ograd, alpha, samples);
+#else
+  RTCReduce(ctx, igrad, req[0], workspace, ograd, alpha, samples, "red::sum{}", ndim,
+            "mul", "gamma_implicit_grad");
+#endif
   Kernel<op_with_req<mshadow_op::mul, kWriteTo>, xpu>::Launch(
     s, igrad.Size(), igrad.dptr<DType>(), igrad.dptr<DType>(), DType(scale));
   // Convert samples back, otherwise the output would be corrupted.
diff --git a/src/operator/numpy/random/np_location_scale_op.h b/src/operator/numpy/random/np_location_scale_op.h
index 73403f3..0179a57 100644
--- a/src/operator/numpy/random/np_location_scale_op.h
+++ b/src/operator/numpy/random/np_location_scale_op.h
@@ -275,72 +275,6 @@ void NumpyLocationScaleForward(const nnvm::NodeAttrs &attrs,
   }
 }
 
-template<typename xpu, int ndim, typename DType>
-inline void LocationScaleReparamBackwardImpl(const OpContext& ctx,
-                                             const std::vector<TBlob>& inputs,
-                                             const std::vector<OpReqType>& req,
-                                             const std::vector<TBlob>& outputs,
-                                             const mxnet::TShape& new_lshape,
-                                             const mxnet::TShape& new_rshape,
-                                             const mxnet::TShape& new_oshape) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
-  using namespace broadcast;
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  const TBlob lgrad = outputs[0].reshape(new_lshape);
-  const TBlob rgrad = outputs[1].reshape(new_rshape);
-  const TBlob ograd = inputs[0].reshape(new_oshape);
-  // Mean
-  const TBlob lhs = inputs[2].reshape(new_lshape);
-  // Scale
-  const TBlob rhs = inputs[3].reshape(new_rshape);
-  const TBlob samples = inputs[4].reshape(new_oshape);
-  const TBlob noise = inputs[5].reshape(new_oshape);
-  size_t workspace_size_l = ReduceWorkspaceSize(
-    s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType));
-  size_t workspace_size_r = ReduceWorkspaceSize(
-    s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType));
-  size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
-  Tensor<xpu, 1, char> workspace =
-    ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
-  Reduce<red::sum, ndim, DType, op::mshadow_op::identity>(
-    s, lgrad, req[0], workspace, ograd);
-  Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
-    s, rgrad, req[1], workspace, ograd, noise, rhs);
-}
-
-template<typename xpu, int ndim, typename DType>
-inline void ScalarLocationScaleReparamBackwardImpl(const OpContext& ctx,
-                                                   const std::vector<TBlob>& inputs,
-                                                   const std::vector<OpReqType>& req,
-                                                   const std::vector<TBlob>& outputs,
-                                                   const mxnet::TShape& new_ishape,
-                                                   const mxnet::TShape& new_oshape,
-                                                   const bool loc_is_tensor) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
-  using namespace broadcast;
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  const TBlob igrad = outputs[0].reshape(new_ishape);
-  // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor,
-  //          samples, noise]
-  const TBlob ograd = inputs[0].reshape(new_oshape);
-  const TBlob itensor = inputs[2].reshape(new_ishape);
-  const TBlob samples = inputs[3].reshape(new_oshape);
-  const TBlob noise = inputs[4].reshape(new_oshape);
-  size_t workspace_size =
-    ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType));
-  Tensor<xpu, 1, char> workspace =
-    ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
-  if (loc_is_tensor) {
-    Reduce<red::sum, ndim, DType, op::mshadow_op::identity>(s, igrad, req[0],
-                                                            workspace, ograd);
-  } else {
-    Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
-      s, igrad, req[0], workspace, ograd, noise, noise);
-  }
-}
-
 // Allow logistic and gumbel sampling to be differentiable,
 // using reparameterization trick described in:
 // Auto-encoding variational bayes.
@@ -359,7 +293,7 @@ void LocationScaleReparamBackward(const nnvm::NodeAttrs& attrs,
   if (outputs.size() == 0U) {
     return;
   }
-  const NumpyLocationScaleParam &param = nnvm::get<NumpyLocationScaleParam>(attrs.parsed);
+  const auto &param = nnvm::get<NumpyLocationScaleParam>(attrs.parsed);
   // [tensor tensor] case
   if (inputs.size() == 6U) {
     mxnet::TShape new_lshape, new_rshape, new_oshape;
@@ -367,7 +301,7 @@ void LocationScaleReparamBackward(const nnvm::NodeAttrs& attrs,
                          &new_lshape, &new_rshape, &new_oshape);
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       BROADCAST_NDIM_SWITCH(ndim, NDim, {
-        LocationScaleReparamBackwardImpl<xpu, NDim, DType>(
+        CommonReparamBackwardImpl<xpu, NDim, DType>(
           ctx, inputs, req, outputs, new_lshape, new_rshape, new_oshape);
       });
     });
@@ -380,7 +314,7 @@ void LocationScaleReparamBackward(const nnvm::NodeAttrs& attrs,
     bool loc_is_tensor = !param.loc.has_value();
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       BROADCAST_NDIM_SWITCH(ndim, NDim, {
-        ScalarLocationScaleReparamBackwardImpl<xpu, NDim, DType>(
+        CommonScalarReparamBackwardImpl<xpu, NDim, DType>(
           ctx, inputs, req, outputs, new_ishape, new_oshape, loc_is_tensor);
       });
     });
diff --git a/src/operator/numpy/random/np_normal_op.h b/src/operator/numpy/random/np_normal_op.h
index e43d98d..06b5bfa 100644
--- a/src/operator/numpy/random/np_normal_op.h
+++ b/src/operator/numpy/random/np_normal_op.h
@@ -161,7 +161,7 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs,
                          const std::vector<TBlob> &outputs) {
   using namespace mshadow;
   using namespace mxnet_op;
-  const NumpyNormalParam &param = nnvm::get<NumpyNormalParam>(attrs.parsed);
+  const auto &param = nnvm::get<NumpyNormalParam>(attrs.parsed);
   Stream<xpu> *s = ctx.get_stream<xpu>();
   // Generate base random number.
   Random<xpu, float> *prnd = ctx.requested[0].get_random<xpu, float>(s);
@@ -240,72 +240,6 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs,
   }
 }
 
-template<typename xpu, int ndim, typename DType>
-inline void NormalReparamBackwardImpl(const OpContext& ctx,
-                                      const std::vector<TBlob>& inputs,
-                                      const std::vector<OpReqType>& req,
-                                      const std::vector<TBlob>& outputs,
-                                      const mxnet::TShape& new_lshape,
-                                      const mxnet::TShape& new_rshape,
-                                      const mxnet::TShape& new_oshape) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
-  using namespace broadcast;
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  const TBlob lgrad = outputs[0].reshape(new_lshape);
-  const TBlob rgrad = outputs[1].reshape(new_rshape);
-  const TBlob ograd = inputs[0].reshape(new_oshape);
-  // Mean
-  const TBlob lhs = inputs[2].reshape(new_lshape);
-  // Variance
-  const TBlob rhs = inputs[3].reshape(new_rshape);
-  const TBlob samples = inputs[4].reshape(new_oshape);
-  const TBlob noise = inputs[5].reshape(new_oshape);
-  size_t workspace_size_l = ReduceWorkspaceSize(
-      s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType));
-  size_t workspace_size_r = ReduceWorkspaceSize(
-      s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType));
-  size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
-  Tensor<xpu, 1, char> workspace =
-      ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
-  Reduce<red::sum, ndim, DType, op::mshadow_op::identity>(s,
-          lgrad, req[0], workspace, ograd);
-  Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
-      s, rgrad, req[1], workspace, ograd, noise, rhs);
-}
-
-template<typename xpu, int ndim, typename DType>
-inline void ScalarNormalReparamBackwardImpl(const OpContext& ctx,
-                                            const std::vector<TBlob>& inputs,
-                                            const std::vector<OpReqType>& req,
-                                            const std::vector<TBlob>& outputs,
-                                            const mxnet::TShape& new_ishape,
-                                            const mxnet::TShape& new_oshape,
-                                            const bool loc_is_tensor) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
-  using namespace broadcast;
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  const TBlob igrad = outputs[0].reshape(new_ishape);
-  // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor,
-  //          samples, noise]
-  const TBlob ograd = inputs[0].reshape(new_oshape);
-  const TBlob itensor = inputs[2].reshape(new_ishape);
-  const TBlob samples = inputs[3].reshape(new_oshape);
-  const TBlob noise = inputs[4].reshape(new_oshape);
-  size_t workspace_size =
-      ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType));
-  Tensor<xpu, 1, char> workspace =
-      ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
-  if (loc_is_tensor) {
-    Reduce<red::sum, ndim, DType, op::mshadow_op::identity>(s, igrad, req[0],
-                                                            workspace, ograd);
-  } else {
-    Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
-        s, igrad, req[0], workspace, ograd, noise, noise);
-  }
-}
-
 // Allow normal sampling to be differentiable,
 // using reparameterization trick described in:
 // Auto-encoding variational bayes.
@@ -324,7 +258,7 @@ void NormalReparamBackward(const nnvm::NodeAttrs& attrs,
   if (outputs.size() == 0U) {
     return;
   }
-  const NumpyNormalParam &param = nnvm::get<NumpyNormalParam>(attrs.parsed);
+  const auto &param = nnvm::get<NumpyNormalParam>(attrs.parsed);
   // [tensor tensor] case
   if (inputs.size() == 6U) {
     mxnet::TShape new_lshape, new_rshape, new_oshape;
@@ -332,7 +266,7 @@ void NormalReparamBackward(const nnvm::NodeAttrs& attrs,
                          &new_lshape, &new_rshape, &new_oshape);
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       BROADCAST_NDIM_SWITCH(ndim, NDim, {
-        NormalReparamBackwardImpl<xpu, NDim, DType>(
+        CommonReparamBackwardImpl<xpu, NDim, DType>(
           ctx, inputs, req, outputs, new_lshape, new_rshape, new_oshape);
       });
     });
@@ -345,7 +279,7 @@ void NormalReparamBackward(const nnvm::NodeAttrs& attrs,
     bool loc_is_tensor = !param.loc.has_value();
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       BROADCAST_NDIM_SWITCH(ndim, NDim, {
-        ScalarNormalReparamBackwardImpl<xpu, NDim, DType>(
+        CommonScalarReparamBackwardImpl<xpu, NDim, DType>(
           ctx, inputs, req, outputs, new_ishape, new_oshape, loc_is_tensor);
       });
     });
diff --git a/src/operator/numpy/random/np_pareto_op.h b/src/operator/numpy/random/np_pareto_op.h
index 5e5d26a..16731c1 100644
--- a/src/operator/numpy/random/np_pareto_op.h
+++ b/src/operator/numpy/random/np_pareto_op.h
@@ -155,32 +155,6 @@ void NumpyParetoForward(const nnvm::NodeAttrs &attrs,
   }
 }
 
-template<typename xpu, int ndim, typename DType>
-inline void ScalarParetoReparamBackwardImpl(const OpContext& ctx,
-                                             const std::vector<TBlob>& inputs,
-                                             const std::vector<OpReqType>& req,
-                                             const std::vector<TBlob>& outputs,
-                                             const mxnet::TShape& new_ishape,
-                                             const mxnet::TShape& new_oshape) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
-  using namespace broadcast;
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  const TBlob igrad = outputs[0].reshape(new_ishape);
-  // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor,
-  //          samples, noise]
-  const TBlob ograd = inputs[0].reshape(new_oshape);
-  const TBlob itensor = inputs[2].reshape(new_ishape);
-  const TBlob samples = inputs[3].reshape(new_oshape);
-  const TBlob noise = inputs[4].reshape(new_oshape);
-  size_t workspace_size =
-      ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType));
-  Tensor<xpu, 1, char> workspace =
-      ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
-  Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
-      s, igrad, req[0], workspace, ograd, noise, noise);
-  }
-
 template<typename xpu>
 void ParetoReparamBackward(const nnvm::NodeAttrs& attrs,
                             const OpContext& ctx,
@@ -202,7 +176,7 @@ if (inputs.size() == 5U) {
                       &new_ishape, &new_ishape, &new_oshape);
   MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
     BROADCAST_NDIM_SWITCH(ndim, NDim, {
-      ScalarParetoReparamBackwardImpl<xpu, NDim, DType>(
+      CommonScalarReparamBackwardImpl<xpu, NDim, DType>(
         ctx, inputs, reqs, outputs, new_ishape, new_oshape);
       });
     });
diff --git a/src/operator/numpy/random/np_rayleigh_op.h b/src/operator/numpy/random/np_rayleigh_op.h
index 0f940e5..75c4784 100644
--- a/src/operator/numpy/random/np_rayleigh_op.h
+++ b/src/operator/numpy/random/np_rayleigh_op.h
@@ -153,32 +153,6 @@ void NumpyRayleighForward(const nnvm::NodeAttrs &attrs,
   }
 }
 
-template<typename xpu, int ndim, typename DType>
-inline void ScalarRayleighReparamBackwardImpl(const OpContext& ctx,
-                                              const std::vector<TBlob>& inputs,
-                                              const std::vector<OpReqType>& req,
-                                              const std::vector<TBlob>& outputs,
-                                              const mxnet::TShape& new_ishape,
-                                              const mxnet::TShape& new_oshape) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
-  using namespace broadcast;
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  const TBlob igrad = outputs[0].reshape(new_ishape);
-  // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor,
-  //          samples, noise]
-  const TBlob ograd = inputs[0].reshape(new_oshape);
-  const TBlob itensor = inputs[2].reshape(new_ishape);
-  const TBlob samples = inputs[3].reshape(new_oshape);
-  const TBlob noise = inputs[4].reshape(new_oshape);
-  size_t workspace_size =
-      ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType));
-  Tensor<xpu, 1, char> workspace =
-      ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
-  Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
-      s, igrad, req[0], workspace, ograd, noise, noise);
-}
-
 template<typename xpu>
 void RayleighReparamBackward(const nnvm::NodeAttrs& attrs,
                              const OpContext& ctx,
@@ -200,7 +174,7 @@ void RayleighReparamBackward(const nnvm::NodeAttrs& attrs,
                          &new_ishape, &new_ishape, &new_oshape);
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       BROADCAST_NDIM_SWITCH(ndim, NDim, {
-        ScalarRayleighReparamBackwardImpl<xpu, NDim, DType>(
+        CommonScalarReparamBackwardImpl<xpu, NDim, DType>(
           ctx, inputs, req, outputs, new_ishape, new_oshape);
       });
     });
diff --git a/src/operator/numpy/random/np_weibull_op.h b/src/operator/numpy/random/np_weibull_op.h
index 970dc85..a7d6d5d 100644
--- a/src/operator/numpy/random/np_weibull_op.h
+++ b/src/operator/numpy/random/np_weibull_op.h
@@ -155,32 +155,6 @@ void NumpyWeibullForward(const nnvm::NodeAttrs &attrs,
   }
 }
 
-template<typename xpu, int ndim, typename DType>
-inline void ScalarWeibullReparamBackwardImpl(const OpContext& ctx,
-                                             const std::vector<TBlob>& inputs,
-                                             const std::vector<OpReqType>& req,
-                                             const std::vector<TBlob>& outputs,
-                                             const mxnet::TShape& new_ishape,
-                                             const mxnet::TShape& new_oshape) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
-  using namespace broadcast;
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  const TBlob igrad = outputs[0].reshape(new_ishape);
-  // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor,
-  //          samples, noise]
-  const TBlob ograd = inputs[0].reshape(new_oshape);
-  const TBlob itensor = inputs[2].reshape(new_ishape);
-  const TBlob samples = inputs[3].reshape(new_oshape);
-  const TBlob noise = inputs[4].reshape(new_oshape);
-  size_t workspace_size =
-      ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType));
-  Tensor<xpu, 1, char> workspace =
-      ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
-  Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
-      s, igrad, req[0], workspace, ograd, noise, noise);
-  }
-
 template<typename xpu>
 void WeibullReparamBackward(const nnvm::NodeAttrs& attrs,
                             const OpContext& ctx,
@@ -202,7 +176,7 @@ if (inputs.size() == 5U) {
                       &new_ishape, &new_ishape, &new_oshape);
   MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
     BROADCAST_NDIM_SWITCH(ndim, NDim, {
-      ScalarWeibullReparamBackwardImpl<xpu, NDim, DType>(
+      CommonScalarReparamBackwardImpl<xpu, NDim, DType>(
         ctx, inputs, reqs, outputs, new_ishape, new_oshape);
       });
     });
diff --git a/src/operator/quantization/quantization_utils.h b/src/operator/quantization/quantization_utils.h
index 2c5c1eb..0d89570 100644
--- a/src/operator/quantization/quantization_utils.h
+++ b/src/operator/quantization/quantization_utils.h
@@ -184,7 +184,7 @@ inline size_t ConfigReduce(mshadow::Stream<xpu>* s,
   CHECK_EQ(src_shape->ndim(), NDim);
   CHECK_EQ(dst_shape->ndim(), NDim);
 
-  return broadcast::ReduceWorkspaceSize(s, *dst_shape, kWriteTo, *src_shape, sizeof(DType));
+  return broadcast::ReduceWorkspaceSize(s, *dst_shape, kWriteTo, *src_shape);
 }
 
 enum QuantizeOutType { kAuto = 0, kInt8, kUint8 };
diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h
index d8814cc..cfbdb7f 100644
--- a/src/operator/quantization/quantize_v2-inl.h
+++ b/src/operator/quantization/quantize_v2-inl.h
@@ -205,10 +205,17 @@ class QuantizeV2Operator {
                        dev_id);
         Tensor<xpu, 1, char> workspace(temp_space.dptr_ + 2 * actual_float_size,
                                        Shape1(temp_reduce_size), s);
+#if !defined(__CUDACC__)
         broadcast::Reduce<red::minimum, 2, SrcDType, mshadow::op::identity>(
             s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape));
         broadcast::Reduce<red::maximum, 2, SrcDType, mshadow::op::identity>(
             s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape));
+#else
+        broadcast::RTCReduce(ctx, in_min_t.reshape(dst_shape), kWriteTo, workspace,
+                             inputs[0].reshape(src_shape), "red::minimum{}", 2, "identity");
+        broadcast::RTCReduce(ctx, in_max_t.reshape(dst_shape), kWriteTo, workspace,
+                             inputs[0].reshape(src_shape), "red::maximum{}", 2, "identity");
+#endif
         if (out_type == mshadow::kUint8) {
           Kernel<quantize_v2_unsigned, xpu>::Launch(
               s, outputs[0].Size(), outputs[0].dptr<uint8_t>(), outputs[1].dptr<float>(),
diff --git a/src/operator/quantization/requantize-inl.h b/src/operator/quantization/requantize-inl.h
index 2bdc3a7..5668670 100644
--- a/src/operator/quantization/requantize-inl.h
+++ b/src/operator/quantization/requantize-inl.h
@@ -148,6 +148,7 @@ void RequantizeForward(const nnvm::NodeAttrs& attrs,
           temp_space.dptr_ + 8) + 1, Shape1(1), xpu::kDevMask, dev_id);
     Tensor<xpu, 1, char> workspace(
         temp_space.dptr_+2*actual_float_size+2*actual_quantized_size, Shape1(temp_reduce_size), s);
+#if !defined(__CUDACC__)
     broadcast::Reduce<red::minimum, 2, SrcDType, mshadow::op::identity>(
       s, actual_min_quantized.reshape(dst_shape),
       kWriteTo, workspace, inputs[0].reshape(src_shape));
@@ -158,6 +159,18 @@ void RequantizeForward(const nnvm::NodeAttrs& attrs,
     broadcast::Reduce<red::maximum, 2, SrcDType, mshadow::op::identity>(
       s, actual_max_quantized.reshape(dst_shape),
       kWriteTo, workspace, inputs[0].reshape(src_shape));
+#else
+    broadcast::RTCReduce(ctx, actual_min_quantized.reshape(dst_shape),
+                         kWriteTo, workspace, inputs[0].reshape(src_shape),
+                         "red::minimum{}", 2, "identity");
+    Kernel<QuantizedToFloatStruct, xpu>::Launch(s, 1,
+        actual_min_float.dptr_, actual_min_quantized.dptr<SrcDType>(),
+        inputs[1].dptr<float>(), inputs[2].dptr<float>());
+
+    broadcast::RTCReduce(ctx, actual_max_quantized.reshape(dst_shape),
+                         kWriteTo, workspace, inputs[0].reshape(src_shape),
+                         "red::maximum{}", 2, "identity");
+#endif
     Kernel<QuantizedToFloatStruct, xpu>::Launch(s, 1,
         actual_max_float.dptr_, actual_max_quantized.dptr<SrcDType>(),
         inputs[1].dptr<float>(), inputs[2].dptr<float>());
diff --git a/src/operator/random/pdf_op.h b/src/operator/random/pdf_op.h
index f6dc777..f53d3a6 100644
--- a/src/operator/random/pdf_op.h
+++ b/src/operator/random/pdf_op.h
@@ -592,10 +592,11 @@ void PdfOpBackward(const nnvm::NodeAttrs& attrs,
   const PdfParam& param = nnvm::get<PdfParam>(attrs.parsed);
   const size_t N(outputs[1].Size());
   const TShape src_shape(Shape2(N, outputs[0].Size() / N)), dst_shape(Shape2(N, 1));
+  const size_t red_work_size(broadcast::ReduceWorkspaceSize(
+          s, dst_shape, kAddTo, src_shape));
+#if !defined(__CUDACC__)
   // Inputs to PdfOpBackward: grad, samples, parm1, parm2, pdf.
   MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-    const size_t red_work_size(broadcast::ReduceWorkspaceSize(
-            s, dst_shape, kAddTo, src_shape, sizeof(DType)));
     const size_t tmp_size(outputs[0].Size() * pnum * sizeof(DType) + red_work_size);
     Tensor<xpu, 1, char> tmp_space =
             ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(tmp_size), s);
@@ -620,6 +621,35 @@ void PdfOpBackward(const nnvm::NodeAttrs& attrs,
        s, outputs[2].reshape(dst_shape), req[2], red_work, grads[2].reshape(src_shape));
     }
   });
+#else
+  // Inputs to PdfOpBackward: grad, samples, parm1, parm2, pdf.
+  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    const size_t tmp_size(outputs[0].Size() * pnum * sizeof(DType) + red_work_size);
+    Tensor<xpu, 1, char> tmp_space =
+            ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(tmp_size), s);
+    std::vector<TBlob> grads = {outputs[0]};
+    grads.push_back(TBlob(tmp_space.dptr_, outputs[0].shape_,
+                          outputs[1].dev_mask(), outputs[1].type_flag_, outputs[1].dev_id()));
+    if (pnum == 2) {
+      grads.push_back(TBlob(tmp_space.dptr_ + outputs[0].Size() * sizeof(DType), outputs[0].shape_,
+                            outputs[2].dev_mask(), outputs[2].type_flag_, outputs[2].dev_id()));
+    }
+    if (param.is_log) {
+      PdfGradCaller<xpu, DType, pdfgrad<true>, pnum, vparm>::op(inputs, req, grads, s);
+    } else {
+      PdfGradCaller<xpu, DType, pdfgrad<false>, pnum, vparm>::op(inputs, req, grads, s);
+    }
+    Tensor<xpu, 1, char> red_work(
+            tmp_space.dptr_ + pnum * outputs[0].Size() * sizeof(DType), Shape1(red_work_size), s);
+    broadcast::RTCReduce(ctx, outputs[1].reshape(dst_shape), req[1], red_work,
+                         grads[1].reshape(src_shape), "red::sum{}", 2, "identity");
+    if (pnum == 2) {
+      broadcast::RTCReduce(ctx, outputs[2].reshape(dst_shape), req[2], red_work,
+                           grads[2].reshape(src_shape), "red::sum{}", 2, "identity");
+    }
+  });
+
+#endif
 }
 
 }  // namespace op
diff --git a/src/operator/tensor/broadcast_reduce-inl.cuh b/src/operator/tensor/broadcast_reduce-inl.cuh
deleted file mode 100644
index 53b0ad8..0000000
--- a/src/operator/tensor/broadcast_reduce-inl.cuh
+++ /dev/null
@@ -1,414 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2015-2020 by Contributors
- * \file broadcast_reduce-inl.cuh
- * \brief CUDA implementations for binary broadcast and reduce
- * \author Antti-Pekka Hynninen, Przemyslaw Tredak
-*/
-#ifndef MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_INL_CUH_
-#define MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_INL_CUH_
-
-using namespace mshadow::cuda;
-
-template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP, int unroll,
-         typename IndexOP = mxnet::op::mshadow_op::set_index_no_op<AType, int>>
-__launch_bounds__(nthread_reduce)
-__global__ void reduce_kernel(const int N, const int M, const bool addto,
-                              const DType* __restrict big, OType *small,
-                              const Shape<ndim> big_shape0, const Shape<ndim> small_shape,
-                              const Shape<ndim> big_shape, const Shape<ndim> big_stride,
-                              const int Mnext, const bool do_transpose) {
-  extern __shared__ char shTileChar[];
-  AType* shTile = (AType*)(shTileChar);
-  const int tid = threadIdx.x + threadIdx.y*blockDim.x;
-  const int bx = (do_transpose) ? blockDim.y : blockDim.x;
-  const int by = (do_transpose) ? blockDim.x : blockDim.y;
-  const int tidx = (do_transpose) ? tid / by : threadIdx.x;
-  const int tidy = (do_transpose) ? tid % by : threadIdx.y;
-  for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) {
-    // This TB handles M range [Mstart, ...., Mend - 1]
-    const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext);
-    const int Mend   = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext);
-    for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) {
-      int idx = idx0 + tidx;
-      Shape<ndim> coord = mxnet_op::unravel(idx, small_shape);
-      int idx_big0 = mxnet_op::ravel(coord, big_shape0);
-
-      AType val, residual;
-      Reducer::SetInitValue(val, residual);
-      if (idx < N) {
-        for (int k = tidy + Mstart; k < Mend; k += by*unroll) {
-          int idx_big[unroll];
-          #pragma unroll
-          for (int u=0;u < unroll;u++) {
-            idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride);
-          }
-          AType tmp[unroll];
-          #pragma unroll
-          for (int u=0;u < unroll;u++) {
-            if (k + u*by < Mend) {
-              tmp[u] = OP::Map(big[idx_big[u]]);
-              // argmin/max, set IndexedNum.idx
-              if (IndexOP::do_op)
-                IndexOP::Op(&tmp[u], k+u*by);
-            }
-          }
-          #pragma unroll
-          for (int u=0;u < unroll;u++) {
-            if (k + u*by < Mend) Reducer::Reduce(val, tmp[u], residual);
-          }
-        }
-      }
-
-      // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0
-      if (by > 1) {
-        // Fix bx to avoid bank conflicts. Assumes warpSize number of banks
-        const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx;
-        const int it0 = tidx + tidy*fbx;
-        shTile[it0 * 2] = val;
-        shTile[it0 * 2 + 1] = residual;
-        __syncthreads();
-        for (int t=1;t < by;t <<= 1) {
-          AType tmp, tmp_residual;
-          Reducer::SetInitValue(tmp, tmp_residual);
-          if (tidy + t < by) {
-            tmp = shTile[(it0 + t*fbx) * 2];
-            tmp_residual = shTile[(it0 + t*fbx) * 2 + 1];
-          }
-          __syncthreads();
-          Reducer::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual);
-          __syncthreads();
-        }
-        if (idx < N && tidy == 0) {
-          Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
-          assign(&small[idx + m0*N], addto, OType(shTile[tidx * 2]));
-        }
-      } else {
-        if (idx < N) {
-          Reducer::Finalize(val, residual);
-          assign(&small[idx + m0*N], addto, OType(val));
-        }
-      }
-    }
-  }
-}
-
-template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2, int unroll>
-__launch_bounds__(nthread_reduce)
-__global__ void reduce_kernel(const int N, const int M, const bool addto,
-                              const DType* __restrict big, const DType* __restrict lhs,
-                              const DType* __restrict rhs, DType *small,
-                              const Shape<ndim> big_shape0, const Shape<ndim> lhs_shape0,
-                              const Shape<ndim> rhs_shape0, const Shape<ndim> small_shape,
-                              const Shape<ndim> big_shape, const Shape<ndim> lhs_shape,
-                              const Shape<ndim> rhs_shape, const Shape<ndim> big_stride,
-                              const Shape<ndim> lhs_stride, const Shape<ndim> rhs_stride,
-                              const int Mnext, const bool do_transpose) {
-  extern __shared__ char shTileChar[];
-  DType* shTile = (DType*)(shTileChar);
-  const int tid = threadIdx.x + threadIdx.y*blockDim.x;
-  const int bx = (do_transpose) ? blockDim.y : blockDim.x;
-  const int by = (do_transpose) ? blockDim.x : blockDim.y;
-  const int tidx = (do_transpose) ? tid / by : threadIdx.x;
-  const int tidy = (do_transpose) ? tid % by : threadIdx.y;
-  for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) {
-    // This TB handles M range [Mstart, ...., Mend - 1]
-    const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext);
-    const int Mend   = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext);
-    for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) {
-      int idx = idx0 + tidx;
-      Shape<ndim> coord = mxnet_op::unravel(idx, small_shape);
-      int idx_big0 = mxnet_op::ravel(coord, big_shape0);
-      int idx_lhs0 = mxnet_op::ravel(coord, lhs_shape0);
-      int idx_rhs0 = mxnet_op::ravel(coord, rhs_shape0);
-
-      DType val, residual;
-      Reducer::SetInitValue(val, residual);
-      if (idx < N) {
-        for (int k = tidy + Mstart; k < Mend; k += by*unroll) {
-          int idx_big[unroll];
-          int idx_lhs[unroll];
-          int idx_rhs[unroll];
-          #pragma unroll
-          for (int u=0;u < unroll;u++) {
-            idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride);
-            idx_lhs[u] = idx_lhs0 + mxnet_op::unravel_dot(k + u*by, lhs_shape, lhs_stride);
-            idx_rhs[u] = idx_rhs0 + mxnet_op::unravel_dot(k + u*by, rhs_shape, rhs_stride);
-          }
-          DType tmp[unroll];
-          #pragma unroll
-          for (int u=0;u < unroll;u++) {
-            if (k + u*by < Mend) {
-              tmp[u] = OP1::Map(big[idx_big[u]], OP2::Map(lhs[idx_lhs[u]], rhs[idx_rhs[u]]));
-            }
-          }
-          #pragma unroll
-          for (int u=0;u < unroll;u++) {
-            if (k + u*by < Mend) Reducer::Reduce(val, tmp[u], residual);
-          }
-        }
-      }
-
-      // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0
-      if (by > 1) {
-        // Fix bx to avoid bank conflicts. Assumes warpSize number of banks
-        const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx;
-        const int it0 = tidx + tidy*fbx;
-        shTile[it0 * 2] = val;
-        shTile[it0 * 2 + 1] = residual;
-        __syncthreads();
-        for (int t=1;t < by;t <<= 1) {
-          DType tmp, tmp_residual;
-          Reducer::SetInitValue(tmp, tmp_residual);
-          if (tidy + t < by) {
-            tmp = shTile[(it0 + t*fbx) * 2];
-            tmp_residual = shTile[(it0 + t*fbx) * 2 + 1];
-          }
-          __syncthreads();
-          Reducer::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual);
-          __syncthreads();
-        }
-        if (idx < N && tidy == 0) {
-          Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
-          assign(&small[idx + m0*N], addto, shTile[tidx * 2]);
-        }
-      } else {
-        if (idx < N) {
-          Reducer::Finalize(val, residual);
-          assign(&small[idx + m0*N], addto, val);
-        }
-      }
-    }
-  }
-}
-
-// Simple reduction of lines when M is small
-template<typename Reducer, typename DType>
-__launch_bounds__(kMaxThreadsPerBlock)
-__global__ void reduce_lines_kernel(const int N, const int M, const bool addto,
-  const int small_in_stride, const DType* __restrict small_in, DType *small_out) {
-  for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
-
-    DType val, residual;
-    Reducer::SetInitValue(val, residual);
-    for (int k = 0; k < M; k++) {
-      Reducer::Reduce(val, small_in[idx + k*small_in_stride], residual);
-    }
-
-    if (idx < N) {
-      Reducer::Finalize(val, residual);
-      assign(&small_out[idx], addto, val);
-    }
-
-  }
-}
-
-template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP>
-__launch_bounds__(kMaxThreadsPerBlock)
-__global__ void reduce_kernel_M1(const int N, const bool addto,
-                                const DType* __restrict big, OType *small, const Shape<ndim> bshape,
-                                const Shape<ndim> sshape) {
-  for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
-    Shape<ndim> coord = mxnet_op::unravel(idx, sshape);
-    int j = mxnet_op::ravel(coord, bshape);
-    AType val, residual, temp = OP::Map(big[j]);
-    Reducer::SetInitValue(val, residual);
-    Reducer::Reduce(val, temp, residual);
-    Reducer::Finalize(val, residual);
-    assign(&small[idx], addto, OType(val));
-  }
-}
-
-template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2>
-__launch_bounds__(kMaxThreadsPerBlock)
-__global__ void reduce_kernel_M1(const int N, const bool addto,
-                                 const DType* __restrict big,
-                                 const DType* __restrict lhs,
-                                 const DType* __restrict rhs,
-                                 DType *small,
-                                 const Shape<ndim> big_shape,
-                                 const Shape<ndim> lhs_shape,
-                                 const Shape<ndim> rhs_shape,
-                                 const Shape<ndim> small_shape) {
-  for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
-    Shape<ndim> coord = mxnet_op::unravel(idx, small_shape);
-    int idx_big = mxnet_op::ravel(coord, big_shape);
-    int idx_lhs = mxnet_op::ravel(coord, lhs_shape);
-    int idx_rhs = mxnet_op::ravel(coord, rhs_shape);
-    DType val, residual;
-    Reducer::SetInitValue(val, residual);
-    Reducer::Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual);
-    Reducer::Finalize(val, residual);
-    assign(&small[idx], addto, val);
-  }
-}
-
-#define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \
-  if (do_unroll) {                                                    \
-    const int unrollVar = unrollAmount;                               \
-    {__VA_ARGS__}                                                     \
-  } else {                                                            \
-    const int unrollVar = 1;                                          \
-    {__VA_ARGS__}                                                     \
-  }
-
-template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP,
-	 typename IndexOP = mxnet::op::mshadow_op::set_index_no_op<AType, int>>
-void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req,
-                const TBlob& big, const Tensor<gpu, 1, char>& workspace,
-                const ReduceImplConfig& config) {
-  if (config.M == 1) {
-    reduce_kernel_M1<Reducer, ndim, AType, DType, OType, OP>
-    <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>(
-      config.N, req == kAddTo, big.dptr<DType>(), reinterpret_cast<OType*>(small.dptr_),
-      big.shape_.get<ndim>(), small.shape_.get<ndim>());
-    MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1);
-  } else {
-    OType* small_dptr = reinterpret_cast<OType*>(small.dptr_);
-    bool addto = (req == kAddTo);
-    if (config.Mnext > 1) {
-      // small_dptr[] is N*Mnext*sizeof(DType) bytes
-      small_dptr = reinterpret_cast<OType*>(workspace.dptr_);
-      addto = false;
-      // Check that the workspace is contigiuous
-      CHECK_EQ(workspace.CheckContiguous(), true);
-      // Check that we have enough storage
-      CHECK_GE(workspace.size(0), config.workspace_size);
-    }
-
-    const int by = (config.kernel_1.do_transpose) ?
-      config.kernel_1.blockDim.x : config.kernel_1.blockDim.y;
-    const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce );
-    KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, {
-      reduce_kernel<Reducer, ndim, AType, DType, OType, OP, UNROLL, IndexOP>
-      <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>(
-        config.N, config.M, addto, big.dptr<DType>(), small_dptr, big.shape_.get<ndim>(),
-        small.shape_.get<ndim>(), config.rshape.get<ndim>(), config.rstride.get<ndim>(),
-        config.Mnext, config.kernel_1.do_transpose);
-    });
-    MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel);
-
-    if (config.Mnext > 1) {
-      reduce_lines_kernel<Reducer, OType>
-      <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>>
-        (config.N, config.Mnext, req == kAddTo, config.N, small_dptr,
-         reinterpret_cast<OType*>(small.dptr_));
-      MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel);
-    }
-  }
-}
-
-template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2>
-void ReduceImpl(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const TBlob& rhs,
-                const OpReqType req, const TBlob& big, const Tensor<gpu, 1, char>& workspace,
-                const ReduceImplConfig& config) {
-  if (config.M == 1) {
-    reduce_kernel_M1<Reducer, ndim, DType, OP1, OP2>
-    <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>(
-      config.N, req == kAddTo, big.dptr<DType>(), lhs.dptr<DType>(), rhs.dptr<DType>(),
-      small.dptr<DType>(), big.shape_.get<ndim>(), lhs.shape_.get<ndim>(),
-      rhs.shape_.get<ndim>(), small.shape_.get<ndim>());
-    MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1);
-  } else {
-    DType* small_dptr = small.dptr<DType>();
-    bool addto = (req == kAddTo);
-    if (config.Mnext > 1) {
-      // small_dptr[] is N*Mnext*sizeof(DType) bytes
-      small_dptr = reinterpret_cast<DType*>(workspace.dptr_);
-      addto = false;
-      // Check that the workspace is contigiuous
-      CHECK_EQ(workspace.CheckContiguous(), true);
-      // Check that we have enough storage
-      CHECK_GE(workspace.size(0), config.workspace_size);
-    }
-
-    const int by = (config.kernel_1.do_transpose) ?
-      config.kernel_1.blockDim.x : config.kernel_1.blockDim.y;
-    const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce );
-    KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, {
-      reduce_kernel<Reducer, ndim, DType, OP1, OP2, UNROLL>
-      <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>(
-        config.N, config.M, addto, big.dptr<DType>(), lhs.dptr<DType>(), rhs.dptr<DType>(),
-        small_dptr, big.shape_.get<ndim>(), lhs.shape_.get<ndim>(),
-        rhs.shape_.get<ndim>(), small.shape_.get<ndim>(), config.rshape.get<ndim>(),
-        config.lhs_shape.get<ndim>(), config.rhs_shape.get<ndim>(), config.rstride.get<ndim>(),
-        config.lhs_stride.get<ndim>(), config.rhs_stride.get<ndim>(), config.Mnext,
-        config.kernel_1.do_transpose);
-      MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel);
-    });
-
-    if (config.Mnext > 1) {
-      reduce_lines_kernel<Reducer, DType>
-      <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>>
-        (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr<DType>());
-      MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel);
-    }
-  }
-}
-
-#undef KERNEL_UNROLL_SWITCH
-
-template<typename Reducer, int ndim, typename DType, typename OP, bool safe_acc = false>
-void Reduce(Stream<gpu> *s, const TBlob& small, const OpReqType req,
-            const Tensor<gpu, 1, char>& workspace, const TBlob& big) {
-  if (req == kNullOp) return;
-  cudaStream_t stream = Stream<gpu>::GetStream(s);
-  ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, sizeof(DType));
-  if (safe_acc) {
-    MXNET_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
-      typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
-      MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
-        typedef typename std::conditional<safe_acc, OType, DataType>::type OutType;
-        config = ReduceImplConfig(small.shape_, big.shape_, nullptr, nullptr,
-                                  sizeof(AccType));
-        ReduceImpl<Reducer, ndim, AccType, DataType, OutType, OP>(
-          stream, small, req, big, workspace, config);
-      });
-    });
-  } else {
-    ReduceImpl<Reducer, ndim, DType, DType, DType, OP>(stream, small, req, big, workspace, config);
-  }
-}
-
-template<typename Reducer, int ndim, typename DType, typename OP, bool safe_acc = false>
-void ReduceBool(Stream<gpu> *s, const TBlob& small, const OpReqType req,
-                const Tensor<gpu, 1, char>& workspace, const TBlob& big) {
-  if (req == kNullOp) return;
-  cudaStream_t stream = Stream<gpu>::GetStream(s);
-  ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, sizeof(DType));
-  ReduceImpl<Reducer, ndim, bool, DType, bool, OP>(stream, small, req, big, workspace, config);
-}
-
-template <typename Reducer, int ndim, typename DType, typename OP>
-void ReduceWithExtraMem(Stream<gpu>* s, const TBlob& small, const OpReqType req,
-                        const Tensor<gpu, 1, char>& workspace, const TBlob& big) {};
-
-template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2>
-void Reduce(Stream<gpu> *s, const TBlob& small, const OpReqType req,
-            const Tensor<gpu, 1, char>& workspace, const TBlob& big,
-            const TBlob& lhs, const TBlob& rhs) {
-  if (req == kNullOp) return;
-  cudaStream_t stream = Stream<gpu>::GetStream(s);
-  ReduceImplConfig config(small.shape_, big.shape_, &lhs.shape_, &rhs.shape_, sizeof(DType));
-  ReduceImpl<Reducer, ndim, DType, OP1, OP2>(stream, small, lhs, rhs, req, big, workspace, config);
-}
-
-#endif  //MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_INL_CUH_
diff --git a/src/operator/tensor/broadcast_reduce-inl.h b/src/operator/tensor/broadcast_reduce-inl.h
index a7d5d67..1907c02 100644
--- a/src/operator/tensor/broadcast_reduce-inl.h
+++ b/src/operator/tensor/broadcast_reduce-inl.h
@@ -445,13 +445,13 @@ void ReduceWithExtraMem(Stream<cpu>* s, const TBlob& small, const OpReqType req,
 }
 
 inline size_t ReduceWorkspaceSize(Stream<cpu> *s, const mxnet::TShape& small, const OpReqType req,
-                                  const mxnet::TShape& big, const int type_size) {
+                                  const mxnet::TShape& big) {
   return 0;
 }
 
 inline size_t ReduceWorkspaceSize(Stream<cpu> *s, const mxnet::TShape& small, const OpReqType req,
                                   const mxnet::TShape& big, const mxnet::TShape& lhs,
-                                  const mxnet::TShape& rhs, const int type_size) {
+                                  const mxnet::TShape& rhs) {
   return 0;
 }
 
@@ -539,11 +539,13 @@ struct ReduceImplConfig {
 
   inline ReduceImplConfig(const ::mxnet::TShape& small, const ::mxnet::TShape& big,
                           const ::mxnet::TShape* lhs,
-                          const ::mxnet::TShape* rhs,
-                          const size_t type_size) :
+                          const ::mxnet::TShape* rhs) :
     rshape(small.ndim(), 1), rstride(small.ndim(), 1),
     lhs_shape(small.ndim(), 1), lhs_stride(small.ndim(), 1),
     rhs_shape(small.ndim(), 1), rhs_stride(small.ndim(), 1) {
+    // The largest reduction type currently is (index_t, double) struct
+    // aligned to 16B
+    constexpr size_t max_type_size = 2 * sizeof(double);
     constexpr int maxLoopPerTB = 64;
     int ndim = small.ndim();
 
@@ -646,7 +648,7 @@ struct ReduceImplConfig {
           by++;
         }
         kernel_1.shMemSize = (kernel_1.blockDim.x > 1) ?
-          kernel_1.blockDim.x*by*type_size * 2 : 0;
+          kernel_1.blockDim.x*by*max_type_size * 2 : 0;
         // Maximum number of times we want TB to loop in M
         // Max size of M-block each TB can handle
         int maxMblock = kernel_1.blockDim.x*maxLoopPerTB;
@@ -657,7 +659,7 @@ struct ReduceImplConfig {
             ceil_idiv<unsigned int>(N, kernel_1.blockDim.x));
         kernel_1.gridDim.y = std::min(kBaseGridNum, Mnext);
         kernel_1.shMemSize = (kernel_1.blockDim.y > 1) ?
-          kernel_1.blockDim.x*kernel_1.blockDim.y*type_size * 2 : 0;
+          kernel_1.blockDim.x*kernel_1.blockDim.y*max_type_size * 2 : 0;
         // Maximum number of times we want TB to loop in M
         // Max size of M-block each TB can handle
         int maxMblock = kernel_1.blockDim.y*maxLoopPerTB;
@@ -666,7 +668,7 @@ struct ReduceImplConfig {
 
       if (Mnext > 1) {
         // small_dptr[] is N*Mnext*type_size bytes
-        workspace_size += N*Mnext*sizeof(double);
+        workspace_size += N * Mnext * max_type_size;
         // Set gridDim.y to Mnext
         kernel_1.gridDim.y = std::min(kBaseGridNum, Mnext);
       }
@@ -681,24 +683,20 @@ struct ReduceImplConfig {
 };
 
 inline size_t ReduceWorkspaceSize(Stream<gpu> *s, const ::mxnet::TShape& small, const OpReqType req,
-                                  const ::mxnet::TShape& big, const int type_size) {
+                                  const ::mxnet::TShape& big) {
   if (req == kNullOp) return 0;
-  ReduceImplConfig config(small, big, nullptr, nullptr, type_size);
+  ReduceImplConfig config(small, big, nullptr, nullptr);
   return config.workspace_size;
 }
 
 inline size_t ReduceWorkspaceSize(Stream<gpu> *s, const ::mxnet::TShape& small, const OpReqType req,
                                   const ::mxnet::TShape& big, const ::mxnet::TShape& lhs,
-                                  const ::mxnet::TShape& rhs, const int type_size) {
+                                  const ::mxnet::TShape& rhs) {
   if (req == kNullOp) return 0;
-  ReduceImplConfig config(small, big, &lhs, &rhs, type_size);
+  ReduceImplConfig config(small, big, &lhs, &rhs);
   return config.workspace_size;
 }
 
-#ifdef __CUDACC__
-#include "broadcast_reduce-inl.cuh"
-#endif
-
 #endif  // MXNET_USE_CUDA
 
 template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2>
@@ -784,7 +782,8 @@ void RTCReduce(const OpContext& ctx,
                const TBlob& big,
                const std::string& reducer,
                int ndim,
-               const std::string& OP);
+               const std::string& OP,
+               const bool use_index = false);
 
 void RTCReduce(const OpContext& ctx,
                const TBlob& small,
diff --git a/src/operator/tensor/broadcast_reduce_minmax_value.cu b/src/operator/tensor/broadcast_reduce_minmax_value.cu
index baf79fe..c8cb757 100644
--- a/src/operator/tensor/broadcast_reduce_minmax_value.cu
+++ b/src/operator/tensor/broadcast_reduce_minmax_value.cu
@@ -28,13 +28,15 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(max)
-.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow::red::maximum>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<ReduceAxesParam, 0>
+                                     {"identity", "red::maximum{}", false});
 
 NNVM_REGISTER_OP(_backward_max)
 .set_attr<FCompute>("FCompute<gpu>", ReduceAxesBackwardUseInOut<gpu, mshadow_op::eq>);
 
 NNVM_REGISTER_OP(min)
-.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow::red::minimum>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<ReduceAxesParam, 0>
+                                     {"identity", "red::minimum{}", false});
 
 NNVM_REGISTER_OP(_backward_min)
 .set_attr<FCompute>("FCompute<gpu>", ReduceAxesBackwardUseInOut<gpu, mshadow_op::eq>);
diff --git a/src/operator/tensor/broadcast_reduce_op.cc b/src/operator/tensor/broadcast_reduce_op.cc
new file mode 100644
index 0000000..483787e
--- /dev/null
+++ b/src/operator/tensor/broadcast_reduce_op.cc
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "broadcast_reduce_op.h"
+#include <limits>
+#include "../numpy/np_broadcast_reduce_op.h"
+#include "elemwise_binary_scalar_op.h"
+#include "mxnet/tuple.h"
+
+namespace mxnet {
+namespace op {
+
+#if MXNET_USE_CUDA
+
+void ReduceAxesRTCComputeImpl(const OpContext& ctx,
+                              const std::vector<TBlob>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<TBlob>& outputs,
+                              const mxnet::TShape& small,
+                              const std::string& reducer,
+                              const mshadow::Tensor<gpu, 1, char>* workspace,
+                              const bool normalize,
+                              const std::string& OP,
+                              const int ddof) {
+  using namespace mshadow;
+
+  mxnet::TShape src_shape, dst_shape;
+  BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape);
+  Stream<gpu>* s = ctx.get_stream<gpu>();
+  Tensor<gpu, 1, char> w;
+  if (workspace == nullptr) {
+    size_t workspace_size = broadcast::ReduceWorkspaceSize(
+        s, dst_shape, req[0], src_shape);
+    w = ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(workspace_size), s);
+    workspace = &w;
+  }
+  const TBlob in_data = inputs[0].reshape(src_shape);
+  const TBlob out_data = outputs[0].reshape(dst_shape);
+  BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
+    broadcast::RTCReduce(ctx, out_data, req[0], *workspace, in_data, reducer, NDim, OP);
+  });
+  if (normalize) {
+    NumpyBinaryScalarParam p{};
+    p.scalar = static_cast<double>(src_shape.Size()/dst_shape.Size() - ddof);
+    NodeAttrs a;
+    a.parsed = p;
+    BinaryScalarRTCCompute {"div"}(a, ctx, {out_data}, {kWriteInplace}, {out_data});
+  }
+}
+
+namespace {
+template <typename Param>
+void PrepareReduce(const Param& param,
+                   const std::vector<TBlob>& inputs,
+                   const std::vector<TBlob>& outputs,
+                   mxnet::TShape* shape, int* ddof);
+
+template <>
+void PrepareReduce<ReduceAxesParam>(const ReduceAxesParam& param,
+                                    const std::vector<TBlob>& inputs,
+                                    const std::vector<TBlob>& outputs,
+                                    mxnet::TShape* small, int* ddof) {
+  if (param.keepdims) {
+    *small = outputs[0].shape_;
+  } else {
+    *small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, param.exclude);
+  }
+
+  *ddof = 0;
+}
+
+template <>
+void PrepareReduce<NumpyReduceAxesNoDTypeParam>(const NumpyReduceAxesNoDTypeParam& param,
+                                                const std::vector<TBlob>& inputs,
+                                                const std::vector<TBlob>& outputs,
+                                                mxnet::TShape* small, int* ddof) {
+  if (param.initial.has_value()) {
+    LOG(FATAL) << "initial is not supported yet";
+  }
+  if (param.keepdims) {
+    *small = outputs[0].shape_;
+  } else {
+    *small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true);
+  }
+
+  *ddof = 0;
+}
+
+template <>
+void PrepareReduce<NumpyReduceAxesParam>(const NumpyReduceAxesParam& param,
+                                         const std::vector<TBlob>& inputs,
+                                         const std::vector<TBlob>& outputs,
+                                         mxnet::TShape* small, int* ddof) {
+  if (param.initial.has_value()) {
+    LOG(FATAL) << "initial is not supported yet";
+  }
+  if (param.keepdims) {
+    *small = outputs[0].shape_;
+  } else {
+    *small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true);
+  }
+
+  *ddof = 0;
+}
+
+template <>
+void PrepareReduce<NumpyReduceAxesBoolParam>(const NumpyReduceAxesBoolParam& param,
+                                             const std::vector<TBlob>& inputs,
+                                             const std::vector<TBlob>& outputs,
+                                             mxnet::TShape* small, int* ddof) {
+  if (param.keepdims) {
+    *small = outputs[0].shape_;
+  } else {
+    *small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true);
+  }
+
+  *ddof = 0;
+}
+
+}  // namespace
+
+template <typename Param, int init>
+void ReduceAxesRTCCompute<Param, init>::operator()(const nnvm::NodeAttrs& attrs,
+                                                   const OpContext& ctx,
+                                                   const std::vector<TBlob>& inputs,
+                                                   const std::vector<OpReqType>& req,
+                                                   const std::vector<TBlob>& outputs) {
+  if (req[0] == kNullOp) return;
+  mxnet::TShape small;
+  int ddof;
+  const auto& param = nnvm::get<Param>(attrs.parsed);
+  CHECK_NE(req[0], kWriteInplace) << "Reduce does not support write in-place";
+  PrepareReduce(param, inputs, outputs, &small, &ddof);
+  if (outputs[0].shape_.Size() == 0U) return;  // zero-size tensor
+  if (inputs[0].shape_.Size() == 0) {
+    if (normalize && mxnet::common::is_float(outputs[0].type_flag_)) {
+      LOG(WARNING) << "WARNING: Mean of empty slice.";
+        NumpyBinaryScalarParam p{};
+        p.scalar = std::numeric_limits<float>::quiet_NaN();
+        NodeAttrs a;
+        a.parsed = p;
+        BinaryScalarRTCCompute {"right"} (a, ctx, outputs, {kWriteTo}, outputs);
+    } else {
+      if (normalize) {
+        LOG(WARNING) << "WARNING: nan is outside the range of"<<
+                        "representable values of type 'int'";
+      }
+      if (init == 0 && req[0] == kAddTo) return;
+      NumpyBinaryScalarParam p{};
+      p.scalar = init;
+      NodeAttrs a;
+      a.parsed = p;
+      BinaryScalarRTCCompute {"right"} (a, ctx, outputs, {req[0]}, outputs);
+    }
+    return;
+  }
+
+  ReduceAxesRTCComputeImpl(ctx, inputs, req, outputs, small, reducer, nullptr, normalize, OP, ddof);
+}
+
+template struct ReduceAxesRTCCompute<ReduceAxesParam, 0>;
+template struct ReduceAxesRTCCompute<NumpyReduceAxesParam, 0>;
+template struct ReduceAxesRTCCompute<NumpyReduceAxesParam, 1>;
+template struct ReduceAxesRTCCompute<NumpyReduceAxesNoDTypeParam, 0>;
+template struct ReduceAxesRTCCompute<NumpyReduceAxesBoolParam, 0>;
+template struct ReduceAxesRTCCompute<NumpyReduceAxesBoolParam, 1>;
+
+#endif
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h
index 872a898..7baf255 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -658,7 +658,9 @@ void ReduceAxesComputeImpl(const OpContext& ctx,
                            const std::vector<TBlob>& inputs,
                            const std::vector<OpReqType>& req,
                            const std::vector<TBlob>& outputs,
-                           const mxnet::TShape& small) {
+                           const mxnet::TShape& small,
+                           const mshadow::Tensor<xpu, 1, char>* workspace = nullptr,
+                           const int ddof = 0) {
   using namespace mshadow;
   using namespace mshadow::expr;
 
@@ -670,15 +672,18 @@ void ReduceAxesComputeImpl(const OpContext& ctx,
       const TBlob in_data = inputs[0].reshape(src_shape);
       const TBlob out_data = outputs[0].reshape(dst_shape);
       BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
-        size_t workspace_size = broadcast::ReduceWorkspaceSize(
-            s, out_data.shape_, req[0], in_data.shape_, sizeof(OType));
-        Tensor<xpu, 1, char> workspace =
-            ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+        Tensor<xpu, 1, char> w;
+        if (workspace == nullptr) {
+          size_t workspace_size = broadcast::ReduceWorkspaceSize(
+              s, out_data.shape_, req[0], in_data.shape_);
+          w = ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+          workspace = &w;
+        }
         broadcast::Reduce<reducer, NDim, DType, OP, safe_acc>(
-            s, out_data, req[0], workspace, in_data);
+            s, out_data, req[0], *workspace, in_data);
         if (normalize) {
           auto out = out_data.FlatTo2D<xpu, OType>(s);
-          out /= scalar<OType>(src_shape.Size()/dst_shape.Size());
+          out /= scalar<OType>(src_shape.Size()/dst_shape.Size() - ddof);
         }
       });
     });
@@ -704,7 +709,7 @@ void ReduceAxesComputeBoolImpl(const OpContext& ctx,
       const TBlob out_data = outputs[0].reshape(dst_shape);
       BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
         size_t workspace_size = broadcast::ReduceWorkspaceSize(
-            s, out_data.shape_, req[0], in_data.shape_, sizeof(OType));
+            s, out_data.shape_, req[0], in_data.shape_);
         Tensor<xpu, 1, char> workspace =
             ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
         broadcast::ReduceBool<reducer, NDim, DType, OP>(
@@ -736,6 +741,35 @@ void ReduceAxesCompute(const nnvm::NodeAttrs& attrs,
   ReduceAxesComputeImpl<xpu, reducer, false, normalize, OP>(ctx, inputs, req, outputs, small);
 }
 
+#if MXNET_USE_CUDA
+
+template <typename Param, int init>
+struct ReduceAxesRTCCompute {
+  std::string OP;
+  std::string reducer;
+  bool normalize;
+
+  void operator()(const nnvm::NodeAttrs& attrs,
+                  const OpContext& ctx,
+                  const std::vector<TBlob>& inputs,
+                  const std::vector<OpReqType>& req,
+                  const std::vector<TBlob>& outputs);
+};
+
+void ReduceAxesRTCComputeImpl(const OpContext& ctx,
+                              const std::vector<TBlob>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<TBlob>& outputs,
+                              const mxnet::TShape& small,
+                              const std::string& reducer,
+                              const mshadow::Tensor<gpu, 1, char>* workspace = nullptr,
+
+                              const bool normalize = false,
+                              const std::string& OP = "identity",
+                              const int ddof = 0);
+
+#endif
+
 template <typename red_op, int req, int axis>
 struct ReduceCsrKernel;
 
@@ -1516,7 +1550,8 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs,
   } else {
     small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, false);
   }
-  bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true);
+#if !defined(__CUDACC__)
+  bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
   if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) {
     common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for LpNorm with float16 inputs. "
                     "See https://mxnet.apache.org/api/faq/env_var "
@@ -1539,6 +1574,15 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs,
         ctx, inputs, req, outputs, small);
     }
   }
+#else
+  const std::string &red = param.ord == 1
+                           ? "red::sum{}"
+                           : "red::nrm2{}";
+  const std::string &op = param.ord == 1
+                          ? "abs"
+                          : "identity";
+  ReduceAxesRTCComputeImpl(ctx, inputs, req, outputs, small, red, nullptr, false, op);
+#endif
 }
 
 template<int req>
diff --git a/src/operator/tensor/broadcast_reduce_op_value.cu b/src/operator/tensor/broadcast_reduce_op_value.cu
index 35b3c02..f7c2834 100644
--- a/src/operator/tensor/broadcast_reduce_op_value.cu
+++ b/src/operator/tensor/broadcast_reduce_op_value.cu
@@ -37,7 +37,8 @@ NNVM_REGISTER_OP(broadcast_like)
 .set_attr<FCompute>("FCompute<gpu>", BroadcastCompute<gpu>);
 
 NNVM_REGISTER_OP(_broadcast_backward)
-.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow::red::sum>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<ReduceAxesParam, 0>{"identity",
+                                                                              "red::sum{}", false});
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/broadcast_reduce_prod_value.cu b/src/operator/tensor/broadcast_reduce_prod_value.cu
index 5731de3..7e7a95b 100644
--- a/src/operator/tensor/broadcast_reduce_prod_value.cu
+++ b/src/operator/tensor/broadcast_reduce_prod_value.cu
@@ -28,13 +28,15 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(prod)
-.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow_op::product>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<ReduceAxesParam, 0>
+                                     {"identity", "red::product{}", false});
 
 NNVM_REGISTER_OP(_backward_prod)
 .set_attr<FCompute>("FCompute<gpu>", ReduceAxesBackwardUseInOut<gpu, mshadow_op::rdiv>);
 
 NNVM_REGISTER_OP(nanprod)
-.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow_op::nanprod>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<ReduceAxesParam, 0>
+                                     {"identity", "red::nanprod{}", false});
 
 NNVM_REGISTER_OP(_backward_nanprod)
 .set_attr<FCompute>("FCompute<gpu>", ReduceAxesBackwardUseInOut<gpu, mshadow_op::nanprod_grad>);
diff --git a/src/operator/tensor/broadcast_reduce_sum_value.cu b/src/operator/tensor/broadcast_reduce_sum_value.cu
index 2385d36..40a8ed8 100644
--- a/src/operator/tensor/broadcast_reduce_sum_value.cu
+++ b/src/operator/tensor/broadcast_reduce_sum_value.cu
@@ -28,19 +28,22 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(sum)
-.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow::red::sum>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<ReduceAxesParam, 0>{"identity",
+                                                                           "red::sum{}", false});
 
 NNVM_REGISTER_OP(_backward_sum)
 .set_attr<FCompute>("FCompute<gpu>", ReduceAxesBackwardUseNone<gpu>);
 
 NNVM_REGISTER_OP(mean)
-.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow::red::sum, true>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<ReduceAxesParam, 0>{"identity",
+                                                                           "red::sum{}", true});
 
 NNVM_REGISTER_OP(_backward_mean)
 .set_attr<FCompute>("FCompute<gpu>", ReduceAxesBackwardUseNone<gpu, true>);
 
 NNVM_REGISTER_OP(nansum)
-.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow_op::nansum>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<ReduceAxesParam, 0>{"identity",
+                                                                           "red::nansum{}", false});
 
 NNVM_REGISTER_OP(_backward_nansum)
 .set_attr<FCompute>("FCompute<gpu>", ReduceAxesBackwardUseInOut<gpu, mshadow_op::nansum_grad>);
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.cc b/src/operator/tensor/elemwise_binary_broadcast_op.cc
index 2f9832a..9a682dc 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.cc
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.cc
@@ -376,10 +376,10 @@ void BinaryBroadcastRTCBackwardUseNone::operator()(const nnvm::NodeAttrs& attrs,
       if (out.shape_.Size() != 0) {
         broadcast::RTCReduce(ctx, lhs, req[0],
                              workspace, out,
-                             "red::sum", NDim, LOP);
+                             "red::sum{}", NDim, LOP);
         broadcast::RTCReduce(ctx, rhs, req[1],
                              workspace, out,
-                             "red::sum", NDim, ROP);
+                             "red::sum{}", NDim, ROP);
       } else {
         using namespace common::cuda::rtc::util;
         if (lhs.shape_.Size() != 0) {
@@ -425,21 +425,21 @@ void BinaryBroadcastRTCBackwardUseIn::operator()(const nnvm::NodeAttrs& attrs,
         const TBlob rhs = inputs[2].reshape(new_rshape);
         size_t workspace_size_l = broadcast::ReduceWorkspaceSize(
             s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_,
-            rhs.shape_, common::mshadow_type_info(outputs[0].type_flag_).size);
+            rhs.shape_);
         size_t workspace_size_r = broadcast::ReduceWorkspaceSize(
             s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_,
-            rhs.shape_, common::mshadow_type_info(outputs[1].type_flag_).size);
+            rhs.shape_);
         size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
         Tensor<gpu, 1, char> workspace =
             ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(workspace_size), s);
         if (req[0] != kNullOp) {
           broadcast::RTCReduce(ctx, lgrad, req[0], workspace,
-                               ograd, lhs, rhs, "red::sum", NDim,
+                               ograd, lhs, rhs, "red::sum{}", NDim,
                                "mul", LOP);
         }
         if (req[1] != kNullOp) {
           broadcast::RTCReduce(ctx, rgrad, req[1], workspace,
-                               ograd, lhs, rhs, "red::sum", NDim,
+                               ograd, lhs, rhs, "red::sum{}", NDim,
                                "mul", ROP);
         }
     });
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h
index a5bfdd7..b1700c7 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -629,9 +629,9 @@ inline void BinaryBroadcastBackwardUseInImpl(const OpContext& ctx,
   const TBlob lhs = inputs[1].reshape(new_lshape);
   const TBlob rhs = inputs[2].reshape(new_rshape);
   size_t workspace_size_l = ReduceWorkspaceSize(
-      s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType));
+      s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_);
   size_t workspace_size_r = ReduceWorkspaceSize(
-      s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType));
+      s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_);
   size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
   Tensor<xpu, 1, char> workspace =
       ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index 49a3ed2..24d6ca8 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -2046,8 +2046,13 @@ void RepeatOpBackward(const nnvm::NodeAttrs& attrs,
     inputs[0].type_flag_, inputs[0].dev_id());
   std::vector<TBlob> newInputs = {iblob};
 
+#if !defined(__CUDACC__)
   ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, false>(
       ctx, newInputs, req, newOutputs, rshapes.first);
+#else
+  ReduceAxesRTCComputeImpl(ctx, newInputs, req, newOutputs, rshapes.first,
+                           "red::sum{}", nullptr, false);
+#endif
 }
 
 struct TileParam : public dmlc::Parameter<TileParam> {
@@ -2238,8 +2243,13 @@ void TileOpBackward(const nnvm::NodeAttrs& attrs,
     inputs[0].type_flag_, inputs[0].dev_id());
   std::vector<TBlob> newInputs = {iblob};
 
+#if !defined(__CUDACC__)
   ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, false>(
       ctx, newInputs, req, newOutputs, rshapes.first);
+#else
+  ReduceAxesRTCComputeImpl(ctx, newInputs, req, newOutputs, rshapes.first,
+                           "red::sum{}", nullptr, false);
+#endif
 }
 
 struct ReverseParam : public dmlc::Parameter<ReverseParam> {
diff --git a/src/operator/tensor/reduce_rtc.cc b/src/operator/tensor/reduce_rtc.cc
index 9e2d6d3..ac39f6f 100644
--- a/src/operator/tensor/reduce_rtc.cc
+++ b/src/operator/tensor/reduce_rtc.cc
@@ -49,12 +49,39 @@ struct reduce_kernel_params {
 
 const char reduce_function_code[] = R"code(
 #define FUNC OP(IType0::from(big[idx_big[u]]))
+using AType = typename AccType<InputType0>::type;
 )code";
 
 const char reduce_function_use_input_code[] = R"code(
 #define FUNC OP1(IType0::from(big[idx_big[u]]),     \
                  OP2(IType1::from(lhs[idx_lhs[u]]), \
                      IType2::from(rhs[idx_rhs[u]])))
+using AType = typename AccType<InputType0>::type;
+)code";
+
+const char reduce_function_index_code[] = R"code(
+#define FUNC AType(OP(IType0::from(big[idx_big[u]])), index)
+
+template <typename T>
+struct AccIndex {
+  index_t idx;
+  T num;
+
+  __device__ inline AccIndex<T>() {}
+  __device__ inline AccIndex<T>(const T& val, const index_t idx) : num(val), idx(idx) {}
+
+  __device__ inline operator index_t() const volatile {
+    return idx;
+  }
+
+  __device__ inline AccIndex<T>& operator=(const AccIndex<T>& other) {
+    idx = other.idx;
+    num = other.num;
+    return *this;
+  }
+};
+
+using AType = AccIndex<typename AccType<InputType0>::type>;
 )code";
 
 const char reduce_kernel_code[] = R"code(
@@ -71,21 +98,107 @@ struct reduce_kernel_params {
   index_t rhs_shape[util::MAX_DIM];
 };
 
-__launch_bounds__(kRTCMaxThreadsPerBlock)
-__global__ void reduce_kernel(const int N, const int M, const bool addto,
-                              const InputType0* __restrict big,
-                              const InputType1* __restrict lhs,
-                              const InputType2* __restrict rhs,
-                              OutputType0 *small,
-                              const reduce_kernel_params params,
-                              const int Mnext) {
+inline __device__ AType reduce(const index_t idx, const int tidx,
+                               const int tidy, const int N,
+                               const index_t Mstart, const index_t Mend,
+                               const InputType0* __restrict big,
+                               const InputType1* __restrict lhs,
+                               const InputType2* __restrict rhs,
+                               const reduce_kernel_params& params) {
   extern __shared__ char shTileChar[];
   using IType0 = AccType<InputType0>;
   using IType1 = AccType<InputType1>;
   using IType2 = AccType<InputType2>;
   using OType = AccType<OutputType0>;
-  using AType = typename IType0::type;
   AType* shTile = (AType*)(shTileChar);
+  const int bx = (do_transpose) ? blockDim.y : blockDim.x;
+  const int by = (do_transpose) ? blockDim.x : blockDim.y;
+  index_t coord[ndim];
+  util::unravel(idx, params.small_shape, coord);
+  index_t idx_big0, idx_lhs0, idx_rhs0;
+  idx_big0 = util::ravel(coord, params.big_shape);
+  if (use_input) {
+    idx_lhs0 = util::ravel(coord, params.lhs_shape0);
+    idx_rhs0 = util::ravel(coord, params.rhs_shape0);
+  }
+
+  AType val, residual;
+  REDUCER.SetInitValue(val, residual);
+  if (idx < N) {
+    for (index_t k = tidy + Mstart; k < Mend; k += by*UNROLL) {
+      index_t idx_big[UNROLL];
+      index_t idx_lhs[UNROLL];
+      index_t idx_rhs[UNROLL];
+      #pragma unroll
+      for (int u=0;u < UNROLL;u++) {
+        idx_big[u] = idx_big0 + util::unravel_dot<ndim>(k + u*by, params.rshape,
+                                                        params.rstride);
+        if (use_input) {
+          idx_lhs[u] = idx_lhs0 + util::unravel_dot<ndim>(k + u*by, params.lhs_shape,
+                                                          params.lhs_stride);
+          idx_rhs[u] = idx_rhs0 + util::unravel_dot<ndim>(k + u*by, params.rhs_shape,
+                                                          params.rhs_stride);
+        }
+      }
+      AType tmp[UNROLL];
+      #pragma unroll
+      for (int u=0;u < UNROLL;u++) {
+        if (k + u*by < Mend) {
+          const index_t index = k + u*by;
+          tmp[u] = FUNC;
+        }
+      }
+      #pragma unroll
+      for (int u=0;u < UNROLL;u++) {
+        if (k + u*by < Mend) REDUCER.Reduce(val, tmp[u], residual);
+      }
+    }
+  }
+
+  // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0
+  if (by > 1) {
+    // Fix bx to avoid bank conflicts. Assumes warpSize number of banks
+    const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx;
+    const int it0 = tidx + tidy*fbx;
+    shTile[it0 * 2] = val;
+    shTile[it0 * 2 + 1] = residual;
+    __syncthreads();
+    for (int t=1;t < by;t <<= 1) {
+      AType tmp, tmp_residual;
+      REDUCER.SetInitValue(tmp, tmp_residual);
+      if (tidy + t < by) {
+        tmp = shTile[(it0 + t*fbx) * 2];
+        tmp_residual = shTile[(it0 + t*fbx) * 2 + 1];
+      }
+      __syncthreads();
+      REDUCER.Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual);
+      __syncthreads();
+    }
+    if (idx < N && tidy == 0) {
+      REDUCER.Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
+      return shTile[tidx * 2];
+    } else {
+      return AType();
+    }
+  } else {
+    if (idx < N) {
+      REDUCER.Finalize(val, residual);
+      return val;
+    } else {
+      return AType();
+    }
+  }
+}
+
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void reduce_kernel_single(const int N, const int M,
+                                     const InputType0* __restrict big,
+                                     const InputType1* __restrict lhs,
+                                     const InputType2* __restrict rhs,
+                                     OutputType0 *small,
+                                     const reduce_kernel_params params,
+                                     const int Mnext) {
+  using OType = AccType<OutputType0>;
   const int tid = threadIdx.x + threadIdx.y*blockDim.x;
   const int bx = (do_transpose) ? blockDim.y : blockDim.x;
   const int by = (do_transpose) ? blockDim.x : blockDim.y;
@@ -96,117 +209,74 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
     const index_t Mstart = (index_t)((int64)M*(int64)m0/(int64)Mnext);
     const index_t Mend   = (index_t)((int64)M*(int64)(m0 + 1)/(int64)Mnext);
     for (index_t idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) {
-      int idx = idx0 + tidx;
-      index_t coord[ndim];
-      util::unravel(idx, params.small_shape, coord);
-      index_t idx_big0, idx_lhs0, idx_rhs0;
-      idx_big0 = util::ravel(coord, params.big_shape);
-      if (use_input) {
-        idx_lhs0 = util::ravel(coord, params.lhs_shape0);
-        idx_rhs0 = util::ravel(coord, params.rhs_shape0);
-      }
-
-      AType val, residual;
-      REDUCER::SetInitValue(val, residual);
-      if (idx < N) {
-        for (index_t k = tidy + Mstart; k < Mend; k += by*UNROLL) {
-          index_t idx_big[UNROLL];
-          index_t idx_lhs[UNROLL];
-          index_t idx_rhs[UNROLL];
-          #pragma unroll
-          for (int u=0;u < UNROLL;u++) {
-            idx_big[u] = idx_big0 + util::unravel_dot<ndim>(k + u*by, params.rshape,
-                                                            params.rstride);
-            if (use_input) {
-              idx_lhs[u] = idx_lhs0 + util::unravel_dot<ndim>(k + u*by, params.lhs_shape,
-                                                              params.lhs_stride);
-              idx_rhs[u] = idx_rhs0 + util::unravel_dot<ndim>(k + u*by, params.rhs_shape,
-                                                              params.rhs_stride);
-            }
-          }
-          typename OType::type tmp[UNROLL];
-          #pragma unroll
-          for (int u=0;u < UNROLL;u++) {
-            if (k + u*by < Mend) {
-              tmp[u] = FUNC;
-            }
-          }
-          #pragma unroll
-          for (int u=0;u < UNROLL;u++) {
-            if (k + u*by < Mend) REDUCER::Reduce(val, tmp[u], residual);
-          }
+      const index_t idx = idx0 + tidx;
+      AType val = reduce(idx, tidx, tidy, N, Mstart, Mend, big, lhs, rhs, params);
+      if (idx < N && (by == 1 || tidy == 0)) {
+        if (req == OpReqType::kAddTo) {
+          small[idx + m0 * N] = OType::to(op::add(OType::from(small[idx + m0 * N]),
+                                                  static_cast<typename OType::type>(val)));
+        } else {
+          small[idx + m0 * N] = OType::to(val);
         }
       }
+    }
+  }
+}
 
-      // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0
-      if (by > 1) {
-        // Fix bx to avoid bank conflicts. Assumes warpSize number of banks
-        const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx;
-        const int it0 = tidx + tidy*fbx;
-        shTile[it0 * 2] = val;
-        shTile[it0 * 2 + 1] = residual;
-        __syncthreads();
-        for (int t=1;t < by;t <<= 1) {
-          AType tmp, tmp_residual;
-          REDUCER::SetInitValue(tmp, tmp_residual);
-          if (tidy + t < by) {
-            tmp = shTile[(it0 + t*fbx) * 2];
-            tmp_residual = shTile[(it0 + t*fbx) * 2 + 1];
-          }
-          __syncthreads();
-          REDUCER::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual);
-          __syncthreads();
-        }
-        if (idx < N && tidy == 0) {
-          REDUCER::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
-          if (addto) {
-            small[idx + m0 * N] = OType::to(op::add(OType::from(small[idx + m0 * N]),
-                                                    shTile[tidx * 2]));
-          } else {
-            small[idx + m0 * N] = OType::to(shTile[tidx * 2]);
-          }
-        }
-      } else {
-        if (idx < N) {
-          REDUCER::Finalize(val, residual);
-          if (addto) {
-            small[idx + m0 * N] = OType::to(op::add(OType::from(small[idx + m0 * N]),
-                                                    val));
-          } else {
-            small[idx + m0 * N] = OType::to(val);
-          }
-        }
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void reduce_kernel_multi(const int N, const int M,
+                                    const InputType0* __restrict big,
+                                    const InputType1* __restrict lhs,
+                                    const InputType2* __restrict rhs,
+                                    AType *small,
+                                    const reduce_kernel_params params,
+                                    const int Mnext) {
+  const int tid = threadIdx.x + threadIdx.y*blockDim.x;
+  const int bx = (do_transpose) ? blockDim.y : blockDim.x;
+  const int by = (do_transpose) ? blockDim.x : blockDim.y;
+  const int tidx = (do_transpose) ? tid / by : threadIdx.x;
+  const int tidy = (do_transpose) ? tid % by : threadIdx.y;
+  for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) {
+    // This TB handles M range [Mstart, ...., Mend - 1]
+    const index_t Mstart = (index_t)((int64)M*(int64)m0/(int64)Mnext);
+    const index_t Mend   = (index_t)((int64)M*(int64)(m0 + 1)/(int64)Mnext);
+    for (index_t idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) {
+      const index_t idx = idx0 + tidx;
+      AType val = reduce(idx, tidx, tidy, N, Mstart, Mend, big, lhs, rhs, params);
+      if (idx < N && (by == 1 || tidy == 0)) {
+        small[idx + m0 * N] = val;
       }
     }
   }
 }
+
 )code";
 
 const char reduce_lines_kernel_code[] = R"code(
 __launch_bounds__(kRTCMaxThreadsPerBlock)
 __global__ void reduce_lines_kernel(const index_t N, const index_t M,
                                     const index_t small_in_stride,
-                                    const OutputType0* __restrict small_in,
+                                    const AType* __restrict small_in,
                                     OutputType0 *small_out) {
   using OType = AccType<OutputType0>;
   for (index_t idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
-    typename OType::type val, residual;
-    REDUCER::SetInitValue(val, residual);
+    AType val, residual;
+    REDUCER.SetInitValue(val, residual);
     for (int k = 0; k < M; k++) {
-      REDUCER::Reduce(val,
-        OType::from(reinterpret_cast<const OutputType0*>(small_in)[idx + k*small_in_stride]),
+      REDUCER.Reduce(val,
+        small_in[idx + k*small_in_stride],
         residual);
     }
 
     if (idx < N) {
-      REDUCER::Finalize(val, residual);
+      REDUCER.Finalize(val, residual);
       if (req == OpReqType::kAddTo) {
-        small_out[idx] = OType::to(op::add(OType::from(small_out[idx]), val));
+        small_out[idx] = OType::to(op::add(OType::from(small_out[idx]),
+                                           static_cast<typename OType::type>(val)));
       } else {
         small_out[idx] = OType::to(val);
       }
     }
-
   }
 }
 )code";
@@ -215,14 +285,13 @@ void RTCReduceImpl(Stream<gpu> *s, const TBlob& small, const bool addto,
                 const TBlob& big, const Tensor<gpu, 1, char>& workspace,
                 const ReduceImplConfig& config, const int ndim,
                 const std::string &common_code, int dev_id,
-                const TBlob *lhs = nullptr, const TBlob *rhs = nullptr) {
+                const TBlob *lhs = nullptr, const TBlob *rhs = nullptr,
+                const bool use_index = false) {
   using namespace common::cuda::rtc;
   void* small_dptr = small.dptr_;
-  bool first_kernel_addto = addto;
   if (config.Mnext > 1) {
     // small_dptr[] is N*Mnext*sizeof(DType) bytes
     small_dptr = workspace.dptr_;
-    first_kernel_addto = false;
     // Check that the workspace is contigiuous
     CHECK_EQ(workspace.CheckContiguous(), true);
     // Check that we have enough storage
@@ -281,7 +350,6 @@ void RTCReduceImpl(Stream<gpu> *s, const TBlob& small, const bool addto,
   std::vector<const void*> args;
   args.emplace_back(&config.N);
   args.emplace_back(&config.M);
-  args.emplace_back(&first_kernel_addto);
   args.emplace_back(&big.dptr_);
   if (lhs != nullptr) {
     args.emplace_back(&(lhs->dptr_));
@@ -295,10 +363,11 @@ void RTCReduceImpl(Stream<gpu> *s, const TBlob& small, const bool addto,
   args.emplace_back(&config.Mnext);
 
   const auto &function_code = (lhs == nullptr)
-                            ? reduce_function_code
+                            ? (use_index ? reduce_function_index_code : reduce_function_code)
                             : reduce_function_use_input_code;
+  const auto& kernel_name = (config.Mnext > 1) ? "reduce_kernel_multi" : "reduce_kernel_single";
   auto reduce_kernel_func = get_function(code + function_code,
-                                         "reduce_kernel",
+                                         kernel_name,
                                          reduce_kernel_code,
                                          dev_id);
   launch(reduce_kernel_func, config.kernel_1.gridDim,
@@ -313,7 +382,7 @@ void RTCReduceImpl(Stream<gpu> *s, const TBlob& small, const bool addto,
     args.emplace_back(&small_dptr);
     args.emplace_back(&small.dptr_);
 
-    auto reduce_lines_kernel_func = get_function(code,
+    auto reduce_lines_kernel_func = get_function(code + function_code,
                                                  "reduce_lines_kernel",
                                                  reduce_lines_kernel_code,
                                                  dev_id);
@@ -348,9 +417,11 @@ __global__ void reduce_kernel_M1(const int N,
   using IType1 = AccType<InputType1>;
   using IType2 = AccType<InputType2>;
   using OType = AccType<OutputType0>;
-  for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
+  for (index_t index = threadIdx.x + blockIdx.x*blockDim.x;
+       index < N;
+       index += blockDim.x*gridDim.x) {
     index_t coord[ndim];
-    util::unravel(idx, params.small_shape, coord);
+    util::unravel(index, params.small_shape, coord);
     index_t idx_big[1];
     idx_big[0] = util::ravel(coord, params.big_shape);
     index_t idx_lhs[1], idx_rhs[1];
@@ -358,16 +429,17 @@ __global__ void reduce_kernel_M1(const int N,
       idx_lhs[0] = util::ravel(coord, params.lhs_shape);
       idx_rhs[0] = util::ravel(coord, params.rhs_shape);
     }
-    typename OType::type val, residual;
-    REDUCER::SetInitValue(val, residual);
+    AType val, residual;
+    REDUCER.SetInitValue(val, residual);
     const int u = 0;
-    REDUCER::Reduce(val, FUNC, residual);
-    REDUCER::Finalize(val, residual);
+    REDUCER.Reduce(val, FUNC, residual);
+    REDUCER.Finalize(val, residual);
     if (req == OpReqType::kAddTo) {
-      const auto temp = op::add(val, OType::from(small[idx]));
-      small[idx] = OType::to(temp);
+      const auto temp = op::add(static_cast<typename OType::type>(val),
+                                OType::from(small[index]));
+      small[index] = OType::to(temp);
     } else {
-      small[idx] = OType::to(val);
+      small[index] = OType::to(static_cast<typename OType::type>(val));
     }
   }
 }
@@ -376,7 +448,8 @@ __global__ void reduce_kernel_M1(const int N,
 void RTCReduceM1Impl(Stream<gpu> *s, const TBlob &small, const TBlob &big,
                      const TBlob *lhs, const TBlob *rhs,
                      const ReduceImplConfig &config, const int ndim,
-                     const std::string &common_code, int dev_id) {
+                     const std::string &common_code, int dev_id,
+                     const bool use_index = false) {
   using namespace common::cuda::rtc;
 
   std::string code = common_code +
@@ -427,7 +500,7 @@ void RTCReduceM1Impl(Stream<gpu> *s, const TBlob &small, const TBlob &big,
   args.emplace_back(&param);
 
   const auto &function_code = (lhs == nullptr)
-                            ? reduce_function_code
+                            ? (use_index ? reduce_function_index_code : reduce_function_code)
                             : reduce_function_use_input_code;
   auto reduce_kernel_M1_func = get_function(code + function_code,
                                             "reduce_kernel_M1",
@@ -447,14 +520,12 @@ void RTCReduce(const OpContext& ctx,
                const TBlob& big,
                const std::string& reducer,
                int ndim,
-               const std::string& OP) {
+               const std::string& OP,
+               const bool use_index) {
   using namespace mxnet::common::cuda::rtc;
   if (req == kNullOp) return;
   Stream<gpu> *s = ctx.get_stream<gpu>();
-  size_t big_type_size = common::mshadow_type_info(big.type_flag_).acc_size;
-  size_t small_type_size = common::mshadow_type_info(small.type_flag_).acc_size;
-  size_t type_size = std::max(big_type_size, small_type_size);
-  ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, type_size);
+  ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr);
   std::string common_code = std::string("const OpReqType req = ") +
                             util::to_string(req) +
                             ";\n"
@@ -469,10 +540,12 @@ void RTCReduce(const OpContext& ctx,
                             ";\n";
   if (config.M == 1) {
     RTCReduceM1Impl(s, small, big, nullptr, nullptr, config,
-                    ndim, common_code, ctx.run_ctx.ctx.dev_id);
+                    ndim, common_code, ctx.run_ctx.ctx.dev_id,
+                    use_index);
   } else {
     RTCReduceImpl(s, small, req == kAddTo, big, workspace, config,
-                  ndim, common_code, ctx.run_ctx.ctx.dev_id);
+                  ndim, common_code, ctx.run_ctx.ctx.dev_id,
+                  nullptr, nullptr, use_index);
   }
 }
 
@@ -490,10 +563,7 @@ void RTCReduce(const OpContext& ctx,
   using namespace mxnet::common::cuda::rtc;
   if (req == kNullOp) return;
   Stream<gpu> *s = ctx.get_stream<gpu>();
-  size_t big_type_size = common::mshadow_type_info(big.type_flag_).acc_size;
-  size_t small_type_size = common::mshadow_type_info(small.type_flag_).acc_size;
-  size_t type_size = std::max(big_type_size, small_type_size);
-  ReduceImplConfig config(small.shape_, big.shape_, &lhs.shape_, &rhs.shape_, type_size);
+  ReduceImplConfig config(small.shape_, big.shape_, &lhs.shape_, &rhs.shape_);
   std::string common_code = std::string("const OpReqType req = ") +
                             util::to_string(req) +
                             ";\n"
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index ba8e327..1fc7b8e 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -3006,8 +3006,10 @@ def test_np_binary_funcs():
                 if isinstance(dtype, tuple):
                     assert len(dtype) == 2
                     ldtype, rdtype = dtype
-                np_test_x1 = _np.random.uniform(low, high, lshape).astype(ldtype)
-                np_test_x2 = _np.random.uniform(low, high, rshape).astype(rdtype)
+                npldtype = ldtype if dtype != _np.float16 else _np.float32
+                nprdtype = rdtype if dtype != _np.float16 else _np.float32
+                np_test_x1 = _np.random.uniform(low, high, lshape).astype(ldtype).astype(npldtype)
+                np_test_x2 = _np.random.uniform(low, high, rshape).astype(rdtype).astype(nprdtype)
                 mx_test_x1 = mx.numpy.array(np_test_x1, dtype=ldtype)
                 mx_test_x2 = mx.numpy.array(np_test_x2, dtype=rdtype)
                 for hybridize in [True, False]:
@@ -4372,7 +4374,7 @@ def test_np_argmin_argmax():
         ((3, 5, 7), 2, False),
         ((3, 5, 7, 9, 11), -3, False),
     ]
-    dtypes = ['float16', 'float32', 'float64']
+    dtypes = ['float16', 'float32', 'float64', 'bool', 'int32']
     ops = ['argmin', 'argmax']
 
     class TestArgExtreme(HybridBlock):
@@ -4387,7 +4389,7 @@ def test_np_argmin_argmax():
     for op_name in ops:
         for shape, axis, throw_exception in workloads:
             for dtype in dtypes:
-                a = np.random.uniform(size=shape, dtype=dtype)
+                a = np.random.uniform(low=0, high=100, size=shape).astype(dtype)
                 if throw_exception:
                     # Cannot use assert_exception because sometimes the main thread
                     # proceeds to `assert False` before the exception is thrown