You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by di...@apache.org on 2021/05/25 01:40:31 UTC
[incubator-mxnet] branch master updated: [FEATURE] Use RTC for
reduction ops (#19426)
This is an automated email from the ASF dual-hosted git repository.
dickjc123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 57d0ace [FEATURE] Use RTC for reduction ops (#19426)
57d0ace is described below
commit 57d0ace3a9b89032896822dd27af870a9b446d83
Author: Przemyslaw Tredak <pt...@nvidia.com>
AuthorDate: Mon May 24 18:38:16 2021 -0700
[FEATURE] Use RTC for reduction ops (#19426)
* Initial rebase
* Fixes after merge
* Fixes
* Fix lint
* Fix lint for real
* Cleaning and code reuse
* Fix lint
* Try to WAR the maybe-uninitialized warning
* Second try
* Fix Windows compilation
* More fixes for Windows compilation
* Breaking the strings to please Windows compiler
* Do not use the default stream in kron
* Fix argmin/argmax
* Fix layernorm
---
src/common/cuda/rtc.cc | 5 +-
src/common/cuda/rtc/backward_functions-inl.h | 240 ++++++++----
src/common/cuda/rtc/forward_functions-inl.h | 9 +
src/common/cuda/rtc/reducer-inl.h | 316 +++++++++++-----
src/common/cuda/rtc/special_functions-inl.h | 5 -
src/common/cuda/rtc/util-inl.h | 138 +++++++
src/operator/mshadow_op.h | 10 +-
src/operator/nn/group_norm-inl.h | 153 +++++---
src/operator/nn/layer_norm-inl.h | 226 ++++-------
src/operator/nn/layer_norm.cc | 116 ++++++
src/operator/nn/layer_norm.cu | 83 +++++
src/operator/nn/moments-inl.h | 15 +-
.../linalg/broadcast_reduce_customized-inl.cuh | 415 ---------------------
.../numpy/linalg/broadcast_reduce_customized-inl.h | 7 -
.../numpy/linalg/broadcast_reduce_op_customized.h | 24 +-
src/operator/numpy/linalg/np_norm-inl.h | 123 +++---
src/operator/numpy/np_broadcast_reduce_op.cc | 85 +++++
src/operator/numpy/np_broadcast_reduce_op.cuh | 44 ---
src/operator/numpy/np_broadcast_reduce_op.h | 262 +++++++------
.../numpy/np_broadcast_reduce_op_boolean.cu | 8 +-
src/operator/numpy/np_broadcast_reduce_op_index.cu | 4 +-
src/operator/numpy/np_broadcast_reduce_op_value.cu | 16 +-
src/operator/numpy/np_constraint_check.h | 5 +
src/operator/numpy/np_cross-inl.h | 10 +-
src/operator/numpy/np_elemwise_broadcast_op.h | 4 +-
src/operator/numpy/np_kron-inl.h | 33 +-
src/operator/numpy/np_tensordot_op-inl.h | 18 +-
src/operator/numpy/np_where_op-inl.h | 27 +-
src/operator/numpy/random/dist_common.h | 80 ++++
src/operator/numpy/random/np_exponential_op.h | 7 +-
src/operator/numpy/random/np_gamma_op.h | 7 +-
src/operator/numpy/random/np_location_scale_op.h | 72 +---
src/operator/numpy/random/np_normal_op.h | 74 +---
src/operator/numpy/random/np_pareto_op.h | 28 +-
src/operator/numpy/random/np_rayleigh_op.h | 28 +-
src/operator/numpy/random/np_weibull_op.h | 28 +-
src/operator/quantization/quantization_utils.h | 2 +-
src/operator/quantization/quantize_v2-inl.h | 7 +
src/operator/quantization/requantize-inl.h | 13 +
src/operator/random/pdf_op.h | 34 +-
src/operator/tensor/broadcast_reduce-inl.cuh | 414 --------------------
src/operator/tensor/broadcast_reduce-inl.h | 31 +-
.../tensor/broadcast_reduce_minmax_value.cu | 6 +-
src/operator/tensor/broadcast_reduce_op.cc | 187 ++++++++++
src/operator/tensor/broadcast_reduce_op.h | 62 ++-
src/operator/tensor/broadcast_reduce_op_value.cu | 3 +-
src/operator/tensor/broadcast_reduce_prod_value.cu | 6 +-
src/operator/tensor/broadcast_reduce_sum_value.cu | 9 +-
.../tensor/elemwise_binary_broadcast_op.cc | 12 +-
src/operator/tensor/elemwise_binary_broadcast_op.h | 4 +-
src/operator/tensor/matrix_op-inl.h | 10 +
src/operator/tensor/reduce_rtc.cc | 316 ++++++++++------
tests/python/unittest/test_numpy_op.py | 10 +-
53 files changed, 1944 insertions(+), 1907 deletions(-)
diff --git a/src/common/cuda/rtc.cc b/src/common/cuda/rtc.cc
index 2284bee..af4abbe 100644
--- a/src/common/cuda/rtc.cc
+++ b/src/common/cuda/rtc.cc
@@ -150,13 +150,16 @@ CUfunction get_function(const std::string ¶meters,
std::string(fp16_support_string) + "\n" +
type_support_string + "\n" +
util_string + "\n" +
+ limits + "\n" +
special_functions_definitions + '\n' +
vectorization_support_string + "\n" +
function_definitions_util + "\n" +
function_definitions_binary + "\n" +
function_definitions_unary + "\n" +
backward_function_definitions + "\n" +
- reducer + "\n";
+ grad_function_definitions + "\n" +
+ reducer + "\n" +
+ logic_reducer + "\n";
std::string code_with_header = common_header + parameters + code;
// If verbose mode, output kernel source, though not including the common header
if (dmlc::GetEnv("MXNET_RTC_VERBOSE", false)) {
diff --git a/src/common/cuda/rtc/backward_functions-inl.h b/src/common/cuda/rtc/backward_functions-inl.h
index 64ec251..cb1bae8 100644
--- a/src/common/cuda/rtc/backward_functions-inl.h
+++ b/src/common/cuda/rtc/backward_functions-inl.h
@@ -238,6 +238,98 @@ backward_square(const DTypeGrad grad, const DType val) {
}
template <typename DType, typename DType2>
+__device__ inline DType div_rgrad(const DType val,
+ const DType2 val2) {
+ return -val / (val2 * val2);
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
+backward_clip(const DTypeGrad grad, const DType val,
+ const float a_min, const float a_max) {
+ if (val > a_max || val < a_min) {
+ return 0;
+ } else {
+ return grad;
+ }
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
+backward_reciprocal(const DTypeGrad grad, const DType val) {
+ return -grad / (val * val);
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
+backward_erf(const DTypeGrad grad, const DType val) {
+ using type = mixed_type<DTypeGrad, DType>;
+ const type v = val;
+ constexpr type my_pi = pi;
+ return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad;
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
+backward_erfinv(const DTypeGrad grad, const DType val) {
+ using type = mixed_type<DTypeGrad, DType>;
+ constexpr type my_pi = pi;
+ const type g = grad;
+ const type v = val;
+ return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g;
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
+backward_gamma(const DTypeGrad grad, const DType val) {
+ using type = mixed_type<DTypeGrad, DType>;
+ const type v = val;
+ if (type_util::is_same<DTypeGrad, double>::value) {
+ return grad * op::gamma(v) * op::special_functions::cephes::psi<double>(v);
+ } else {
+ return grad * op::gamma(v) * op::special_functions::cephes::psi<float>(v);
+ }
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
+backward_gammaln(const DTypeGrad grad, const DType val) {
+ using type = mixed_type<DTypeGrad, DType>;
+ const type v = val;
+ if (type_util::is_same<DTypeGrad, double>::value) {
+ return grad * op::special_functions::cephes::psi<double>(v);
+ } else {
+ return grad * op::special_functions::cephes::psi<float>(v);
+ }
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
+backward_digamma(const DTypeGrad grad, const DType val) {
+ using type = mixed_type<DTypeGrad, DType>;
+ const type v = val;
+ if (type_util::is_same<DTypeGrad, double>::value) {
+ return grad * op::special_functions::trigamma<double>(v);
+ } else {
+ return grad * op::special_functions::trigamma<float>(v);
+ }
+}
+
+template <typename DType, typename DTypeGrad>
+__device__ inline mixed_type<DTypeGrad, DType>
+backward_gelu(const DTypeGrad grad, const DType val) {
+ return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) +
+ val * backward_erf(grad, val / op::sqrt(2.0f)) / op::sqrt(2.0f));
+}
+
+} // namespace op
+
+)code";
+
+const char grad_function_definitions[] = R"code(
+namespace op {
+
+template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
rdiv_grad(const DType val,
const DType2 val2) {
@@ -253,12 +345,6 @@ div_grad(const DType val,
}
template <typename DType, typename DType2>
-__device__ inline DType div_rgrad(const DType val,
- const DType2 val2) {
- return -val / (val2 * val2);
-}
-
-template <typename DType, typename DType2>
__device__ inline DType mod_grad(const DType val,
const DType2 val2) {
if (type_util::is_integral<DType>::value) {
@@ -368,80 +454,6 @@ rldexp_grad(const DType val,
return val2 * op::power(static_cast<type>(2), val) * op::log(static_cast<type>(2));
}
-template <typename DType, typename DTypeGrad>
-__device__ inline mixed_type<DTypeGrad, DType>
-backward_clip(const DTypeGrad grad, const DType val,
- const float a_min, const float a_max) {
- if (val > a_max || val < a_min) {
- return 0;
- } else {
- return grad;
- }
-}
-
-template <typename DType, typename DTypeGrad>
-__device__ inline mixed_type<DTypeGrad, DType>
-backward_reciprocal(const DTypeGrad grad, const DType val) {
- return -grad / (val * val);
-}
-
-template <typename DType, typename DTypeGrad>
-__device__ inline mixed_type<DTypeGrad, DType>
-backward_erf(const DTypeGrad grad, const DType val) {
- const mixed_type<DTypeGrad, DType> v = val;
- constexpr mixed_type<DTypeGrad, DType> my_pi = pi;
- return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad;
-}
-
-template <typename DType, typename DTypeGrad>
-__device__ inline mixed_type<DTypeGrad, DType>
-backward_erfinv(const DTypeGrad grad, const DType val) {
- constexpr mixed_type<DTypeGrad, DType> my_pi = pi;
- const mixed_type<DTypeGrad, DType> g = grad;
- const mixed_type<DTypeGrad, DType> v = val;
- return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g;
-}
-
-template <typename DType, typename DTypeGrad>
-__device__ inline mixed_type<DTypeGrad, DType>
-backward_gamma(const DTypeGrad grad, const DType val) {
- const mixed_type<DTypeGrad, DType> v = val;
- if (type_util::is_same<DTypeGrad, double>::value) {
- return grad * op::gamma(v) * op::special_functions::cephes::psi<double>(v);
- } else {
- return grad * op::gamma(v) * op::special_functions::cephes::psi<float>(v);
- }
-}
-
-template <typename DType, typename DTypeGrad>
-__device__ inline mixed_type<DTypeGrad, DType>
-backward_gammaln(const DTypeGrad grad, const DType val) {
- const mixed_type<DTypeGrad, DType> v = val;
- if (type_util::is_same<DTypeGrad, double>::value) {
- return grad * op::special_functions::cephes::psi<double>(v);
- } else {
- return grad * op::special_functions::cephes::psi<float>(v);
- }
-}
-
-template <typename DType, typename DTypeGrad>
-__device__ inline mixed_type<DTypeGrad, DType>
-backward_digamma(const DTypeGrad grad, const DType val) {
- const mixed_type<DTypeGrad, DType> v = val;
- if (type_util::is_same<DTypeGrad, double>::value) {
- return grad * op::special_functions::trigamma<double>(v);
- } else {
- return grad * op::special_functions::trigamma<float>(v);
- }
-}
-
-template <typename DType, typename DTypeGrad>
-__device__ inline mixed_type<DTypeGrad, DType>
-backward_gelu(const DTypeGrad grad, const DType val) {
- return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) +
- val * backward_erf(grad, val / op::sqrt(2.0f)) / op::sqrt(2.0f));
-}
-
template <typename DType, typename DType2>
__device__ inline DType smooth_l1_grad(const DType val, const DType2 scalar) {
auto bsq = scalar * scalar;
@@ -467,8 +479,74 @@ __device__ inline DType prelu_grad(const DType val,
return (val > 0) ? 0 : val;
}
-} // namespace op
+template <typename DType, typename DType2>
+__device__ inline mixed_type<DType2, DType>
+gamma_implicit_grad(const DType a_in, const DType2 x_in) {
+ using OType = mixed_type<DType2, DType>;
+ const OType a = a_in;
+ const OType x = x_in;
+ if (x < 0.8f) {
+ OType numer = 1;
+ OType denom = a;
+ OType series1 = numer / denom;
+ OType series2 = numer / (denom * denom);
+#pragma unroll
+ for (int i = 1; i <= 5; i++) {
+ numer *= -x / static_cast<DType>(i);
+ denom += 1;
+ series1 += numer / denom;
+ series2 += numer / (denom * denom);
+ }
+ OType pow_x_alpha = op::power(x, a);
+ OType gamma_pdf = op::power(x, a - 1) * op::exp(-x);
+ OType gamma_cdf = pow_x_alpha * series1;
+ OType gamma_cdf_alpha =
+ (op::log(x) - OType(special_functions::cephes::psi<float>(a))) *
+ gamma_cdf -
+ pow_x_alpha * series2;
+ OType result = -gamma_cdf_alpha / gamma_pdf;
+ return op::isnan(result) ? 0.f : result;
+ }
+ if (a > 8.0f) {
+ if (0.9f * a <= x && x <= 1.1f * a) {
+ OType numer_1 = 1 + 24 * a * (1 + 12 * a);
+ OType numer_2 = 1440 * (a * a) + 6 * x * (53 - 120 * x) -
+ 65 * x * x / a + a * (107 + 3600 * x);
+ OType denom = 1244160 * (a * a) * (a * a);
+ return numer_1 * numer_2 / denom;
+ }
+ OType denom = op::sqrt(8 * a);
+ OType term2 = denom / (a - x);
+ OType term3 =
+ op::power(x - a - a * op::log(x / a), static_cast<OType>(-1.5));
+ OType term23 = (x < a) ? term2 - term3 : term2 + term3;
+ OType term1 = op::log(x / a) * term23 -
+ op::sqrt(2 / a) * (a + x) / ((a - x) * (a - x));
+ OType stirling = 1.f + 1.f / (12.f * a) * (1.f + 1.f / (24.f * a));
+ OType numer = x * term1;
+ return -stirling * numer / denom;
+ }
+ OType u = op::log(x / a);
+ OType v = op::log(a);
+ OType coef_uv[3][8] = {
+ {0.16009398, -0.094634809, 0.025146376, -0.0030648343, 1, 0.32668115,
+ 0.10406089, 0.0014179084},
+ {0.53487893, 0.1298071, 0.065735949, -0.0015649758, 0.16639465,
+ 0.020070113, -0.0035938915, -0.00058392623},
+ {0.040121004, -0.0065914022, -0.0026286047, -0.0013441777, 0.017050642,
+ -0.0021309326, 0.00085092367, -1.5247877e-07},
+ };
+ OType coef_v[8];
+#pragma unroll
+ for (int i = 0; i < 8; i++) {
+ coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
+ }
+ OType p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
+ OType q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
+ return op::exp(p / q);
+}
+} // namespace op
)code";
} // namespace rtc
diff --git a/src/common/cuda/rtc/forward_functions-inl.h b/src/common/cuda/rtc/forward_functions-inl.h
index 9018a5d..4f87db6 100644
--- a/src/common/cuda/rtc/forward_functions-inl.h
+++ b/src/common/cuda/rtc/forward_functions-inl.h
@@ -696,6 +696,10 @@ __device__ inline DType log_sigmoid(const DType val) {
template <typename DType>
__device__ inline DType softrelu(const DType val) {
+ // Avoid overflow of exp for large inputs.
+ // The threshold 20 is chosen such that softrelu(a) = a
+ // for a > 20 using floating precision.
+ if (val > 20) return val;
if (type_util::has_double_or_integral<DType>::value) {
return ::log(1 + ::exp(val));
} else {
@@ -936,6 +940,11 @@ __device__ inline bool_t np_logical_not(const DType val) {
return !static_cast<bool>(val);
}
+template <typename DType>
+__device__ inline bool_t NonZero(const DType val) {
+ return val != 0;
+}
+
#undef DEFINE_UNARY_MATH_FUNC
template <typename DType>
diff --git a/src/common/cuda/rtc/reducer-inl.h b/src/common/cuda/rtc/reducer-inl.h
index 259d0e0..f5b70d8 100644
--- a/src/common/cuda/rtc/reducer-inl.h
+++ b/src/common/cuda/rtc/reducer-inl.h
@@ -27,11 +27,10 @@ namespace common {
namespace cuda {
namespace rtc {
-const char reducer[] = R"code(
+const char reducer[] = R"code(
namespace red {
-/*! \brief sum reducer */
struct sum {
/*! \brief do reduction into dst */
template<typename DType, typename DType2>
@@ -95,103 +94,6 @@ struct sum {
}
};
-/*! \brief maximum reducer */
-struct maximum {
- /*! \brief do reduction into dst */
- template<typename DType, typename DType2>
- __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { // NOLINT(*)
- if (!util::isnan(dst)) {
- if (!(dst >= src)) dst = src;
- }
- }
- /*! \brief do reduction into dst */
- template<typename DType, typename DType2>
- __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src,
- volatile DType& none) {
- Reduce(dst, src);
- }
- /*! \brief combine the results of two reducers */
- template<typename DType>
- __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
- Reduce(dst_val, src_val);
- }
- /*! \brief combine the results of two reducers */
- template<typename DType>
- __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
- volatile DType& src_val, volatile DType& src_residual) {
- Reduce(dst_val, src_val);
- }
- /*! \brief finalize reduction result */
- template<typename DType>
- __device__ inline static void Finalize(volatile DType& dst) {}
- /*! \brief finalize reduction result */
- template<typename DType>
- __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
- /*!
- *\brief set the initial value during reduction
- */
- template<typename DType>
- __device__ inline static void SetInitValue(DType &initv) {
- initv = -2*DBL_MAX;
- }
- /*!
- *\brief set the initial value during reduction
- */
- template<typename DType>
- __device__ inline static void SetInitValue(DType &initv, DType &none) {
- SetInitValue(initv);
- }
-};
-
-/*! \brief minimum reducer */
-struct minimum {
- /*! \brief do reduction into dst */
- template<typename DType, typename DType2>
- __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) {
- if (!util::isnan(dst)) {
- if (!(dst <= src)) dst = src;
- }
- }
- /*! \brief do reduction into dst */
- template<typename DType, typename DType2>
- __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src,
- volatile DType& none) {
- Reduce(dst, src);
- }
- /*! \brief combine the results of two reducers */
- template<typename DType>
- __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
- Reduce(dst_val, src_val);
- }
- /*! \brief combine the results of two reducers */
- template<typename DType>
- __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
- volatile DType& src_val, volatile DType& src_residual) {
- Reduce(dst_val, src_val);
- }
- /*! \brief finalize reduction result */
- template<typename DType>
- __device__ inline static void Finalize(volatile DType& dst) {}
- /*! \brief finalize reduction result */
- template<typename DType>
- __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
- /*!
- *\brief set the initial value during reduction
- */
- template<typename DType>
- __device__ inline static void SetInitValue(DType &initv) {
- initv = 2*DBL_MAX;
- }
- /*!
- *\brief set the initial value during reduction
- */
- template<typename DType>
- __device__ inline static void SetInitValue(DType &initv, DType &none) {
- SetInitValue(initv);
- }
-};
-
-/*! \brief product reducer */
struct product {
/*! \brief do reduction into dst */
template<typename DType, typename DType2>
@@ -237,7 +139,6 @@ struct product {
}
};
-/*! \brief sum reducer that ignores NaN values in the input */
struct nansum {
/*! \brief do reduction into dst */
template<typename DType, typename DType2>
@@ -293,7 +194,6 @@ struct nansum {
}
};
-/*! \brief product reducer that ignores NaN values in the input */
struct nanprod {
/*! \brief do reduction into dst */
template<typename DType, typename DType2>
@@ -493,10 +393,222 @@ struct nrmlp {
scale = 0;
}
};
-} // namespace red
+} // namespace red
)code";
+const char logic_reducer[] = R"code(
+namespace red {
+
+struct maximum {
+ /*! \brief do reduction into dst */
+ template<typename DType, typename DType2>
+ __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) { // NOLINT(*)
+ if (!util::isnan(dst)) {
+ if (!(dst >= src)) dst = src;
+ }
+ }
+ /*! \brief do reduction into dst */
+ template<typename DType, typename DType2>
+ __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src,
+ volatile DType& none) {
+ Reduce(dst, src);
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
+ Reduce(dst_val, src_val);
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
+ volatile DType& src_val, volatile DType& src_residual) {
+ Reduce(dst_val, src_val);
+ }
+ /*! \brief finalize reduction result */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& dst) {}
+ /*! \brief finalize reduction result */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &initv) {
+ initv = limits::NegInfValue<DType>();
+ }
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &initv, DType &none) {
+ SetInitValue(initv);
+ }
+};
+
+struct minimum {
+ /*! \brief do reduction into dst */
+ template<typename DType, typename DType2>
+ __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src) {
+ if (!util::isnan(dst)) {
+ if (!(dst <= src)) dst = src;
+ }
+ }
+ /*! \brief do reduction into dst */
+ template<typename DType, typename DType2>
+ __device__ inline static void Reduce(volatile DType& dst, volatile DType2 src,
+ volatile DType& none) {
+ Reduce(dst, src);
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
+ Reduce(dst_val, src_val);
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
+ volatile DType& src_val, volatile DType& src_residual) {
+ Reduce(dst_val, src_val);
+ }
+ /*! \brief finalize reduction result */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& dst) {}
+ /*! \brief finalize reduction result */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &initv) {
+ initv = limits::PosInfValue<DType>();
+ }
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &initv, DType &none) {
+ SetInitValue(initv);
+ }
+};
+
+struct argmax {
+ /*! \brief do reduction into dst */
+ template<typename AType, typename DType>
+ __device__ inline static void Reduce(volatile AType& dst, volatile DType src) {
+ if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
+ dst.num = src.num;
+ dst.idx = src.idx;
+ }
+ }
+ /*! \brief do stable reduction into dst */
+ template<typename AType, typename DType>
+ __device__ inline static void Reduce(volatile AType& dst, volatile DType src,
+ volatile DType&) {
+ if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
+ dst.num = src.num;
+ dst.idx = src.idx;
+ }
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst, volatile DType& src) {
+ if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
+ dst.num = src.num;
+ dst.idx = src.idx;
+ }
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst, volatile DType&,
+ volatile DType& src, volatile DType&) {
+ if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
+ dst.num = src.num;
+ dst.idx = src.idx;
+ }
+ }
+ /*! \brief finalize reduction */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& dst) {}
+ /*! \brief finalize reduction */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& dst, volatile DType&) {}
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &initv) {
+ initv.num = limits::NegInfValue<decltype(initv.num)>();
+ }
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &initv, DType &) {
+ initv.num = limits::NegInfValue<decltype(initv.num)>();
+ }
+};
+
+struct argmin {
+ /*! \brief do reduction into dst */
+ template<typename AType, typename DType>
+ __device__ inline static void Reduce(volatile AType& dst, volatile DType src) {
+ if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) {
+ dst.num = src.num;
+ dst.idx = src.idx;
+ }
+ }
+ /*! \brief do stable reduction into dst */
+ template<typename AType, typename DType>
+ __device__ inline static void Reduce(volatile AType& dst, volatile DType src,
+ volatile DType& residual) {
+ if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) {
+ dst.num = src.num;
+ dst.idx = src.idx;
+ }
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst, volatile DType& src) {
+ if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) {
+ dst.num = src.num;
+ dst.idx = src.idx;
+ }
+ }
+ /*! \brief combine the results of two reducers */
+ template<typename DType>
+ __device__ inline static void Merge(volatile DType& dst, volatile DType&,
+ volatile DType& src, volatile DType&) {
+ if (dst.num > src.num || (dst.num == src.num && dst.idx > src.idx)) {
+ dst.num = src.num;
+ dst.idx = src.idx;
+ }
+ }
+ /*! \brief finalize reduction */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& dst) {}
+ /*! \brief finalize reduction */
+ template<typename DType>
+ __device__ inline static void Finalize(volatile DType& dst, volatile DType& residual) {}
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &initv) {
+ initv.num = limits::PosInfValue<decltype(initv.num)>();
+ }
+ /*!
+ *\brief set the initial value during reduction
+ */
+ template<typename DType>
+ __device__ inline static void SetInitValue(DType &initv, DType &residual) {
+ initv.num = limits::PosInfValue<decltype(initv.num)>();
+ }
+};
+} // namespace red
+)code";
} // namespace rtc
} // namespace cuda
} // namespace common
diff --git a/src/common/cuda/rtc/special_functions-inl.h b/src/common/cuda/rtc/special_functions-inl.h
index d64afb5..7110e47 100644
--- a/src/common/cuda/rtc/special_functions-inl.h
+++ b/src/common/cuda/rtc/special_functions-inl.h
@@ -50,11 +50,6 @@ namespace rtc {
// Direct inquiries to 30 Frost Street, Cambridge, MA 02140
//
const char special_functions_definitions[] = R"code(
-constexpr double DBL_MAX = 1.7976931348623157081e+308;
-constexpr float FLT_MAX = 3.4028234663852885981e+38;
-#define inf ((float)1e50)
-#define nan (inf - inf)
-
namespace op {
namespace special_functions {
diff --git a/src/common/cuda/rtc/util-inl.h b/src/common/cuda/rtc/util-inl.h
index b426603..bafa8cf 100644
--- a/src/common/cuda/rtc/util-inl.h
+++ b/src/common/cuda/rtc/util-inl.h
@@ -446,6 +446,144 @@ __device__ inline T strided_grouped_warp_allreduce(T value, OP redfun, const int
} // namespace util
)code";
+
+const char limits[] = R"code(
+constexpr double DBL_MAX = 1.7976931348623157081e+308;
+constexpr float FLT_MAX = 3.4028234663852885981e+38;
+#define inf ((float)1e50)
+#define nan (inf - inf)
+
+namespace limits {
+
+template<typename DType>
+__device__ inline DType MinValue(void);
+
+template<>
+__device__ inline float MinValue<float>(void) {
+ return -FLT_MAX;
+}
+/*! \brief minimum value of double */
+template<>
+__device__ inline double MinValue<double>(void) {
+ return -DBL_MAX;
+}
+/*! \brief minimum value of uint8 */
+template<>
+__device__ inline uint8 MinValue<uint8>(void) {
+ return 0;
+}
+/*! \brief minimum value of int8_t */
+template<>
+__device__ inline int8 MinValue<int8>(void) {
+ return -128;
+}
+/*! \brief minimum value of int32 */
+template<>
+__device__ inline int32 MinValue<int32>(void) {
+ return -2147483648;
+}
+/*! \brief minimum value of int64_t */
+template<>
+__device__ inline int64 MinValue<int64>(void) {
+ return -9223372036854775808LL;
+}
+/*! \brief minimum value of bool */
+template<>
+__device__ inline bool MinValue<bool>(void) {
+ return false;
+}
+/*! \brief minimum value of bool_t */
+template<>
+__device__ inline bool_t MinValue<bool_t>(void) {
+ return MinValue<index_t>();
+}
+
+/*!
+ * \brief negative infinity of certain types
+ * \tparam DType data type
+ */
+template<typename DType>
+__device__ inline DType NegInfValue(void) {
+ return MinValue<DType>();
+}
+/*! \brief negative infinity value of float */
+template<>
+__device__ inline float NegInfValue<float>(void) {
+ return -inf;
+}
+/*! \brief negative infinity value of double */
+template<>
+__device__ inline double NegInfValue<double>(void) {
+ return -inf;
+}
+
+/*!
+ * \brief maximum value of certain types
+ * \tparam DType data type
+ */
+template<typename DType>
+__device__ inline DType MaxValue(void);
+/*! \brief maximum value of float */
+template<>
+__device__ inline float MaxValue<float>(void) {
+ return FLT_MAX;
+}
+/*! \brief maximum value of double */
+template<>
+__device__ inline double MaxValue<double>(void) {
+ return DBL_MAX;
+}
+/*! \brief maximum value of uint8 */
+template<>
+__device__ inline uint8 MaxValue<uint8>(void) {
+ return 255;
+}
+/*! \brief maximum value of int8 */
+template<>
+__device__ inline int8 MaxValue<int8>(void) {
+ return 127;
+}
+/*! \brief maximum value of int32 */
+template<>
+__device__ inline int32 MaxValue<int32>(void) {
+ return 2147483647;
+}
+/*! \brief maximum value of int64 */
+template<>
+__device__ inline int64 MaxValue<int64>(void) {
+ return 9223372036854775807LL;
+}
+/*! \brief maximum value of bool */
+template<>
+__device__ inline bool MaxValue<bool>(void) {
+ return true;
+}
+/*! \brief maximum value of bool_t */
+template<>
+__device__ inline bool_t MaxValue<bool_t>(void) {
+ return MaxValue<index_t>();
+}
+/*!
+ * \brief positive infinity of certain types
+ * \tparam DType data type
+ */
+template<typename DType>
+__device__ inline DType PosInfValue(void) {
+ return MaxValue<DType>();
+}
+/*! \brief positive infinity value of float */
+template<>
+__device__ inline float PosInfValue<float>(void) {
+ return inf;
+}
+/*! \brief positive infinity value of double */
+template<>
+__device__ inline double PosInfValue<double>(void) {
+ return inf;
+}
+
+} // namespace limits
+)code";
} // namespace rtc
} // namespace cuda
} // namespace common
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index c33dad4..be49f0f 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -940,7 +940,7 @@ template<>
MSHADOW_XINLINE mshadow::half::half_t mod_rgrad::Map<mshadow::half::half_t>
(mshadow::half::half_t a,
mshadow::half::half_t b) {
- return mshadow::half::half_t(-::floorf(static_cast<float>(a/b)));
+ return mshadow::half::half_t(-::floorf(static_cast<float>(a)/static_cast<float>(b)));
}
struct rmod : public mxnet_op::tunable {
@@ -1573,7 +1573,7 @@ struct argmax {
/*! \brief do reduction into dst */
template<typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src) { // NOLINT(*)
- if (dst.num < src.num) {
+ if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
dst.num = src.num;
dst.idx = src.idx;
}
@@ -1581,7 +1581,7 @@ struct argmax {
/*! \brief do stable reduction into dst */
template<typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*)
- if (dst.num < src.num) {
+ if (dst.num < src.num || (dst.num == src.num && dst.idx > src.idx)) {
dst.num = src.num;
dst.idx = src.idx;
}
@@ -1589,7 +1589,7 @@ struct argmax {
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
- if (dst_val.num < src_val.num) {
+ if (dst_val.num < src_val.num || (dst_val.num == src_val.num && dst_val.idx > src_val.idx)) {
dst_val.num = src_val.num;
dst_val.idx = src_val.idx;
}
@@ -1597,7 +1597,7 @@ struct argmax {
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
- if (dst_val.num < src_val.num) {
+ if (dst_val.num < src_val.num || (dst_val.num == src_val.num && dst_val.idx > src_val.idx)) {
dst_val.num = src_val.num;
dst_val.idx = src_val.idx;
}
diff --git a/src/operator/nn/group_norm-inl.h b/src/operator/nn/group_norm-inl.h
index da30192..0df0db2 100644
--- a/src/operator/nn/group_norm-inl.h
+++ b/src/operator/nn/group_norm-inl.h
@@ -113,24 +113,29 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, char> workspace;
- size_t workspace_size = 0;
- MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, {
- workspace_size =
- broadcast::ReduceWorkspaceSize(s, red_dst_shape, req[0],
- red_src_shape, sizeof(DType));
- });
+ size_t workspace_size = broadcast::ReduceWorkspaceSize(s, red_dst_shape, req[0],
+ red_src_shape);
workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
// Calculate mean
+#if !defined(__CUDACC__)
MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
s, mean_, req[0], workspace, data_);
- Tensor<xpu, 1, DType> mean_data_tensor = mean_.FlatTo1D<xpu, DType>(s);
- mean_data_tensor /= scalar<DType>(channel_size);
});
});
+#else
+ BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+ broadcast::RTCReduce(ctx, mean_, req[0], workspace,
+ data_, "red::sum{}", NDim, "identity");
+ });
+#endif // !defined(__CUDACC__)
+ MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, {
+ Tensor<xpu, 1, DType> mean_data_tensor = mean_.FlatTo1D<xpu, DType>(s);
+ mean_data_tensor /= scalar<DType>(channel_size);
+ });
TBlob data_grp = data.reshape(temp_data_shape);
const TBlob& mean_grp = mean.reshape(moments_shape);
@@ -150,15 +155,25 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs,
// Calculate std
const TBlob centered_out = outputs[groupnorm::kOut].reshape(red_src_shape);
+#if !defined(__CUDACC__)
MSHADOW_REAL_TYPE_SWITCH(output_grp.type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::square, true>(
s, std_, req[0], workspace, centered_out);
- Tensor<xpu, 1, DType> std_data_tensor = std_.FlatTo1D<xpu, DType>(s);
- std_data_tensor = F<mshadow_op::square_root>(std_data_tensor / scalar<DType>(channel_size)
- + scalar<DType>(param.eps));
});
});
+#else
+ BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+ broadcast::RTCReduce(ctx, std_, req[0],
+ workspace, centered_out,
+ "red::sum{}", NDim, "square");
+ });
+#endif
+ MSHADOW_REAL_TYPE_SWITCH(output_grp.type_flag_, DType, {
+ Tensor<xpu, 1, DType> std_data_tensor = std_.FlatTo1D<xpu, DType>(s);
+ std_data_tensor = F<mshadow_op::square_root>(std_data_tensor / scalar<DType>(channel_size)
+ + scalar<DType>(param.eps));
+ });
// Calculate data = data / std
#if !defined(__CUDACC__)
@@ -263,26 +278,17 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
// Initialize the workspace + Construct the temporary TBlobs
Tensor<xpu, 1, char> workspace;
- size_t reduce_workspace_size = 0;
- size_t data_size = 0;
- size_t red_out_size = 0;
- MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
- data_size = sizeof(DType) * data.Size();
- red_out_size = sizeof(DType) * mean.Size();
- // There are two types of reduction workloads: reduce over axis and reduce exclude axis
- // We take the maximum of the workspace sizes required by these workloads.
- // Also, we explicitly set the req_type=kAddto in case we want to use it.
- reduce_workspace_size =
- std::max(reduce_workspace_size,
- broadcast::ReduceWorkspaceSize(s, red_dst_shape,
- kAddTo, red_src_shape,
- sizeof(DType)));
- reduce_workspace_size =
- std::max(reduce_workspace_size,
- broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo,
- red_exclude_src_shape,
- sizeof(DType)));
- });
+ size_t dtype_size = common::mshadow_type_info(outputs[0].type_flag_).size;
+ size_t data_size = data.Size() * dtype_size;
+ size_t red_out_size = mean.Size() * dtype_size;
+ // There are two types of reduction workloads: reduce over axis and reduce exclude axis
+ // We take the maximum of the workspace sizes required by these workloads.
+ // Also, we explicitly set the req_type=kAddto in case we want to use it.
+ size_t reduce_workspace_size =
+ std::max(broadcast::ReduceWorkspaceSize(s, red_dst_shape,
+ kAddTo, red_src_shape),
+ broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo,
+ red_exclude_src_shape));
workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(
Shape1(reduce_workspace_size + data_size * 2 + red_out_size), s);
const TBlob normalized_data =
@@ -300,14 +306,6 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
{normalized_data, std_},
{kWriteTo}, {normalized_data});
-#else
- BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
- {data_, mean_},
- {kWriteTo}, {normalized_data});
- BinaryBroadcastRTCCompute {"div"}(attrs, ctx,
- {normalized_data, std_},
- {kWriteTo}, {normalized_data});
-#endif // !defined(__CUDACC__)
// Calculate grad_beta
if (req[2] != kNullOp) {
MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, {
@@ -319,13 +317,8 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
});
}
// Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis)
-#if !defined(__CUDACC__)
ElemwiseBinaryOp::Compute<xpu, op::mshadow_op::mul>(attrs, ctx, {normalized_data, ograd},
{kWriteTo}, {ograd_mult});
-#else
- ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd},
- {kWriteTo}, {ograd_mult});
-#endif // !defined(__CUDACC__)
if (req[1] != kNullOp) {
MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
@@ -335,6 +328,32 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
});
});
}
+#else
+ BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
+ {data_, mean_},
+ {kWriteTo}, {normalized_data});
+ BinaryBroadcastRTCCompute {"div"}(attrs, ctx,
+ {normalized_data, std_},
+ {kWriteTo}, {normalized_data});
+ // Calculate grad_beta
+ if (req[2] != kNullOp) {
+ BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+ broadcast::RTCReduce(ctx, outputs[2].reshape(red_exclude_dst_shape),
+ req[2], workspace, ograd.reshape(red_exclude_src_shape),
+ "red::sum{}", NDim, "identity");
+ });
+ }
+ // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis)
+ ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd},
+ {kWriteTo}, {ograd_mult});
+ if (req[1] != kNullOp) {
+ BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+ broadcast::RTCReduce(ctx, outputs[1].reshape(red_exclude_dst_shape),
+ req[1], workspace, ograd_mult.reshape(red_exclude_src_shape),
+ "red::sum{}", NDim, "identity");
+ });
+ }
+#endif // !defined(__CUDACC__)
// Calculate grad_data:
// ograd_mult = ograd * gamma / std
@@ -350,15 +369,6 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
{ograd_mult, std_},
{kWriteTo}, {ograd_mult});
-#else
- BinaryBroadcastRTCCompute {"mul"}(attrs, ctx,
- {inputs[0], gamma},
- {kWriteTo},
- {ograd_mult.reshape(data.shape_)});
- BinaryBroadcastRTCCompute {"div"}(attrs, ctx,
- {ograd_mult, std_},
- {kWriteTo}, {ograd_mult});
-#endif // !defined(__CUDACC__)
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<mshadow_op::sum, NDim, DType, op::mshadow_op::identity, true>(
@@ -368,19 +378,11 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
red_out_tensor /= scalar<DType>(N);
});
-#if !defined(__CUDACC__)
BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
{ograd_mult, red_out},
{req[0]}, {output_});
ElemwiseBinaryOp::Compute<xpu, op::mshadow_op::mul>(attrs, ctx, {ograd_mult, normalized_data},
{kWriteTo}, {ograd_mult});
-#else
- BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
- {ograd_mult, red_out},
- {req[0]}, {output_});
- ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data},
- {kWriteTo}, {ograd_mult});
-#endif // !defined(__CUDACC__)
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<mshadow_op::sum, NDim, DType, op::mshadow_op::identity, true>(
@@ -390,12 +392,39 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
red_out_tensor /= scalar<DType>(-N);
});
-#if !defined(__CUDACC__)
BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
{normalized_data, red_out},
{kAddTo}, {output_});
#else
BinaryBroadcastRTCCompute {"mul"}(attrs, ctx,
+ {inputs[0], gamma},
+ {kWriteTo},
+ {ograd_mult.reshape(data.shape_)});
+ BinaryBroadcastRTCCompute {"div"}(attrs, ctx,
+ {ograd_mult, std_},
+ {kWriteTo}, {ograd_mult});
+ BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+ broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+ ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity");
+ });
+ MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+ Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
+ red_out_tensor /= scalar<DType>(N);
+ });
+ BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
+ {ograd_mult, red_out},
+ {req[0]}, {output_});
+ ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data},
+ {kWriteTo}, {ograd_mult});
+ BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+ broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+ ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity");
+ });
+ MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+ Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
+ red_out_tensor /= scalar<DType>(-N);
+ });
+ BinaryBroadcastRTCCompute {"mul"}(attrs, ctx,
{normalized_data, red_out},
{kAddTo}, {output_});
#endif // !defined(__CUDACC__)
diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h
index d8c8dbc..79d0906 100644
--- a/src/operator/nn/layer_norm-inl.h
+++ b/src/operator/nn/layer_norm-inl.h
@@ -38,6 +38,7 @@
#include "../operator_common.h"
#include "../mxnet_op.h"
#include "../tensor/broadcast_reduce_op.h"
+#include "mxnet/tuple.h"
namespace mxnet {
namespace op {
@@ -115,14 +116,11 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
int channel_size = red_src_shape.Size() / red_dst_shape.Size();
// Initialize the workspace
Tensor<xpu, 1, char> workspace;
- size_t workspace_size = 0;
- MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
- workspace_size =
- broadcast::ReduceWorkspaceSize(s, mean_data.shape_, req[0],
- in_data.shape_, sizeof(DType));
- });
+ size_t workspace_size = broadcast::ReduceWorkspaceSize(s, mean_data.shape_, req[0],
+ in_data.shape_);
workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+#if !defined(__CUDACC__)
bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true);
if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) {
common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for float16 inputs for LayerNorm. "
@@ -145,15 +143,9 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
});
});
// Calculate data = data - mean
-#if !defined(__CUDACC__)
BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
{inputs[0], outputs[layernorm::kMean]},
{kWriteTo}, {outputs[0]});
-#else
- BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
- {inputs[0], outputs[layernorm::kMean]},
- {kWriteTo}, {outputs[0]});
-#endif // !defined(__CUDACC__)
// Calculate std
const TBlob centered_out = outputs[0].reshape(red_src_shape);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
@@ -170,7 +162,6 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
+ scalar<DType>(param.eps));
});
});
-#if !defined(__CUDACC__)
// Calculate data = data / std
BinaryBroadcastCompute<xpu, mshadow_op::div>(attrs, ctx,
{outputs[0], outputs[layernorm::kStd]},
@@ -184,6 +175,30 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
{outputs[0], beta},
{kWriteTo}, {outputs[0]});
#else
+ // Calculate mean
+ BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+ broadcast::RTCReduce(ctx, mean_data, req[0], workspace, in_data,
+ "red::sum{}", NDim, "identity");
+ });
+ MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+ Tensor<xpu, 1, DType> mean_data_tensor = mean_data.FlatTo1D<xpu, DType>(s);
+ mean_data_tensor /= scalar<DType>(channel_size);
+ });
+ // Calculate data = data - mean
+ BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
+ {inputs[0], outputs[layernorm::kMean]},
+ {kWriteTo}, {outputs[0]});
+ // Calculate std
+ const TBlob centered_out = outputs[0].reshape(red_src_shape);
+ BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+ broadcast::RTCReduce(ctx, std_data, req[0], workspace, centered_out,
+ "red::sum{}", NDim, "square");
+ });
+ MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+ Tensor<xpu, 1, DType> std_data_tensor = std_data.FlatTo1D<xpu, DType>(s);
+ std_data_tensor = F<mshadow_op::square_root>(std_data_tensor / scalar<DType>(channel_size)
+ + scalar<DType>(param.eps));
+ });
// Calculate data = data / std
BinaryBroadcastRTCCompute {"div"}(attrs, ctx,
{outputs[0], outputs[layernorm::kStd]},
@@ -196,7 +211,7 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
BinaryBroadcastRTCCompute {"add"}(attrs, ctx,
{outputs[0], beta},
{kWriteTo}, {outputs[0]});
-#endif // !defined(__CUDACC__)
+#endif
}
template<typename xpu>
@@ -205,6 +220,26 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs);
+template<typename xpu>
+void LayerNormGradComputeGeneralImpl(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const TBlob& ograd,
+ const TBlob& data,
+ const TBlob& gamma,
+ const TBlob& mean,
+ const TBlob& std,
+ const TBlob& normalized_data,
+ const TBlob& ograd_mult,
+ const TBlob& red_out,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs,
+ const mshadow::Tensor<xpu, 1, char>& workspace,
+ const mxnet::TShape& red_dst_shape,
+ const mxnet::TShape& red_src_shape,
+ const mxnet::TShape& red_exclude_dst_shape,
+ const mxnet::TShape& red_exclude_src_shape,
+ const int channel_size);
+
/*
Calculate the gradient of layer normalization.
We have the following gradient for gamma, beta and x:
@@ -250,26 +285,17 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
int channel_size = red_src_shape.Size() / red_dst_shape.Size();
// Initialize the workspace + Construct the temporary TBlobs
Tensor<xpu, 1, char> workspace;
- size_t reduce_workspace_size = 0;
- size_t data_size = 0;
- size_t red_out_size = 0;
- MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
- data_size = sizeof(DType) * data.Size();
- red_out_size = sizeof(DType) * mean.Size();
- // There are two types of reduction workloads: reduce over axis and reduce exclude axis
- // We take the maximum of the workspace sizes required by these workloads.
- // Also, we explicitly set the req_type=kAddto in case we want to use it.
- reduce_workspace_size =
- std::max(reduce_workspace_size,
- broadcast::ReduceWorkspaceSize(s, red_dst_shape,
- kAddTo, red_src_shape,
- sizeof(DType)));
- reduce_workspace_size =
- std::max(reduce_workspace_size,
+ size_t dtype_size = common::mshadow_type_info(outputs[0].type_flag_).size;
+ size_t data_size = data.Size() * dtype_size;
+ size_t red_out_size = mean.Size() * dtype_size;
+ // There are two types of reduction workloads: reduce over axis and reduce exclude axis
+ // We take the maximum of the workspace sizes required by these workloads.
+ // Also, we explicitly set the req_type=kAddto in case we want to use it.
+ size_t reduce_workspace_size =
+ std::max(broadcast::ReduceWorkspaceSize(s, red_dst_shape,
+ kAddTo, red_src_shape),
broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo,
- red_exclude_src_shape,
- sizeof(DType)));
- });
+ red_exclude_src_shape));
workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(
Shape1(reduce_workspace_size + data_size * 2 + red_out_size), s);
const TBlob normalized_data = TBlob(workspace.dptr_ + reduce_workspace_size,
@@ -278,135 +304,11 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
ograd.shape_, ograd.dev_mask(), ograd.type_flag_, ograd.dev_id());
const TBlob red_out = TBlob(workspace.dptr_ + reduce_workspace_size + data_size * 2,
mean.shape_, mean.dev_mask(), mean.type_flag_, mean.dev_id());
- // Compute normalized_data = (data - mean) / std
-#if !defined(__CUDACC__)
- BinaryBroadcastCompute<xpu, mshadow_op::minus>(attrs, ctx,
- {data, mean},
- {kWriteTo}, {normalized_data});
- BinaryBroadcastCompute<xpu, mshadow_op::div>(attrs, ctx,
- {normalized_data, std},
- {kWriteTo}, {normalized_data});
-#else
- BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
- {data, mean},
- {kWriteTo}, {normalized_data});
- BinaryBroadcastRTCCompute {"div"}(attrs, ctx,
- {normalized_data, std},
- {kWriteTo}, {normalized_data});
-#endif // !defined(__CUDACC__)
- // Calculate grad_beta
- bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true);
- if (req[2] != kNullOp) {
- MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, {
- BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
- if (!safe_acc) {
- broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
- s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
- ograd.reshape(red_exclude_src_shape));
- } else {
- broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
- s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
- ograd.reshape(red_exclude_src_shape));
- }
- });
- });
- }
- // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis)
-#if !defined(__CUDACC__)
- ElemwiseBinaryOp::Compute<xpu, op::mshadow_op::mul>(attrs, ctx, {normalized_data, ograd},
- {kWriteTo}, {ograd_mult});
-#else
- ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd},
- {kWriteTo}, {ograd_mult});
-#endif // !defined(__CUDACC__)
- if (req[1] != kNullOp) {
- MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, {
- BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
- if (!safe_acc) {
- broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
- s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
- ograd_mult.reshape(red_exclude_src_shape));
- } else {
- broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
- s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
- ograd_mult.reshape(red_exclude_src_shape));
- }
- });
- });
- }
- // Calculate grad_data:
- // ograd_mult = ograd * gamma / std
- // grad_data = ograd_mult - mean(ograd_mult, axis)
- // + normalized_data * (-mean(normalized_data * ograd_mult, axis))
- if (req[0] != kNullOp) {
-#if !defined(__CUDACC__)
- BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
- {ograd, gamma},
- {kWriteTo}, {ograd_mult});
- BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
- {ograd_mult, std},
- {kWriteTo}, {ograd_mult});
-#else
- BinaryBroadcastRTCCompute {"mul"}(attrs, ctx,
- {ograd, gamma},
- {kWriteTo}, {ograd_mult});
- BinaryBroadcastRTCCompute {"div"}(attrs, ctx,
- {ograd_mult, std},
- {kWriteTo}, {ograd_mult});
-#endif // !defined(__CUDACC__)
- MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
- BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
- if (!safe_acc) {
- broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
- s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
- ograd_mult.reshape(red_src_shape));
- } else {
- broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
- s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
- ograd_mult.reshape(red_src_shape));
- }
- });
- Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
- red_out_tensor /= scalar<DType>(channel_size);
- });
-#if !defined(__CUDACC__)
- BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
- {ograd_mult, red_out},
- {req[0]}, {outputs[0]});
- ElemwiseBinaryOp::Compute<xpu, op::mshadow_op::mul>(attrs, ctx, {ograd_mult, normalized_data},
- {kWriteTo}, {ograd_mult});
-#else
- BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
- {ograd_mult, red_out},
- {req[0]}, {outputs[0]});
- ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data},
- {kWriteTo}, {ograd_mult});
-#endif // !defined(__CUDACC__)
- MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
- BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
- if (!safe_acc) {
- broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
- s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
- ograd_mult.reshape(red_src_shape));
- } else {
- broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
- s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
- ograd_mult.reshape(red_src_shape));
- }
- });
- Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
- red_out_tensor /= scalar<DType>(- channel_size);
- });
-#if !defined(__CUDACC__)
- BinaryBroadcastCompute<xpu, mshadow_op::mul>(attrs, ctx,
- {normalized_data, red_out},
- {kAddTo}, {outputs[0]});
-#else
- BinaryBroadcastRTCCompute {"mul"}(attrs, ctx,
- {normalized_data, red_out},
- {kAddTo}, {outputs[0]});
-#endif // !defined(__CUDACC__)
- }
+
+ LayerNormGradComputeGeneralImpl(attrs, ctx, ograd, data, gamma, mean, std, normalized_data,
+ ograd_mult, red_out, req, outputs, workspace, red_dst_shape,
+ red_src_shape, red_exclude_dst_shape, red_exclude_src_shape,
+ channel_size);
}
} // namespace op
diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc
index 0884720..1a040fa 100644
--- a/src/operator/nn/layer_norm.cc
+++ b/src/operator/nn/layer_norm.cc
@@ -268,6 +268,122 @@ void LayerNormCompute<cpu>(const nnvm::NodeAttrs& attrs,
LayerNormComputeGeneral<cpu>(attrs, ctx, inputs, req, outputs);
}
+template <>
+void LayerNormGradComputeGeneralImpl<cpu>(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const TBlob& ograd,
+ const TBlob& data,
+ const TBlob& gamma,
+ const TBlob& mean,
+ const TBlob& std,
+ const TBlob& normalized_data,
+ const TBlob& ograd_mult,
+ const TBlob& red_out,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs,
+ const mshadow::Tensor<cpu, 1, char>& workspace,
+ const mxnet::TShape& red_dst_shape,
+ const mxnet::TShape& red_src_shape,
+ const mxnet::TShape& red_exclude_dst_shape,
+ const mxnet::TShape& red_exclude_src_shape,
+ const int channel_size) {
+ using namespace mshadow;
+ using namespace mshadow::expr;
+ Stream<cpu> *s = ctx.get_stream<cpu>();
+ // Compute normalized_data = (data - mean) / std
+ BinaryBroadcastCompute<cpu, mshadow_op::minus>(attrs, ctx,
+ {data, mean},
+ {kWriteTo}, {normalized_data});
+ BinaryBroadcastCompute<cpu, mshadow_op::div>(attrs, ctx,
+ {normalized_data, std},
+ {kWriteTo}, {normalized_data});
+ // Calculate grad_beta
+ bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true);
+ if (req[2] != kNullOp) {
+ MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, {
+ BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+ if (!safe_acc) {
+ broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
+ s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
+ ograd.reshape(red_exclude_src_shape));
+ } else {
+ broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
+ s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
+ ograd.reshape(red_exclude_src_shape));
+ }
+ });
+ });
+ }
+ // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis)
+ ElemwiseBinaryOp::Compute<cpu, op::mshadow_op::mul>(attrs, ctx, {normalized_data, ograd},
+ {kWriteTo}, {ograd_mult});
+ if (req[1] != kNullOp) {
+ MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, {
+ BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+ if (!safe_acc) {
+ broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
+ s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
+ ograd_mult.reshape(red_exclude_src_shape));
+ } else {
+ broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
+ s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
+ ograd_mult.reshape(red_exclude_src_shape));
+ }
+ });
+ });
+ }
+ // Calculate grad_data:
+ // ograd_mult = ograd * gamma / std
+ // grad_data = ograd_mult - mean(ograd_mult, axis)
+ // + normalized_data * (-mean(normalized_data * ograd_mult, axis))
+ if (req[0] != kNullOp) {
+ BinaryBroadcastCompute<cpu, op::mshadow_op::mul>(attrs, ctx,
+ {ograd, gamma},
+ {kWriteTo}, {ograd_mult});
+ BinaryBroadcastCompute<cpu, op::mshadow_op::div>(attrs, ctx,
+ {ograd_mult, std},
+ {kWriteTo}, {ograd_mult});
+ MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+ BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+ if (!safe_acc) {
+ broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
+ s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+ ograd_mult.reshape(red_src_shape));
+ } else {
+ broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
+ s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+ ograd_mult.reshape(red_src_shape));
+ }
+ });
+ Tensor<cpu, 1, DType> red_out_tensor = red_out.FlatTo1D<cpu, DType>(s);
+ red_out_tensor /= scalar<DType>(channel_size);
+ });
+ BinaryBroadcastCompute<cpu, op::mshadow_op::minus>(attrs, ctx,
+ {ograd_mult, red_out},
+ {req[0]}, {outputs[0]});
+ ElemwiseBinaryOp::Compute<cpu, op::mshadow_op::mul>(attrs, ctx, {ograd_mult, normalized_data},
+ {kWriteTo}, {ograd_mult});
+ MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+ BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+ if (!safe_acc) {
+ broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
+ s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+ ograd_mult.reshape(red_src_shape));
+ } else {
+ broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
+ s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+ ograd_mult.reshape(red_src_shape));
+ }
+ });
+ Tensor<cpu, 1, DType> red_out_tensor = red_out.FlatTo1D<cpu, DType>(s);
+ red_out_tensor /= scalar<DType>(- channel_size);
+ });
+ BinaryBroadcastCompute<cpu, mshadow_op::mul>(attrs, ctx,
+ {normalized_data, red_out},
+ {kAddTo}, {outputs[0]});
+ }
+}
+
template<>
void LayerNormGradCompute<cpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
diff --git a/src/operator/nn/layer_norm.cu b/src/operator/nn/layer_norm.cu
index a60df41..9a33e06 100644
--- a/src/operator/nn/layer_norm.cu
+++ b/src/operator/nn/layer_norm.cu
@@ -29,6 +29,89 @@ using namespace mshadow::cuda;
namespace mxnet {
namespace op {
+template <>
+void LayerNormGradComputeGeneralImpl<gpu>(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const TBlob& ograd,
+ const TBlob& data,
+ const TBlob& gamma,
+ const TBlob& mean,
+ const TBlob& std,
+ const TBlob& normalized_data,
+ const TBlob& ograd_mult,
+ const TBlob& red_out,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs,
+ const mshadow::Tensor<gpu, 1, char>& workspace,
+ const mxnet::TShape& red_dst_shape,
+ const mxnet::TShape& red_src_shape,
+ const mxnet::TShape& red_exclude_dst_shape,
+ const mxnet::TShape& red_exclude_src_shape,
+ const int channel_size) {
+ using namespace mshadow;
+ using namespace mshadow::expr;
+ Stream<gpu> *s = ctx.get_stream<gpu>();
+ // Compute normalized_data = (data - mean) / std
+ BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
+ {data, mean},
+ {kWriteTo}, {normalized_data});
+ BinaryBroadcastRTCCompute {"div"}(attrs, ctx,
+ {normalized_data, std},
+ {kWriteTo}, {normalized_data});
+ // Calculate grad_beta
+ if (req[2] != kNullOp) {
+ BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+ broadcast::RTCReduce(ctx, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
+ ograd.reshape(red_exclude_src_shape), "red::sum{}", NDim, "identity");
+ });
+ }
+ // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis)
+ ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {normalized_data, ograd},
+ {kWriteTo}, {ograd_mult});
+ if (req[1] != kNullOp) {
+ BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+ broadcast::RTCReduce(ctx, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
+ ograd_mult.reshape(red_exclude_src_shape), "red::sum{}", NDim,
+ "identity");
+ });
+ }
+ // Calculate grad_data:
+ // ograd_mult = ograd * gamma / std
+ // grad_data = ograd_mult - mean(ograd_mult, axis)
+ // + normalized_data * (-mean(normalized_data * ograd_mult, axis))
+ if (req[0] != kNullOp) {
+ BinaryBroadcastRTCCompute {"mul"}(attrs, ctx,
+ {ograd, gamma},
+ {kWriteTo}, {ograd_mult});
+ BinaryBroadcastRTCCompute {"div"}(attrs, ctx,
+ {ograd_mult, std},
+ {kWriteTo}, {ograd_mult});
+ BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+ broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+ ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity");
+ });
+ MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+ Tensor<gpu, 1, DType> red_out_tensor = red_out.FlatTo1D<gpu, DType>(s);
+ red_out_tensor /= scalar<DType>(channel_size);
+ });
+ BinaryBroadcastRTCCompute {"sub"}(attrs, ctx,
+ {ograd_mult, red_out},
+ {req[0]}, {outputs[0]});
+ ElemwiseBinaryRTCCompute {"mul"}(attrs, ctx, {ograd_mult, normalized_data},
+ {kWriteTo}, {ograd_mult});
+ BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+ broadcast::RTCReduce(ctx, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+ ograd_mult.reshape(red_src_shape), "red::sum{}", NDim, "identity");
+ });
+ MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+ Tensor<gpu, 1, DType> red_out_tensor = red_out.FlatTo1D<gpu, DType>(s);
+ red_out_tensor /= scalar<DType>(- channel_size);
+ });
+ BinaryBroadcastRTCCompute {"mul"}(attrs, ctx,
+ {normalized_data, red_out},
+ {kAddTo}, {outputs[0]});
+ }
+}
template <typename DType>
__device__ __forceinline__ DType warp_shfl(DType value, int src_lane,
int width = 32, unsigned int mask = 0xffffffff) {
diff --git a/src/operator/nn/moments-inl.h b/src/operator/nn/moments-inl.h
index ca78b65..78c7e4a 100644
--- a/src/operator/nn/moments-inl.h
+++ b/src/operator/nn/moments-inl.h
@@ -126,7 +126,12 @@ inline void MomentsForwardImpl(const OpContext& ctx,
small = ReduceAxesShapeImpl(inputs[0].shape_, axes, true, false);
}
+#if !defined(__CUDACC__)
ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, true>(ctx, {data}, {req[0]}, {mean}, small);
+#else
+ ReduceAxesRTCComputeImpl(ctx, {data}, {req[0]}, {mean}, small, "red::sum{}", nullptr, true);
+#endif
+ TBlob temp;
MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
Shape<6> data_shape, mean_shape;
for (int i = 0; i < 6; ++i) {
@@ -137,9 +142,15 @@ inline void MomentsForwardImpl(const OpContext& ctx,
ctx.requested[0].get_space_typed<xpu, 1, DType>(Shape1(data.shape_.Size()), s);;
Kernel<VarBroadcastKernel, xpu>::Launch(s, data.shape_.Size(), temp_data.dptr_,
data.dptr<DType>(), mean.dptr<DType>(), data_shape, mean_shape);
- ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, true>(
- ctx, {TBlob(temp_data).reshape(data.shape_)}, {kWriteTo}, {var}, small);
+ temp = TBlob(temp_data);
});
+#if !defined(__CUDACC__)
+ ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, true>(
+ ctx, {temp.reshape(data.shape_)}, {kWriteTo}, {var}, small);
+#else
+ ReduceAxesRTCComputeImpl(ctx, {temp.reshape(data.shape_)},
+ {kWriteTo}, {var}, small, "red::sum{}", nullptr, true);
+#endif
}
template<typename xpu>
diff --git a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh
deleted file mode 100644
index d4374ed..0000000
--- a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.cuh
+++ /dev/null
@@ -1,415 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2015-2020 by Contributors
- * \file broadcast_reduce_customized-inl.cuh
- * \brief Customized CUDA implementations for binary broadcast and reduce
- * \author MXNet contributors
-*/
-#ifndef MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_INL_CUSTOMIZED_CUH_
-#define MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_INL_CUSTOMIZED_CUH_
-
-#include "../../tensor/broadcast_reduce-inl.cuh"
-
-using namespace mshadow::cuda;
-
-template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP, int unroll>
-__launch_bounds__(nthread_reduce)
-__global__ void reduce_kernel_wr(const int N, const int M, const bool addto,
- const DType* __restrict big, OType *small,
- const Shape<ndim> big_shape0, const Shape<ndim> small_shape,
- const Shape<ndim> big_shape, const Shape<ndim> big_stride,
- const int Mnext, const bool do_transpose,
- Reducer* reducer) {
- extern __shared__ char shTileChar[];
- AType* shTile = (AType*)(shTileChar);
- const int tid = threadIdx.x + threadIdx.y*blockDim.x;
- const int bx = (do_transpose) ? blockDim.y : blockDim.x;
- const int by = (do_transpose) ? blockDim.x : blockDim.y;
- const int tidx = (do_transpose) ? tid / by : threadIdx.x;
- const int tidy = (do_transpose) ? tid % by : threadIdx.y;
- // bool need_clean = !reducer;
- // reducer = reducer ? reducer : new Reducer();
- for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) {
- // This TB handles M range [Mstart, ...., Mend - 1]
- const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext);
- const int Mend = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext);
- for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) {
- int idx = idx0 + tidx;
- Shape<ndim> coord = unravel(idx, small_shape);
- int idx_big0 = ravel(coord, big_shape0);
-
- AType val, residual;
- reducer->SetInitValue(val, residual);
- if (idx < N) {
- for (int k = tidy + Mstart; k < Mend; k += by*unroll) {
- int idx_big[unroll];
- #pragma unroll
- for (int u=0;u < unroll;u++) {
- idx_big[u] = idx_big0 + unravel_dot(k + u*by, big_shape, big_stride);
- }
- DType tmp[unroll];
- #pragma unroll
- for (int u=0;u < unroll;u++) {
- if (k + u*by < Mend) {
- tmp[u] = OP::Map(big[idx_big[u]]);
- }
- }
- #pragma unroll
- for (int u=0;u < unroll;u++) {
- if (k + u*by < Mend) reducer->Reduce(val, AType(tmp[u]), residual);
- }
- }
- }
-
- // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0
- if (by > 1) {
- // Fix bx to avoid bank conflicts. Assumes warpSize number of banks
- const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx;
- const int it0 = tidx + tidy*fbx;
- shTile[it0 * 2] = val;
- shTile[it0 * 2 + 1] = residual;
- __syncthreads();
- for (int t=1;t < by;t <<= 1) {
- AType tmp, tmp_residual;
- reducer->SetInitValue(tmp, tmp_residual);
- if (tidy + t < by) {
- tmp = shTile[(it0 + t*fbx) * 2];
- tmp_residual = shTile[(it0 + t*fbx) * 2 + 1];
- }
- __syncthreads();
- reducer->Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual);
- __syncthreads();
- }
- if (idx < N && tidy == 0) {
- reducer->Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
- assign(&small[idx + m0*N], addto, OType(shTile[tidx * 2]));
- }
- } else {
- if (idx < N) {
- reducer->Finalize(val, residual);
- assign(&small[idx + m0*N], addto, OType(val));
- }
- }
- }
- }
- // if (need_clean) {
- // delete reducer;
- // }
-}
-
-template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2, int unroll>
-__launch_bounds__(nthread_reduce)
-__global__ void reduce_kernel_wr(const int N, const int M, const bool addto,
- const DType* __restrict big, const DType* __restrict lhs,
- const DType* __restrict rhs, DType *small,
- const Shape<ndim> big_shape0, const Shape<ndim> lhs_shape0,
- const Shape<ndim> rhs_shape0, const Shape<ndim> small_shape,
- const Shape<ndim> big_shape, const Shape<ndim> lhs_shape,
- const Shape<ndim> rhs_shape, const Shape<ndim> big_stride,
- const Shape<ndim> lhs_stride, const Shape<ndim> rhs_stride,
- const int Mnext, const bool do_transpose,
- Reducer* reducer) {
- extern __shared__ char shTileChar[];
- DType* shTile = (DType*)(shTileChar);
- const int tid = threadIdx.x + threadIdx.y*blockDim.x;
- const int bx = (do_transpose) ? blockDim.y : blockDim.x;
- const int by = (do_transpose) ? blockDim.x : blockDim.y;
- const int tidx = (do_transpose) ? tid / by : threadIdx.x;
- const int tidy = (do_transpose) ? tid % by : threadIdx.y;
- // bool need_clean = !reducer;
- // reducer = reducer ? reducer : new Reducer();
- for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) {
- // This TB handles M range [Mstart, ...., Mend - 1]
- const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext);
- const int Mend = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext);
- for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) {
- int idx = idx0 + tidx;
- Shape<ndim> coord = unravel(idx, small_shape);
- int idx_big0 = ravel(coord, big_shape0);
- int idx_lhs0 = ravel(coord, lhs_shape0);
- int idx_rhs0 = ravel(coord, rhs_shape0);
-
- DType val, residual;
- reducer->SetInitValue(val, residual);
- if (idx < N) {
- for (int k = tidy + Mstart; k < Mend; k += by*unroll) {
- int idx_big[unroll];
- int idx_lhs[unroll];
- int idx_rhs[unroll];
- #pragma unroll
- for (int u=0;u < unroll;u++) {
- idx_big[u] = idx_big0 + unravel_dot(k + u*by, big_shape, big_stride);
- idx_lhs[u] = idx_lhs0 + unravel_dot(k + u*by, lhs_shape, lhs_stride);
- idx_rhs[u] = idx_rhs0 + unravel_dot(k + u*by, rhs_shape, rhs_stride);
- }
- DType tmp[unroll];
- #pragma unroll
- for (int u=0;u < unroll;u++) {
- if (k + u*by < Mend) {
- tmp[u] = OP1::Map(big[idx_big[u]], OP2::Map(lhs[idx_lhs[u]], rhs[idx_rhs[u]]));
- }
- }
- #pragma unroll
- for (int u=0;u < unroll;u++) {
- if (k + u*by < Mend) reducer->Reduce(val, tmp[u], residual);
- }
- }
- }
-
- // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0
- if (by > 1) {
- // Fix bx to avoid bank conflicts. Assumes warpSize number of banks
- const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx;
- const int it0 = tidx + tidy*fbx;
- shTile[it0 * 2] = val;
- shTile[it0 * 2 + 1] = residual;
- __syncthreads();
- for (int t=1;t < by;t <<= 1) {
- DType tmp, tmp_residual;
- reducer->SetInitValue(tmp, tmp_residual);
- if (tidy + t < by) {
- tmp = shTile[(it0 + t*fbx) * 2];
- tmp_residual = shTile[(it0 + t*fbx) * 2 + 1];
- }
- __syncthreads();
- reducer->Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual);
- __syncthreads();
- }
- if (idx < N && tidy == 0) {
- reducer->Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
- assign(&small[idx + m0*N], addto, shTile[tidx * 2]);
- }
- } else {
- if (idx < N) {
- reducer->Finalize(val, residual);
- assign(&small[idx + m0*N], addto, val);
- }
- }
- }
- }
- // if (need_clean) {
- // delete reducer;
- // }
-}
-
-// Simple reduction of lines when M is small
-template<typename Reducer, typename DType>
-__launch_bounds__(kMaxThreadsPerBlock)
-__global__ void reduce_lines_kernel_wr(const int N, const int M, const bool addto,
- const int small_in_stride, const DType* __restrict small_in, DType *small_out,
- Reducer* reducer) {
- for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
-
- DType val, residual;
- reducer->SetInitValue(val, residual);
- for (int k = 0; k < M; k++) {
- reducer->Reduce(val, small_in[idx + k*small_in_stride], residual);
- }
-
- if (idx < N) {
- reducer->Finalize(val, residual);
- assign(&small_out[idx], addto, val);
- }
-
- }
-}
-
-template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP>
-__launch_bounds__(kMaxThreadsPerBlock)
-__global__ void reduce_kernel_M1_wr(const int N, const bool addto,
- const DType* __restrict big, OType *small, const Shape<ndim> bshape,
- const Shape<ndim> sshape, Reducer* reducer) {
- for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
- Shape<ndim> coord = unravel(idx, sshape);
- int j = ravel(coord, bshape);
- AType val, residual;
- reducer->SetInitValue(val, residual);
- reducer->Reduce(val, AType(OP::Map(big[j])), residual);
- reducer->Finalize(val, residual);
- assign(&small[idx], addto, OType(val));
- }
-}
-
-template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2>
-__launch_bounds__(kMaxThreadsPerBlock)
-__global__ void reduce_kernel_M1_wr(const int N, const bool addto,
- const DType* __restrict big,
- const DType* __restrict lhs,
- const DType* __restrict rhs,
- DType *small,
- const Shape<ndim> big_shape,
- const Shape<ndim> lhs_shape,
- const Shape<ndim> rhs_shape,
- const Shape<ndim> small_shape,
- Reducer* reducer) {
- for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
- Shape<ndim> coord = unravel(idx, small_shape);
- int idx_big = ravel(coord, big_shape);
- int idx_lhs = ravel(coord, lhs_shape);
- int idx_rhs = ravel(coord, rhs_shape);
- DType val, residual;
- reducer->SetInitValue(val, residual);
- reducer->Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual);
- reducer->Finalize(val, residual);
- assign(&small[idx], addto, val);
- }
-}
-
-#define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \
- if (do_unroll) { \
- const int unrollVar = unrollAmount; \
- {__VA_ARGS__} \
- } else { \
- const int unrollVar = 1; \
- {__VA_ARGS__} \
- }
-
-template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP>
-void ReduceImplWithReducer(cudaStream_t stream, const TBlob& small, const OpReqType req,
- const TBlob& big, const Tensor<gpu, 1, char>& workspace,
- const ReduceImplConfig& config,
- Reducer* reducer = nullptr) {
- bool need_clean = !reducer;
- reducer = reducer ? reducer : new Reducer();
- if (config.M == 1) {
- reduce_kernel_M1_wr<Reducer, ndim, AType, DType, OType, OP>
- <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>(
- config.N, req == kAddTo, big.dptr<DType>(), small.dptr<OType>(), big.shape_.get<ndim>(),
- small.shape_.get<ndim>(), reducer);
- MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1_wr);
- } else {
- OType* small_dptr = small.dptr<OType>();
- bool addto = (req == kAddTo);
- if (config.Mnext > 1) {
- // small_dptr[] is N*Mnext*sizeof(DType) bytes
- small_dptr = reinterpret_cast<OType*>(workspace.dptr_);
- addto = false;
- // Check that the workspace is contigiuous
- CHECK_EQ(workspace.CheckContiguous(), true);
- // Check that we have enough storage
- CHECK_GE(workspace.size(0), config.workspace_size);
- }
-
- const int by = (config.kernel_1.do_transpose) ?
- config.kernel_1.blockDim.x : config.kernel_1.blockDim.y;
- const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce );
- KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, {
- reduce_kernel_wr<Reducer, ndim, AType, DType, OType, OP, UNROLL>
- <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>(
- config.N, config.M, addto, big.dptr<DType>(), small_dptr, big.shape_.get<ndim>(),
- small.shape_.get<ndim>(), config.rshape.get<ndim>(), config.rstride.get<ndim>(),
- config.Mnext, config.kernel_1.do_transpose, reducer);
- });
- MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_wr);
-
- if (config.Mnext > 1) {
- reduce_lines_kernel_wr<Reducer, OType>
- <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>>
- (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr<OType>(), reducer);
- MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel_wr);
- }
- }
- if (need_clean) {
- delete reducer;
- }
-}
-
-template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2>
-void ReduceImplWithReducer(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const TBlob& rhs,
- const OpReqType req, const TBlob& big, const Tensor<gpu, 1, char>& workspace,
- const ReduceImplConfig& config, Reducer* reducer = nullptr) {
- bool need_clean = !reducer;
- reducer = reducer ? reducer : new Reducer();
- if (config.M == 1) {
- reduce_kernel_M1_wr<Reducer, ndim, DType, OP1, OP2>
- <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>(
- config.N, req == kAddTo, big.dptr<DType>(), lhs.dptr<DType>(), rhs.dptr<DType>(),
- small.dptr<DType>(), big.shape_.get<ndim>(), lhs.shape_.get<ndim>(),
- rhs.shape_.get<ndim>(), small.shape_.get<ndim>());
- MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1_wr);
- } else {
- DType* small_dptr = small.dptr<DType>();
- bool addto = (req == kAddTo);
- if (config.Mnext > 1) {
- // small_dptr[] is N*Mnext*sizeof(DType) bytes
- small_dptr = reinterpret_cast<DType*>(workspace.dptr_);
- addto = false;
- // Check that the workspace is contigiuous
- CHECK_EQ(workspace.CheckContiguous(), true);
- // Check that we have enough storage
- CHECK_GE(workspace.size(0), config.workspace_size);
- }
-
- const int by = (config.kernel_1.do_transpose) ?
- config.kernel_1.blockDim.x : config.kernel_1.blockDim.y;
- const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce );
- KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, {
- reduce_kernel_wr<Reducer, ndim, DType, OP1, OP2, UNROLL>
- <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>(
- config.N, config.M, addto, big.dptr<DType>(), lhs.dptr<DType>(), rhs.dptr<DType>(),
- small_dptr, big.shape_.get<ndim>(), lhs.shape_.get<ndim>(),
- rhs.shape_.get<ndim>(), small.shape_.get<ndim>(), config.rshape, config.lhs_shape,
- config.rhs_shape, config.rstride, config.lhs_stride, config.rhs_stride, config.Mnext,
- config.kernel_1.do_transpose, reducer);
- MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_wr);
- });
-
- if (config.Mnext > 1) {
- reduce_lines_kernel_wr<Reducer, DType>
- <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>>
- (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr<DType>(), reducer);
- MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel_wr);
- }
- }
- if (need_clean) {
- delete reducer;
- }
-}
-
-#undef KERNEL_UNROLL_SWITCH
-
-template<typename Reducer, int ndim, typename DType, typename OP, bool safe_acc = false>
-void ReduceWithReducer(Stream<gpu> *s, const TBlob& small, const OpReqType req,
- const Tensor<gpu, 1, char>& workspace, const TBlob& big, Reducer* reducer = nullptr) {
- if (req == kNullOp) return;
- cudaStream_t stream = Stream<gpu>::GetStream(s);
- bool need_clean = !reducer;
- reducer = reducer ? reducer : new Reducer();
- ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, sizeof(DType));
- if (safe_acc) {
- MXNET_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
- typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
- MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
- typedef typename std::conditional<safe_acc, OType, DataType>::type OutType;
- config = ReduceImplConfig(small.shape_, big.shape_, nullptr, nullptr, sizeof(AccType));
- ReduceImplWithReducer<Reducer, ndim, AccType, DataType, OutType, OP>(
- stream, small, req, big, workspace, config, reducer);
- });
- });
- } else {
- ReduceImplWithReducer<Reducer, ndim, DType, DType, DType, OP>(stream, small, req, big, workspace, config, reducer);
- }
- if (need_clean) {
- delete reducer;
- }
-}
-
-#endif // MXNET_OPERATOR_NUMPY_LINALG_BROADCAST_REDUCE_INL_CUSTOMIZED_CUH_
diff --git a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h
index 0226df4..2941d54 100644
--- a/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h
+++ b/src/operator/numpy/linalg/broadcast_reduce_customized-inl.h
@@ -54,12 +54,6 @@ MSHADOW_XINLINE void seq_reduce_assign_wr(const index_t idx, const size_t M, con
assign(&small[idx], addto, OType(val));
}
-#ifdef __CUDACC__
-#include "broadcast_reduce_customized-inl.cuh"
-#include "../../tensor/broadcast_reduce-inl.cuh"
-
-#else
-
template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP>
void seq_reduce_compute_wr(const size_t N, const size_t M, const bool addto,
const DType *big, OType *small, const Shape<ndim> bshape,
@@ -177,7 +171,6 @@ void ReduceWithReducer(Stream<cpu> *s, const TBlob& small, const OpReqType req,
reducer);
}
-#endif
} // namespace broadcast
} // namespace op
} // namespace mxnet
diff --git a/src/operator/numpy/linalg/broadcast_reduce_op_customized.h b/src/operator/numpy/linalg/broadcast_reduce_op_customized.h
index 8e1c0b3..976991f 100644
--- a/src/operator/numpy/linalg/broadcast_reduce_op_customized.h
+++ b/src/operator/numpy/linalg/broadcast_reduce_op_customized.h
@@ -46,19 +46,17 @@ void ReduceAxesComputeImplWithReducer(const OpContext& ctx,
mxnet::TShape src_shape, dst_shape;
BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape);
Stream<xpu> *s = ctx.get_stream<xpu>();
- MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
- MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, {
- const TBlob in_data = inputs[0].reshape(src_shape);
- const TBlob out_data = outputs[0].reshape(dst_shape);
- BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
- size_t workspace_size = broadcast::ReduceWorkspaceSize(
- s, out_data.shape_, req[0], in_data.shape_, sizeof(OType));
- Tensor<xpu, 1, char> workspace =
- ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
- broadcast::ReduceWithReducer<Reducer, NDim, OType, OP, safe_acc>(
- s, out_data, req[0], workspace, in_data, reducer);
- // no normalization
- });
+ MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, {
+ const TBlob in_data = inputs[0].reshape(src_shape);
+ const TBlob out_data = outputs[0].reshape(dst_shape);
+ BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
+ size_t workspace_size = broadcast::ReduceWorkspaceSize(
+ s, out_data.shape_, req[0], in_data.shape_);
+ Tensor<xpu, 1, char> workspace =
+ ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+ broadcast::ReduceWithReducer<Reducer, NDim, OType, OP, safe_acc>(
+ s, out_data, req[0], workspace, in_data, reducer);
+ // no normalization
});
});
}
diff --git a/src/operator/numpy/linalg/np_norm-inl.h b/src/operator/numpy/linalg/np_norm-inl.h
index b26e680..60dee6a 100644
--- a/src/operator/numpy/linalg/np_norm-inl.h
+++ b/src/operator/numpy/linalg/np_norm-inl.h
@@ -285,18 +285,10 @@ void NumpyLpNormCompute(const nnvm::NodeAttrs& attrs,
} else if (param.ord == std::numeric_limits<double>::infinity()) { // inf norm
LOG(FATAL) << "inf norm handled in front-end.";
} else {
+#ifndef __CUDACC__
mshadow_op::nrmlp host_reducer(param.ord);
mshadow_op::nrmlp *reducer_instance = nullptr;
-#ifdef __CUDACC__
- Stream<xpu> *s = ctx.get_stream<xpu>();
- cudaStream_t copy_stream = mshadow::Stream<gpu>::GetStream(s);
- cudaMalloc(reinterpret_cast<void**>(&reducer_instance), sizeof(mshadow_op::nrmlp));
- cudaMemcpyAsync(reducer_instance, &host_reducer, sizeof(mshadow_op::nrmlp),
- cudaMemcpyHostToDevice, copy_stream);
- cudaStreamSynchronize(copy_stream);
-#else
reducer_instance = &host_reducer;
-#endif
if (safe_acc) {
ReduceAxesComputeImplWithReducer<xpu, mshadow_op::nrmlp, true, mshadow_op::abs>(
ctx, inputs, req, outputs, small, reducer_instance);
@@ -304,8 +296,10 @@ void NumpyLpNormCompute(const nnvm::NodeAttrs& attrs,
ReduceAxesComputeImplWithReducer<xpu, mshadow_op::nrmlp, false, mshadow_op::abs>(
ctx, inputs, req, outputs, small, reducer_instance);
}
-#ifdef __CUDACC__
- cudaFree(reducer_instance);
+#else
+ ReduceAxesRTCComputeImpl(
+ ctx, inputs, req, outputs, small, "red::nrmlp{" + std::to_string(param.ord) + "}",
+ nullptr, false, "abs");
#endif
}
}
@@ -443,8 +437,13 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs,
}
if (param.flag == 1) { // Frobenius norm
- ReduceAxesComputeImplWithReducer<xpu, mshadow_op::nrm2, false, mshadow_op::identity>(
+#if !defined(__CUDACC__)
+ ReduceAxesComputeImpl<xpu, mshadow_op::nrm2, false, false, mshadow_op::identity>(
ctx, inputs, req, outputs, reduced_shape);
+#else
+ ReduceAxesRTCComputeImpl(
+ ctx, inputs, req, outputs, reduced_shape, "red::nrm2{}", nullptr, false, "identity");
+#endif
return;
}
@@ -453,19 +452,29 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs,
if (param.ord != 2 && param.ord != -2) { // row norm or col norm
TShape sum_shape = inputs[0].shape_;
sum_shape[mat_axis[!(param.ord == 1 || param.ord == -1)]] = 1;
- MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
- TBlob temp = outputs[1].reshape(sum_shape);
- std::vector<TBlob> sum_output({temp});
- ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, false, mshadow_op::abs>(
- ctx, inputs, req, sum_output, sum_shape);
- if (param.ord > 0) {
- ReduceAxesComputeImpl<xpu, mshadow::red::maximum, false, false, mshadow_op::identity>(
- ctx, sum_output, req, outputs, reduced_shape);
- } else {
- ReduceAxesComputeImpl<xpu, mshadow::red::minimum, false, false, mshadow_op::identity>(
- ctx, sum_output, req, outputs, reduced_shape);
- }
- });
+ TBlob temp = outputs[1].reshape(sum_shape);
+ std::vector<TBlob> sum_output({temp});
+#if !defined(__CUDACC__)
+ ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, false, mshadow_op::abs>(
+ ctx, inputs, req, sum_output, sum_shape);
+ if (param.ord > 0) {
+ ReduceAxesComputeImpl<xpu, mshadow::red::maximum, false, false, mshadow_op::identity>(
+ ctx, sum_output, req, outputs, reduced_shape);
+ } else {
+ ReduceAxesComputeImpl<xpu, mshadow::red::minimum, false, false, mshadow_op::identity>(
+ ctx, sum_output, req, outputs, reduced_shape);
+ }
+#else
+ ReduceAxesRTCComputeImpl(ctx, inputs, req, sum_output, sum_shape,
+ "red::sum{}", nullptr, false, "abs");
+ if (param.ord > 0) {
+ ReduceAxesRTCComputeImpl(ctx, sum_output, req, outputs, reduced_shape,
+ "red::maximum{}", nullptr, false);
+ } else {
+ ReduceAxesRTCComputeImpl(ctx, sum_output, req, outputs, reduced_shape,
+ "red::minimum{}", nullptr, false);
+ }
+#endif
return;
}
@@ -500,6 +509,7 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs,
L_trans[mat_axis[1]] = 1;
}
+ std::vector<TBlob> eigen;
MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 3, DType> UT =
outputs[1].get_with_shape<xpu, 3, DType>(Shape3(batch_dim, row_dim, row_dim), s);
@@ -523,32 +533,46 @@ void NumpyMatrixNormCompute(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 3, DType> svd_input =
workspace.get_with_shape<xpu, 3, DType>(Shape3(batch_dim, row_dim, col_dim), s);
gesvd::op(svd_input, UT, L, V, ctx, attrs, &svd_workspace);
-
TBlob workspace0(reinterpret_cast<DType*>(temp.dptr_), L_trans,
temp.dev_mask(), temp.dev_id());
TransposeImpl<xpu>(ctx.run_ctx, TBlob(L).reshape(L_shape), workspace0, reduce_axes);
- std::vector<TBlob> eigen({ workspace0 });
- if (param.flag == 2) { // nuclear norm
- ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, false, mshadow_op::identity>(
+ eigen.emplace_back(workspace0);
+ });
+
+#if !defined(__CUDACC__)
+ if (param.flag == 2) { // nuclear norm
+ ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, false, mshadow_op::identity>(
+ ctx, eigen, req, outputs, reduced_shape);
+ } else if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true)) {
+ if (ord == 2) {
+ ReduceAxesComputeImpl<xpu, mshadow::red::maximum, true, false, mshadow_op::abs>(
+ ctx, eigen, req, outputs, reduced_shape);
+ } else if (ord == -2) {
+ ReduceAxesComputeImpl<xpu, mshadow::red::minimum, true, false, mshadow_op::abs>(
ctx, eigen, req, outputs, reduced_shape);
- } else if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true)) {
- if (ord == 2) {
- ReduceAxesComputeImpl<xpu, mshadow::red::maximum, true, false, mshadow_op::abs>(
- ctx, eigen, req, outputs, reduced_shape);
- } else if (ord == -2) {
- ReduceAxesComputeImpl<xpu, mshadow::red::minimum, true, false, mshadow_op::abs>(
- ctx, eigen, req, outputs, reduced_shape);
- }
- } else {
- if (ord == 2) {
- ReduceAxesComputeImpl<xpu, mshadow::red::maximum, false, false, mshadow_op::abs>(
- ctx, eigen, req, outputs, reduced_shape);
- } else if (ord == -2) {
- ReduceAxesComputeImpl<xpu, mshadow::red::minimum, false, false, mshadow_op::abs>(
- ctx, eigen, req, outputs, reduced_shape);
- }
}
- });
+ } else {
+ if (ord == 2) {
+ ReduceAxesComputeImpl<xpu, mshadow::red::maximum, false, false, mshadow_op::abs>(
+ ctx, eigen, req, outputs, reduced_shape);
+ } else if (ord == -2) {
+ ReduceAxesComputeImpl<xpu, mshadow::red::minimum, false, false, mshadow_op::abs>(
+ ctx, eigen, req, outputs, reduced_shape);
+ }
+ }
+#else
+ if (param.flag == 2) { // nuclear norm
+ ReduceAxesRTCComputeImpl(ctx, eigen, req, outputs, reduced_shape, "red::sum{}", nullptr, false);
+ } else {
+ if (ord == 2) {
+ ReduceAxesRTCComputeImpl(ctx, eigen, req, outputs, reduced_shape,
+ "red::maximum{}", nullptr, false, "abs");
+ } else if (ord == -2) {
+ ReduceAxesRTCComputeImpl(ctx, eigen, req, outputs, reduced_shape,
+ "red::minimum{}", nullptr, false, "abs");
+ }
+ }
+#endif
}
template<typename xpu>
@@ -784,8 +808,13 @@ void NumpyNormComputeForward(const nnvm::NodeAttrs& attrs,
std::vector<TBlob> flat_outputs({
outputs[0].reshape(TShape(1, 1))
});
- ReduceAxesComputeImplWithReducer<xpu, mshadow_op::nrm2, false, mshadow_op::identity>(
+#if !defined(__CUDACC__)
+ ReduceAxesComputeImpl<xpu, mshadow_op::nrm2, false, false, mshadow_op::identity>(
ctx, flat_inputs, req, flat_outputs, TShape(1, 1));
+#else
+ ReduceAxesRTCComputeImpl(
+ ctx, flat_inputs, req, flat_outputs, TShape(1, 1), "red::nrm2{}", nullptr, false, "identity");
+#endif
return;
}
diff --git a/src/operator/numpy/np_broadcast_reduce_op.cc b/src/operator/numpy/np_broadcast_reduce_op.cc
new file mode 100644
index 0000000..4b64a1a
--- /dev/null
+++ b/src/operator/numpy/np_broadcast_reduce_op.cc
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file np_broadcast_reduce_op.cc
+ * \brief Function definitions of NumPy-compatible
+ * broadcast and reduce operators
+ */
+
+#include "np_broadcast_reduce_op.h"
+
+namespace mxnet {
+namespace op {
+#if MXNET_USE_CUDA
+
+void NumpyArgMinMaxRTCCompute::operator()(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ using namespace mshadow;
+ using namespace mshadow::expr;
+ if (req[0] == kNullOp) return;
+ // parse param
+ const auto& param = nnvm::get<ReduceAxisParam>(attrs.parsed);
+ mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
+ TBlob out = outputs[0];
+ TBlob in = inputs[0];
+ // do some shape checks
+ if (in.shape_.ndim() != 0) {
+ if (param.axis.has_value()) {
+ // cannot do argmax in an empty dimension
+ int axis = param.axis.value();
+ axis = CheckAxis(axis, in.shape_.ndim());
+ CHECK_NE(in.shape_[axis], 0)
+ << "searching input tensor of shape " << inputs[0].shape_
+ << " along axis = " << axis << " of zero dim-size is not allowed";
+ } else {
+ // cannot do argmax on an empty array
+ CHECK_NE(in.shape_.Size(), 0U) << "attempt to search an empty sequence";
+ }
+ }
+ if (in.shape_.Size() == 0U) return; // zero-size tensor
+ // prepare shape
+ dmlc::optional<mxnet::Tuple<int>> axes;
+ if (param.axis.has_value()) {
+ mxnet::Tuple<int> t({param.axis.value()});
+ axes = dmlc::optional<mxnet::Tuple<int>>(t);
+ }
+ TShape small;
+ small = NumpyReduceAxesShapeImpl(in.shape_, axes, true);
+ mxnet::TShape src_shape, dst_shape;
+ BroadcastReduceShapeCompact(in.shape_, small, &src_shape, &dst_shape);
+ const TBlob in_data = in.reshape(src_shape);
+ // request a work space
+ size_t workspace_size = broadcast::ReduceWorkspaceSize(s, dst_shape, req[0], src_shape);
+ Tensor<gpu, 1, char> workspace =
+ ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(workspace_size), s);
+ BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
+ broadcast::RTCReduce(ctx, outputs[0].reshape(dst_shape), req[0], workspace, in_data,
+ reducer, NDim, "identity", true);
+ });
+}
+
+#endif // MXNET_USE_CUDA
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/numpy/np_broadcast_reduce_op.cuh b/src/operator/numpy/np_broadcast_reduce_op.cuh
deleted file mode 100644
index f97aa78..0000000
--- a/src/operator/numpy/np_broadcast_reduce_op.cuh
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2015-2020 by Contributors
- * \file np_broadcast_reduce-inl.cuh
- * \brief GPU implementations for numpy binary broadcast ops
- * \author Zhaoqi Zhu
-*/
-#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
-#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
-
-using namespace mshadow::cuda;
-using namespace mshadow;
-using namespace broadcast;
-
-template<typename Reducer, int NDim, typename DType, typename OType>
-void NumpyArgMinMaxReduce(Stream<gpu> *s, const TBlob& in_data, const TBlob& out_data,
- const Tensor<gpu, 1, char>& workspace) {
- cudaStream_t stream = Stream<gpu>::GetStream(s);
- ReduceImplConfig config(out_data.shape_, in_data.shape_, nullptr, nullptr, sizeof(OType));
-
- ReduceImpl<Reducer, NDim, OType, DType, OType, mxnet::op::mshadow_op::identity,
- mxnet::op::mshadow_op::arg_min_max_set_index<OType, int>>
- (stream, out_data, kWriteTo, in_data, workspace, config);
-}
-
-#endif // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h
index 80714fb6..9ce3967 100644
--- a/src/operator/numpy/np_broadcast_reduce_op.h
+++ b/src/operator/numpy/np_broadcast_reduce_op.h
@@ -298,7 +298,7 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
if (req[0] == kNullOp) return;
- const NumpyReduceAxesParam& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
+ const auto& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
if (param.initial.has_value()) {
LOG(FATAL) << "initial is not supported yet";
}
@@ -494,10 +494,6 @@ void NumpyArgMinMaxReduce(mshadow::Stream<cpu> *s, const TBlob& in_data, const T
in_data.shape_.get<NDim>(), out_data.shape_.get<NDim>(), rshape, rstride);
}
-#ifdef __CUDACC__
-#include "np_broadcast_reduce_op.cuh"
-#endif
-
template<typename Reducer, typename xpu, typename IType>
void NumpyArgMinMaxCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
@@ -508,7 +504,7 @@ void NumpyArgMinMaxCompute(const nnvm::NodeAttrs& attrs,
using namespace mshadow::expr;
if (req[0] == kNullOp) return;
// parse param
- const ReduceAxisParam& param = nnvm::get<ReduceAxisParam>(attrs.parsed);
+ const auto& param = nnvm::get<ReduceAxisParam>(attrs.parsed);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
TBlob out = outputs[0];
TBlob in = inputs[0];
@@ -537,36 +533,52 @@ void NumpyArgMinMaxCompute(const nnvm::NodeAttrs& attrs,
small = NumpyReduceAxesShapeImpl(in.shape_, axes, true);
mxnet::TShape src_shape, dst_shape;
BroadcastReduceShapeCompact(in.shape_, small, &src_shape, &dst_shape);
+ const TBlob in_data = in.reshape(src_shape);
+ // request a work space
+ size_t workspace_size = broadcast::ReduceWorkspaceSize(s, dst_shape, req[0], src_shape);
MSHADOW_TYPE_SWITCH_WITH_BOOL(in.type_flag_, DType, {
// define OType
typedef mxnet::op::mshadow_op::IndexedNum<IType, DType> OType;
- // request a work space
- size_t workspace_size = sizeof(OType) * out.shape_.Size();
- Tensor<xpu, 1, char> workspace =
- ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
- // set up intermediate output
- TBlob intermediate = out;
- intermediate.dptr_ = reinterpret_cast<int64_t*>(workspace.dptr_);
- // reshape the input and intermediate output tensor
- const TBlob in_data = in.reshape(src_shape);
- const TBlob intermediate_out_data = intermediate.reshape(dst_shape);
// switch dim
BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
- size_t workspace_size = broadcast::ReduceWorkspaceSize(
- s, intermediate_out_data.shape_, req[0], in_data.shape_, sizeof(OType));
+ constexpr size_t align_size = 1024;
+ const size_t aligned_first_workspace_size = ((workspace_size + align_size - 1) / align_size)
+ * align_size;
+ workspace_size = aligned_first_workspace_size +
+ sizeof(OType) * out.shape_.Size();
Tensor<xpu, 1, char> workspace =
- ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+ ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+ // set up intermediate output
+ TBlob intermediate = out;
+ intermediate.dptr_ = reinterpret_cast<int64_t*>(workspace.dptr_ +
+ aligned_first_workspace_size);
+ // reshape the input and intermediate output tensor
+ const TBlob intermediate_out_data = intermediate.reshape(dst_shape);
NumpyArgMinMaxReduce<Reducer, NDim, DType, OType>(s, in_data,
intermediate_out_data, workspace);
+ // parse the indices from the intermediate tensor back to the actual output tensor
+ using namespace mxnet_op;
+ Kernel<arg_min_max_parse, xpu>::Launch(
+ s, out.shape_.Size(), outputs[0].dptr<int64_t>(),
+ static_cast<OType*>(intermediate_out_data.dptr_));
});
- // parse the indices from the intermediate tensor back to the actual output tensor
- using namespace mxnet_op;
- Kernel<arg_min_max_parse, xpu>::Launch(
- s, out.shape_.Size(), outputs[0].dptr<int64_t>(),
- static_cast<OType*>(intermediate_out_data.dptr_));
});
}
+#if MXNET_USE_CUDA
+
+struct NumpyArgMinMaxRTCCompute {
+ std::string reducer;
+
+ void operator()(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs);
+};
+
+#endif
+
template<typename xpu, bool normalize = false>
inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
@@ -663,36 +675,6 @@ struct NumpyMomentsParam : public dmlc::Parameter<NumpyMomentsParam> {
}
};
-template<typename xpu, typename reducer, bool safe_acc, bool normalize = false,
- typename OP = op::mshadow_op::identity>
-void ReduceAxesComputeWithWorkspaceImpl(const OpContext& ctx,
- const std::vector<TBlob>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<TBlob>& outputs,
- const mshadow::Tensor<xpu, 1, char>& workspace,
- const mxnet::TShape& src_shape,
- const mxnet::TShape& dst_shape,
- const int ddof = 0) {
- using namespace mshadow;
- using namespace mshadow::expr;
-
- Stream<xpu> *s = ctx.get_stream<xpu>();
- MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
- MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
- const TBlob in_data = inputs[0].reshape(src_shape);
- const TBlob out_data = outputs[0].reshape(dst_shape);
- BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
- broadcast::Reduce<reducer, NDim, DType, OP, safe_acc>(
- s, out_data, req[0], workspace, in_data);
- if (normalize) {
- auto out = out_data.FlatTo2D<xpu, OType>(s);
- out /= scalar<OType>(src_shape.Size()/dst_shape.Size() - ddof);
- }
- });
- });
- });
-}
-
struct NumpyWeightedAverageParam : public dmlc::Parameter<NumpyWeightedAverageParam> {
dmlc::optional<mxnet::Tuple<int>> axis;
bool returned;
@@ -871,13 +853,6 @@ struct avg_grad_w_1D_kernel {
}
};
-// Windows has issues with #ifdefs inside MSHADOW_TYPE_SWITCH
-#ifndef __CUDACC__
-#define NP_BROADCAST_REDUCE_OP_BROADCAST(OP) BinaryBroadcastCompute<xpu, mshadow_op::OP>
-#else
-#define NP_BROADCAST_REDUCE_OP_BROADCAST(OP) BinaryBroadcastRTCCompute {#OP}
-#endif
-
template<typename xpu, bool back = false>
void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
@@ -914,6 +889,9 @@ void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs,
weights = weights.reshape(new_w_shape);
small2 = TShape(new_w_shape.ndim(), 1);
}
+ TBlob wa;
+ TBlob sum_of_wa;
+ Tensor<xpu, 1, char> workspace;
MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
// Get temp space
size_t temp_data_size = data.shape_.Size() * sizeof(DType);
@@ -922,38 +900,53 @@ void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs,
BroadcastReduceShapeCompact(data.shape_, small1, &src_shape, &dst_shape);
size_t workspace_size = 0;
workspace_size = broadcast::ReduceWorkspaceSize(
- s, dst_shape, {kWriteTo}, src_shape, sizeof(DType));
+ s, dst_shape, {kWriteTo}, src_shape);
size_t temp_mem_size = temp_data_size + temp_sum_size + workspace_size;
Tensor<xpu, 1, char> temp_mem =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(temp_mem_size), s);
- DType *temp_data_ptr = reinterpret_cast<DType*>(temp_mem.dptr_);
- DType *temp_sum_ptr = reinterpret_cast<DType*>(temp_mem.dptr_ + temp_data_size);
+ auto *temp_data_ptr = reinterpret_cast<DType*>(temp_mem.dptr_);
+ auto *temp_sum_ptr = reinterpret_cast<DType*>(temp_mem.dptr_ + temp_data_size);
char *workspace_ptr = temp_mem.dptr_ + temp_data_size + temp_sum_size;
- Tensor<xpu, 1, char> workspace(workspace_ptr, Shape1(workspace_size), s);
+ workspace = Tensor<xpu, 1, char>(workspace_ptr, Shape1(workspace_size), s);
// Compute weighted data
- TBlob wa = TBlob(temp_data_ptr, data.shape_, xpu::kDevMask);
- NP_BROADCAST_REDUCE_OP_BROADCAST(mul)(
- attrs, ctx, {data, weights}, {kWriteTo}, {wa});
-
- // Compute sum of weighted data
- TBlob sum_of_wa = TBlob(temp_sum_ptr, small1, xpu::kDevMask);
- ReduceAxesComputeWithWorkspaceImpl<xpu, mshadow_op::sum, true>(
- ctx, {wa}, {kWriteTo}, {sum_of_wa}, workspace, src_shape, dst_shape);
- if (!back) {
- const TBlob& avg = outputs[0];
- const TBlob& sum_of_weights = outputs[1];
- TShape w_src_shape, w_dst_shape;
- BroadcastReduceShapeCompact(weights.shape_, small2, &w_src_shape, &w_dst_shape);
- // Compute sum of weight
- TBlob scl = sum_of_weights.reshape(small2);
- ReduceAxesComputeWithWorkspaceImpl<xpu, mshadow_op::sum, true>(
- ctx, {weights}, {kWriteTo}, {scl}, workspace, w_src_shape, w_dst_shape);
-
- // Compute avg and assign output
- NP_BROADCAST_REDUCE_OP_BROADCAST(div)(
- attrs, ctx, {sum_of_wa, scl}, req, {avg.reshape(small1)});
- } else {
+ wa = TBlob(temp_data_ptr, data.shape_, xpu::kDevMask);
+ sum_of_wa = TBlob(temp_sum_ptr, small1, xpu::kDevMask);
+ });
+#if !defined(__CUDACC__)
+ BinaryBroadcastCompute<xpu, mshadow_op::mul>(
+ attrs, ctx, {data, weights}, {kWriteTo}, {wa});
+
+ // Compute sum of weighted data
+ ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(
+ ctx, {wa}, {kWriteTo}, {sum_of_wa}, small1, &workspace);
+#else
+ BinaryBroadcastRTCCompute {"mul"}(attrs, ctx, {data, weights}, {kWriteTo}, {wa});
+
+ // Compute sum of weighted data
+ ReduceAxesRTCComputeImpl(ctx, {wa}, {kWriteTo}, {sum_of_wa}, small1, "red::sum{}",
+ &workspace, false, "identity");
+#endif
+ if (!back) {
+ const TBlob& avg = outputs[0];
+ const TBlob& sum_of_weights = outputs[1];
+ // Compute sum of weight
+ TBlob scl = sum_of_weights.reshape(small2);
+#if !defined(__CUDACC__)
+ ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(
+ ctx, {weights}, {kWriteTo}, {scl}, small2, &workspace);
+ // Compute avg and assign output
+ BinaryBroadcastCompute<xpu, mshadow_op::div>(
+ attrs, ctx, {sum_of_wa, scl}, req, {avg.reshape(small1)});
+#else
+ ReduceAxesRTCComputeImpl(ctx, {weights}, {kWriteTo}, {scl}, small2, "red::sum{}",
+ &workspace, false, "identity");
+ // Compute avg and assign output
+ BinaryBroadcastRTCCompute {"div"}(
+ attrs, ctx, {sum_of_wa, scl}, req, {avg.reshape(small1)});
+#endif
+ } else {
+ MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
// Compute and assign the derivatives of a and weights
const TBlob& igrad_a = outputs[0];
const TBlob& igrad_w = outputs[1];
@@ -992,12 +985,10 @@ void NumpyWeightedAverageComputeImpl(const nnvm::NodeAttrs& attrs,
}
});
})
- }
- });
+ });
+ }
}
-#undef NP_BROADCAST_REDUCE_OP_BROADCAST
-
template<typename xpu>
void NumpyWeightedAverageForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
@@ -1010,21 +1001,29 @@ void NumpyWeightedAverageForward(const nnvm::NodeAttrs& attrs,
CHECK_NE(req[0], kWriteInplace) << "Average does not support write in-place";
const auto& param = nnvm::get<NumpyWeightedAverageParam>(attrs.parsed);
const TBlob& data = inputs[0];
+ TShape small;
MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
if (!param.weighted) {
- TShape small = NumpyReduceAxesShapeImpl(data.shape_, param.axis, true);
+ small = NumpyReduceAxesShapeImpl(data.shape_, param.axis, true);
// Compute sum of weights which equals to the product of sizes of reduced axes
Stream<xpu>* s = ctx.get_stream<xpu>();
auto ret = outputs[1].FlatTo1D<xpu, DType>(s);
ret = scalar<DType>(data.shape_.Size()/small.Size());
- // Compute mean
- ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, true>(
- ctx, inputs, req, {outputs[0]}, small);
- } else {
- NumpyWeightedAverageComputeImpl<xpu>(
- attrs, ctx, inputs, req, outputs, param.axis);
}
});
+ if (!param.weighted) {
+ // Compute mean
+#if !defined(__CUDACC__)
+ ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, true>(
+ ctx, inputs, req, {outputs[0]}, small);
+#else
+ ReduceAxesRTCComputeImpl(ctx, inputs, req, {outputs[0]}, small,
+ "red::sum{}", nullptr, true);
+#endif
+ } else {
+ NumpyWeightedAverageComputeImpl<xpu>(
+ attrs, ctx, inputs, req, outputs, param.axis);
+ }
}
template<typename xpu>
@@ -1090,40 +1089,58 @@ void NumpyMomentsForward(const nnvm::NodeAttrs& attrs,
mxnet::TShape src_shape, dst_shape;
BroadcastReduceShapeCompact(data.shape_, small, &src_shape, &dst_shape);
+ // Get workspace and temp space for data - mean
+ size_t workspace_size = broadcast::ReduceWorkspaceSize(s, dst_shape, req[0], src_shape);
+ size_t temp_data_size = data.shape_.Size() * common::mshadow_type_info(inputs[0].type_flag_).size;
+ size_t temp_mem_size = temp_data_size + workspace_size;
+ Tensor<xpu, 1, char> temp_mem =
+ ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(temp_mem_size), s);
+ char *workspace_ptr = temp_mem.dptr_ + temp_data_size;
+ Tensor<xpu, 1, char> workspace(workspace_ptr, Shape1(workspace_size), s);
+ // Compute mean
+#if !defined(__CUDACC__)
+ ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, true>(
+ ctx, inputs, {kWriteTo}, {mean}, small, &workspace);
+#else
+ ReduceAxesRTCComputeImpl(ctx, inputs, {kWriteTo}, {mean}, small, "red::sum{}",
+ &workspace, true, "identity");
+#endif
+ // Compute data - mean
+ Shape<6> data_shape, mean_shape;
+ for (int i = 0; i < 6; ++i) {
+ data_shape[i] = (i < data.shape_.ndim()) ? data.shape_[i] : 1;
+ mean_shape[i] = (i < small.ndim()) ? small[i] : 1;
+ }
+#if !defined(__CUDACC__)
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
- // Get workspace and temp space for data - mean
- size_t workspace_size = 0;
- workspace_size = broadcast::ReduceWorkspaceSize(
- s, dst_shape, req[0], src_shape, sizeof(DType));
- size_t temp_data_size = data.shape_.Size() * sizeof(DType);
- size_t temp_mem_size = temp_data_size + workspace_size;
- Tensor<xpu, 1, char> temp_mem =
- ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(temp_mem_size), s);
DType *temp_data_ptr = reinterpret_cast<DType*>(temp_mem.dptr_);
- char *workspace_ptr = temp_mem.dptr_ + temp_data_size;
- Tensor<xpu, 1, char> workspace(workspace_ptr, Shape1(workspace_size), s);
- // Compute mean
- ReduceAxesComputeWithWorkspaceImpl<xpu, mshadow_op::sum, true, true>(
- ctx, inputs, {kWriteTo}, {mean}, workspace, src_shape, dst_shape);
- // Compute data - mean
- Shape<6> data_shape, mean_shape;
- for (int i = 0; i < 6; ++i) {
- data_shape[i] = (i < data.shape_.ndim()) ? data.shape_[i] : 1;
- mean_shape[i] = (i < small.ndim()) ? small[i] : 1;
- }
Kernel<VarBroadcastKernel, xpu>::Launch(s, data_shape.Size(), temp_data_ptr,
data.dptr<DType>(), mean.dptr<DType>(), data_shape, mean_shape);
Tensor<xpu, 1, DType> temp_data_tensor(temp_data_ptr, Shape1(data.shape_.Size()), s);
TBlob temp_data_blob = TBlob(temp_data_tensor).reshape(data.shape_);
- ReduceAxesComputeWithWorkspaceImpl<xpu, mshadow_op::sum, true, true>(
- ctx, {temp_data_blob}, {req[0]}, {moment}, workspace, src_shape, dst_shape, param.ddof);
- if (sqrt) {
+ ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, true>(
+ ctx, {temp_data_blob}, {req[0]}, {moment}, small, &workspace, param.ddof);
+ if (sqrt && req[0] != kNullOp) {
Tensor<xpu, 1, OType> moment_tensor = moment.FlatTo1D<xpu, OType>(s);
moment_tensor = F<mshadow_op::square_root>(moment_tensor);
}
});
});
+#else
+ MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+ DType *temp_data_ptr = reinterpret_cast<DType*>(temp_mem.dptr_);
+ Kernel<VarBroadcastKernel, xpu>::Launch(s, data_shape.Size(), temp_data_ptr,
+ data.dptr<DType>(), mean.dptr<DType>(), data_shape, mean_shape);
+ Tensor<xpu, 1, DType> temp_data_tensor(temp_data_ptr, Shape1(data.shape_.Size()), s);
+ TBlob temp_data_blob = TBlob(temp_data_tensor).reshape(data.shape_);
+ ReduceAxesRTCComputeImpl(ctx, {temp_data_blob}, {req[0]}, {moment}, small,
+ "red::sum{}", &workspace, true, "identity", param.ddof);
+ if (sqrt && req[0] != kNullOp) {
+ UnaryRTCCompute {"sqrt"}({}, ctx, {moment}, {kWriteInplace}, {moment});
+ }
+ });
+#endif
}
template<typename xpu>
@@ -1159,6 +1176,7 @@ void NumpyBroadcastToBackward(const nnvm::NodeAttrs& attrs,
for (int i = 0; i < igrad_shape.ndim(); ++i) {
expanded_igrad_shape[i + ndim_delta] = igrad_shape[i];
}
+#if !defined(__CUDACC__)
if (NeedSafeAcc<true>(inputs[0].type_flag_, outputs[0].type_flag_)) {
ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(
ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, expanded_igrad_shape);
@@ -1166,6 +1184,10 @@ void NumpyBroadcastToBackward(const nnvm::NodeAttrs& attrs,
ReduceAxesComputeImpl<xpu, mshadow_op::sum, false>(
ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, expanded_igrad_shape);
}
+#else
+ ReduceAxesRTCComputeImpl(ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)},
+ expanded_igrad_shape, "red::sum{}", nullptr, false);
+#endif
}
template<typename xpu, typename OP>
diff --git a/src/operator/numpy/np_broadcast_reduce_op_boolean.cu b/src/operator/numpy/np_broadcast_reduce_op_boolean.cu
index d3247b7..405ae4b 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_boolean.cu
+++ b/src/operator/numpy/np_broadcast_reduce_op_boolean.cu
@@ -29,12 +29,12 @@ namespace mxnet {
namespace op {
NNVM_REGISTER_OP(_npi_any)
-.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBoolCompute<gpu,
- mshadow_op::sum, mshadow_op::NonZero, 0>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<NumpyReduceAxesBoolParam, 0>
+ {"NonZero", "red::sum{}", false});
NNVM_REGISTER_OP(_npi_all)
-.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBoolCompute<gpu,
- mshadow_op::product, mshadow_op::NonZero, 1>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<NumpyReduceAxesBoolParam, 1>
+ {"NonZero", "red::product{}", false});
} // namespace op
} // namespace mxnet
diff --git a/src/operator/numpy/np_broadcast_reduce_op_index.cu b/src/operator/numpy/np_broadcast_reduce_op_index.cu
index 892d046..eb6086c 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_index.cu
+++ b/src/operator/numpy/np_broadcast_reduce_op_index.cu
@@ -28,10 +28,10 @@ namespace mxnet {
namespace op {
NNVM_REGISTER_OP(_npi_argmax)
-.set_attr<FCompute>("FCompute<gpu>", NumpyArgMinMaxCompute<mshadow_op::argmax, gpu, int>);
+.set_attr<FCompute>("FCompute<gpu>", NumpyArgMinMaxRTCCompute{"red::argmax{}"});
NNVM_REGISTER_OP(_npi_argmin)
-.set_attr<FCompute>("FCompute<gpu>", NumpyArgMinMaxCompute<mshadow_op::argmin, gpu, int>);
+.set_attr<FCompute>("FCompute<gpu>", NumpyArgMinMaxRTCCompute{"red::argmin{}"});
} // namespace op
} // namespace mxnet
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu b/src/operator/numpy/np_broadcast_reduce_op_value.cu
index 422097d..6020573 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cu
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu
@@ -27,25 +27,30 @@
namespace mxnet {
namespace op {
NNVM_REGISTER_OP(_npi_sum)
-.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesCompute<gpu, mshadow_op::sum, true>);
+.set_attr<FCompute>("FCompute<gpu>",
+ ReduceAxesRTCCompute<NumpyReduceAxesParam, 0>{"identity", "red::sum{}", false});
NNVM_REGISTER_OP(_backward_npi_sum)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseNone<gpu>);
NNVM_REGISTER_OP(_npi_max)
-.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeCompute<gpu, mshadow::red::maximum>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<NumpyReduceAxesNoDTypeParam, 0>
+ {"identity", "red::maximum{}", false});
NNVM_REGISTER_OP(_backward_npi_max)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeBackward<gpu, mshadow_op::eq>);
NNVM_REGISTER_OP(_npi_min)
-.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeCompute<gpu, mshadow::red::minimum>);
+.set_attr<FCompute>("FCompute<gpu>",
+ ReduceAxesRTCCompute<NumpyReduceAxesNoDTypeParam, 0>{"identity",
+ "red::minimum{}", false});
NNVM_REGISTER_OP(_backward_npi_min)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeBackward<gpu, mshadow_op::eq>);
NNVM_REGISTER_OP(_npi_prod)
-.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesCompute<gpu, mshadow_op::product, true>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<NumpyReduceAxesParam, 1>{"identity",
+ "red::product{}", false});
NNVM_REGISTER_OP(_backward_npi_prod)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseInOut<gpu, mshadow_op::rdiv>);
@@ -57,7 +62,8 @@ NNVM_REGISTER_OP(_backward_np_average)
.set_attr<FCompute>("FCompute<gpu>", NumpyWeightedAverageBackward<gpu>);
NNVM_REGISTER_OP(_npi_mean)
-.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesCompute<gpu, mshadow_op::sum, true, true>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<NumpyReduceAxesParam, 0>{"identity",
+ "red::sum{}", true});
NNVM_REGISTER_OP(_backward_np_mean)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseNone<gpu, true>);
diff --git a/src/operator/numpy/np_constraint_check.h b/src/operator/numpy/np_constraint_check.h
index 80beaa3..01c54b6 100644
--- a/src/operator/numpy/np_constraint_check.h
+++ b/src/operator/numpy/np_constraint_check.h
@@ -56,9 +56,14 @@ void ConstraintCheckForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
CHECK_EQ(outputs.size(), 1U);
const ConstraintCheckParam& param =
nnvm::get<ConstraintCheckParam>(attrs.parsed);
+#if !defined(__CUDACC__)
ReduceAxesComputeImpl<xpu, mshadow_op::product, false, false,
op::mshadow_op::identity>(ctx, inputs, req, outputs,
outputs[0].shape_);
+#else
+ ReduceAxesRTCComputeImpl(ctx, inputs, req, outputs,
+ outputs[0].shape_, "red::product{}");
+#endif
std::string msg = param.msg;
bool red_output = true;
GetReduceOutput(ctx.get_stream<xpu>(), outputs[0], &red_output);
diff --git a/src/operator/numpy/np_cross-inl.h b/src/operator/numpy/np_cross-inl.h
index ab64564..7d5dd30 100644
--- a/src/operator/numpy/np_cross-inl.h
+++ b/src/operator/numpy/np_cross-inl.h
@@ -677,8 +677,7 @@ struct ReduceImplWrap {
std::vector<int> reduce_axis = GetReduceAxis(out_move_shape, in_move_shape);
if (reduce_axis.empty() || req == kNullOp) { return 0U; }
ws_reduce = broadcast::ReduceWorkspaceSize(ctx.get_stream<xpu>(),
- out_shape, req, in_shape,
- sizeof(DType));
+ out_shape, req, in_shape);
return ws_reduce;
}
@@ -690,10 +689,17 @@ struct ReduceImplWrap {
const Tensor<xpu, 1, char> workspace_tensor) {
Stream<xpu> *s = ctx.get_stream<xpu>();
// Reduce work_in to work_out.
+#if !defined(__CUDACC__)
SUM_NDIM_SWITCH(work_out.ndim(), NDim, {
op::broadcast::Reduce<mshadow_op::sum, NDim, DType, op::mshadow_op::identity, false>(
s, work_out, kWriteTo, workspace_tensor, work_in);
});
+#else
+ SUM_NDIM_SWITCH(work_out.ndim(), NDim, {
+ op::broadcast::RTCReduce(ctx, work_out, kWriteTo, workspace_tensor, work_in,
+ "red::sum{}", NDim, "identity");
+ });
+#endif
// Copy work_out to out_data.
MXNET_ASSIGN_REQ_SWITCH(out_req, req_type, {
mxnet_op::Kernel<ResAssign<req_type>, xpu>::Launch(
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h
index 1fa5890..be19b38 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.h
+++ b/src/operator/numpy/np_elemwise_broadcast_op.h
@@ -413,9 +413,9 @@ void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
MSHADOW_TYPE_SWITCH(ograd.type_flag_, OType, {
if (need_bc) {
workspace_size_l = ReduceWorkspaceSize(
- s, new_lshape, req[0], new_oshape, new_lshape, new_rshape, sizeof(OType));
+ s, new_lshape, req[0], new_oshape, new_lshape, new_rshape);
workspace_size_r = ReduceWorkspaceSize(
- s, new_rshape, req[1], new_oshape, new_lshape, new_rshape, sizeof(OType));
+ s, new_rshape, req[1], new_oshape, new_lshape, new_rshape);
}
size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
size_t cast_tensor_size = tensor_size * sizeof(OType);
diff --git a/src/operator/numpy/np_kron-inl.h b/src/operator/numpy/np_kron-inl.h
index f357cec..bf0983f 100644
--- a/src/operator/numpy/np_kron-inl.h
+++ b/src/operator/numpy/np_kron-inl.h
@@ -188,6 +188,14 @@ void KronOpForwardImpl(const OpContext& ctx,
});
}
+#if !defined(__CUDACC__)
+#define NP_KRON_REDUCE_AXES(safe_acc, workspace, ...) \
+ ReduceAxesComputeImpl<xpu, mshadow_op::sum, safe_acc>(__VA_ARGS__, &workspace)
+#else
+#define NP_KRON_REDUCE_AXES(safe_acc, workspace, ...) \
+ ReduceAxesRTCComputeImpl(__VA_ARGS__, "red::sum{}", &workspace)
+#endif
+
template<typename xpu>
void KronOpBackwardImpl(const OpContext& ctx,
const std::vector<OpReqType>& req,
@@ -226,12 +234,23 @@ void KronOpBackwardImpl(const OpContext& ctx,
const OpReqType& scalar_req = (ashape.ndim() == 0) ? req[0] : req[1];
ASSIGN_DISPATCH(tensor_grad_, tensor_req,
broadcast_scalar(scalar_, tensor_grad_.shape_) * ograd_);
- Tensor<xpu, 1, DType> workspace =
- ctx.requested[0].get_space_typed<xpu, 1, DType>(Shape1(ograd.shape_.Size()), s);
- ASSIGN_DISPATCH(workspace, kWriteTo, tensor_ * ograd_);
-
- ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(
- ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_);
+ TShape src_shape, dst_shape;
+ BroadcastReduceShapeCompact(ograd.shape_, scalar_grad_.shape_, &src_shape, &dst_shape);
+ size_t workspace_size = broadcast::ReduceWorkspaceSize(s, dst_shape,
+ {scalar_req}, src_shape);
+ constexpr size_t align_size = 1024;
+ const size_t aligned_first_workspace_size = ((workspace_size + align_size - 1) / align_size)
+ * align_size;
+ workspace_size = aligned_first_workspace_size + ograd.shape_.Size() * sizeof(DType);
+ Tensor<xpu, 1, char> workspace =
+ ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+ Tensor<xpu, 1, DType> temp(reinterpret_cast<DType*>(workspace.dptr_ +
+ aligned_first_workspace_size),
+ Shape1(ograd.shape_.Size()), s);
+ ASSIGN_DISPATCH(temp, kWriteTo, tensor_ * ograd_);
+
+ NP_KRON_REDUCE_AXES(true, workspace, ctx, {TBlob(temp)}, {scalar_req},
+ {TBlob(scalar_grad_)}, scalar_grad_.shape_);
} else {
MXNET_NDIM_SWITCH(oshape.ndim(), ndim, {
Shape<ndim> ashape_ = oshape.get<ndim>();
@@ -276,6 +295,8 @@ void KronOpBackwardImpl(const OpContext& ctx,
});
}
+#undef NP_KRON_REDUCE_AXES
+
template<typename xpu>
inline void KronOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
diff --git a/src/operator/numpy/np_tensordot_op-inl.h b/src/operator/numpy/np_tensordot_op-inl.h
index 1bfc6d1..47e42d0 100644
--- a/src/operator/numpy/np_tensordot_op-inl.h
+++ b/src/operator/numpy/np_tensordot_op-inl.h
@@ -370,6 +370,14 @@ inline mxnet::TShape GetReverseShape(const mxnet::Tuple<int>& shape) {
return shape2;
}
+#if !defined(__CUDACC__)
+#define NP_TENSORDOT_REDUCE_AXES(safe_acc, ...) \
+ ReduceAxesComputeImpl<xpu, mshadow_op::sum, safe_acc>(__VA_ARGS__)
+#else
+#define NP_TENSORDOT_REDUCE_AXES(safe_acc, ...) \
+ ReduceAxesRTCComputeImpl(__VA_ARGS__, "red::sum{}")
+#endif
+
/**
* calculates tensordot derivative.
*/
@@ -424,8 +432,8 @@ void TensordotBackwardImpl(const Tuple<int>& a_axes_summed,
workspace.stream_);
ASSIGN_DISPATCH(dtypespace, kWriteTo, tensor_ * out_grad_);
- ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(
- ctx, {TBlob(dtypespace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_);
+ NP_TENSORDOT_REDUCE_AXES(true, ctx, {TBlob(dtypespace)}, {scalar_req},
+ {TBlob(scalar_grad_)}, scalar_grad_.shape_);
} else {
// Two tensors of at least 1 dimensions.
Tuple<int> a_axes_remained;
@@ -734,8 +742,8 @@ void TensordotIntAxesBackwardImpl(const int axes,
ctx.requested[0].get_space_typed<xpu, 1, DType>(Shape1(out_grad.shape_.Size()), s);
ASSIGN_DISPATCH(workspace, kWriteTo, tensor_ * out_grad_);
- ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(
- ctx, {TBlob(workspace)}, {scalar_req}, {TBlob(scalar_grad_)}, scalar_grad_.shape_);
+ NP_TENSORDOT_REDUCE_AXES(true, ctx, {TBlob(workspace)}, {scalar_req},
+ {TBlob(scalar_grad_)}, scalar_grad_.shape_);
} else {
// Two tensors of at least 1 dimensions.
Tuple<int> a_axes_summed;
@@ -759,6 +767,8 @@ void TensordotIntAxesBackwardImpl(const int axes,
});
}
+#undef NP_TENSORDOT_REDUCE_AXES
+
/**
* backward function.
*/
diff --git a/src/operator/numpy/np_where_op-inl.h b/src/operator/numpy/np_where_op-inl.h
index 10ec081..43af21d 100644
--- a/src/operator/numpy/np_where_op-inl.h
+++ b/src/operator/numpy/np_where_op-inl.h
@@ -175,6 +175,13 @@ inline void NumpyWhereOpForward(const nnvm::NodeAttrs& attrs,
});
}
+#if !defined(__CUDACC__)
+#define NP_WHERE_REDUCE_AXES(safe_acc, ...) \
+ ReduceAxesComputeImpl<xpu, mshadow_op::sum, safe_acc>(__VA_ARGS__)
+#else
+#define NP_WHERE_REDUCE_AXES(safe_acc, ...) ReduceAxesRTCComputeImpl(__VA_ARGS__, "red::sum{}")
+#endif
+
template<typename xpu>
inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
@@ -226,9 +233,9 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs,
size_t ws_size = 0;
if (ograd.shape_ != dx.shape_ || ograd.shape_ != dy.shape_) {
size_t ws_size1 = broadcast::ReduceWorkspaceSize(
- s, expanded_lshape, req[0], expanded_oshape, sizeof(DType));
+ s, expanded_lshape, req[0], expanded_oshape);
size_t ws_size2 = broadcast::ReduceWorkspaceSize(
- s, expanded_rshape, req[1], expanded_oshape, sizeof(DType));
+ s, expanded_rshape, req[1], expanded_oshape);
ws_size = std::max(ws_size1, ws_size2);
}
// process left output
@@ -246,10 +253,10 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs,
s, ograd.Size(), req[0], cstride, oshape,
cond.dptr<CType>(), ograd.dptr<DType>(), workspace.dptr_);
if (NeedSafeAcc<true>(dx.type_flag_, dx.type_flag_)) {
- ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(ctx, {TBlob(workspace)}, {req[0]},
+ NP_WHERE_REDUCE_AXES(true, ctx, {TBlob(workspace)}, {req[0]},
{dx.reshape(expanded_lshape)}, expanded_lshape);
} else {
- ReduceAxesComputeImpl<xpu, mshadow_op::sum, false>(ctx, {TBlob(workspace)}, {req[0]},
+ NP_WHERE_REDUCE_AXES(false, ctx, {TBlob(workspace)}, {req[0]},
{dx.reshape(expanded_lshape)}, expanded_lshape);
}
}
@@ -268,10 +275,10 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs,
s, ograd.Size(), req[1], cstride, oshape,
cond.dptr<CType>(), ograd.dptr<DType>(), workspace.dptr_);
if (NeedSafeAcc<true>(dy.type_flag_, dy.type_flag_)) {
- ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(ctx, {TBlob(workspace)}, {req[1]},
+ NP_WHERE_REDUCE_AXES(true, ctx, {TBlob(workspace)}, {req[1]},
{dy.reshape(expanded_rshape)}, expanded_rshape);
} else {
- ReduceAxesComputeImpl<xpu, mshadow_op::sum, false>(ctx, {TBlob(workspace)}, {req[1]},
+ NP_WHERE_REDUCE_AXES(false, ctx, {TBlob(workspace)}, {req[1]},
{dy.reshape(expanded_rshape)}, expanded_rshape);
}
}
@@ -367,7 +374,7 @@ inline void NumpyWhereScalarOpBackward(const nnvm::NodeAttrs& attrs,
size_t ws_size = 0;
if (ograd.shape_ != dx.shape_) {
ws_size = broadcast::ReduceWorkspaceSize(s, expanded_lshape, req[0],
- expanded_oshape, sizeof(DType));
+ expanded_oshape);
}
// If lscalar, then process right output, `is_left` should be false
if (ograd.shape_ == dx.shape_) {
@@ -384,10 +391,10 @@ inline void NumpyWhereScalarOpBackward(const nnvm::NodeAttrs& attrs,
s, ograd.Size(), req[0], cstride, oshape,
cond.dptr<CType>(), ograd.dptr<DType>(), workspace.dptr_);
if (NeedSafeAcc<true>(dx.type_flag_, dx.type_flag_)) {
- ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(ctx, {TBlob(workspace)}, {req[0]},
+ NP_WHERE_REDUCE_AXES(true, ctx, {TBlob(workspace)}, {req[0]},
{dx.reshape(expanded_lshape)}, expanded_lshape);
} else {
- ReduceAxesComputeImpl<xpu, mshadow_op::sum, false>(ctx, {TBlob(workspace)}, {req[0]},
+ NP_WHERE_REDUCE_AXES(false, ctx, {TBlob(workspace)}, {req[0]},
{dx.reshape(expanded_lshape)}, expanded_lshape);
}
}
@@ -395,6 +402,8 @@ inline void NumpyWhereScalarOpBackward(const nnvm::NodeAttrs& attrs,
});
}
+#undef NP_WHERE_REDUCE_AXES
+
template<typename xpu>
inline void NumpyWhereScalar2OpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
diff --git a/src/operator/numpy/random/dist_common.h b/src/operator/numpy/random/dist_common.h
index ab8afe9..1e43134 100644
--- a/src/operator/numpy/random/dist_common.h
+++ b/src/operator/numpy/random/dist_common.h
@@ -277,6 +277,86 @@ inline bool TwoparamsDistOpConcatShape(const nnvm::NodeAttrs &attrs,
return true;
}
+template<typename xpu, int ndim, typename DType>
+inline void CommonReparamBackwardImpl(const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs,
+ const mxnet::TShape& new_lshape,
+ const mxnet::TShape& new_rshape,
+ const mxnet::TShape& new_oshape) {
+ using namespace mshadow;
+ using namespace mshadow::expr;
+ using namespace broadcast;
+ Stream<xpu> *s = ctx.get_stream<xpu>();
+ const TBlob lgrad = outputs[0].reshape(new_lshape);
+ const TBlob rgrad = outputs[1].reshape(new_rshape);
+ const TBlob ograd = inputs[0].reshape(new_oshape);
+ // Mean
+ const TBlob lhs = inputs[2].reshape(new_lshape);
+ // Scale
+ const TBlob rhs = inputs[3].reshape(new_rshape);
+ const TBlob samples = inputs[4].reshape(new_oshape);
+ const TBlob noise = inputs[5].reshape(new_oshape);
+ size_t workspace_size_l = ReduceWorkspaceSize(
+ s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_);
+ size_t workspace_size_r = ReduceWorkspaceSize(
+ s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_);
+ size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
+ Tensor<xpu, 1, char> workspace =
+ ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+#if !defined(__CUDACC__)
+ Reduce<red::sum, ndim, DType, op::mshadow_op::identity>(
+ s, lgrad, req[0], workspace, ograd);
+ Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
+ s, rgrad, req[1], workspace, ograd, noise, rhs);
+#else
+ RTCReduce(ctx, lgrad, req[0], workspace, ograd, "red::sum{}", ndim, "identity");
+ RTCReduce(ctx, rgrad, req[1], workspace, ograd, noise, rhs, "red::sum{}", ndim, "mul", "left");
+#endif
+}
+
+template<typename xpu, int ndim, typename DType>
+inline void CommonScalarReparamBackwardImpl(const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs,
+ const mxnet::TShape& new_ishape,
+ const mxnet::TShape& new_oshape,
+ const bool loc_is_tensor = false) {
+ using namespace mshadow;
+ using namespace mshadow::expr;
+ using namespace broadcast;
+ Stream<xpu> *s = ctx.get_stream<xpu>();
+ const TBlob igrad = outputs[0].reshape(new_ishape);
+ // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor,
+ // samples, noise]
+ const TBlob ograd = inputs[0].reshape(new_oshape);
+ const TBlob itensor = inputs[2].reshape(new_ishape);
+ const TBlob samples = inputs[3].reshape(new_oshape);
+ const TBlob noise = inputs[4].reshape(new_oshape);
+ size_t workspace_size =
+ ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_);
+ Tensor<xpu, 1, char> workspace =
+ ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+#if !defined(__CUDACC__)
+ if (loc_is_tensor) {
+ Reduce<red::sum, ndim, DType, op::mshadow_op::identity>(s, igrad, req[0],
+ workspace, ograd);
+ } else {
+ Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
+ s, igrad, req[0], workspace, ograd, noise, noise);
+ }
+#else
+ if (loc_is_tensor) {
+ RTCReduce(ctx, igrad, req[0], workspace, ograd, "red::sum{}", ndim, "identity");
+ } else {
+ RTCReduce(ctx, igrad, req[0], workspace, ograd, noise, noise, "red::sum{}",
+ ndim, "mul", "left");
+ }
+#endif
+}
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/numpy/random/np_exponential_op.h b/src/operator/numpy/random/np_exponential_op.h
index 374b3b4..0f0d462 100644
--- a/src/operator/numpy/random/np_exponential_op.h
+++ b/src/operator/numpy/random/np_exponential_op.h
@@ -171,11 +171,16 @@ inline void ExponentialReparamBackwardImpl(const OpContext& ctx,
const TBlob samples = inputs[3].reshape(new_oshape);
const TBlob noise = inputs[4].reshape(new_oshape);
size_t workspace_size =
- ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType));
+ ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_);
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+#if !defined(__CUDACC__)
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
s, igrad, req[0], workspace, ograd, noise, noise);
+#else
+ RTCReduce(ctx, igrad, req[0], workspace, ograd, noise, noise,
+ "red::sum{}", ndim, "mul", "left");
+#endif
}
template<typename xpu>
diff --git a/src/operator/numpy/random/np_gamma_op.h b/src/operator/numpy/random/np_gamma_op.h
index a0f3299..55041a9 100644
--- a/src/operator/numpy/random/np_gamma_op.h
+++ b/src/operator/numpy/random/np_gamma_op.h
@@ -420,14 +420,19 @@ inline void GammaReparamBackwardImpl(const OpContext& ctx,
const TBlob alpha = inputs[1].reshape(new_ishape);
TBlob samples = inputs[2].reshape(new_oshape);
size_t workspace_size =
- ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType));
+ ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_);
// Convert samples to standard gamma
Kernel<op_with_req<mshadow_op::div, kWriteTo>, xpu>::Launch(
s, samples.Size(), samples.dptr<DType>(), samples.dptr<DType>(), DType(scale));
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+#if !defined(__CUDACC__)
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::gamma_implicit_grad>(
s, igrad, req[0], workspace, ograd, alpha, samples);
+#else
+ RTCReduce(ctx, igrad, req[0], workspace, ograd, alpha, samples, "red::sum{}", ndim,
+ "mul", "gamma_implicit_grad");
+#endif
Kernel<op_with_req<mshadow_op::mul, kWriteTo>, xpu>::Launch(
s, igrad.Size(), igrad.dptr<DType>(), igrad.dptr<DType>(), DType(scale));
// Convert samples back, otherwise the output would be corrupted.
diff --git a/src/operator/numpy/random/np_location_scale_op.h b/src/operator/numpy/random/np_location_scale_op.h
index 73403f3..0179a57 100644
--- a/src/operator/numpy/random/np_location_scale_op.h
+++ b/src/operator/numpy/random/np_location_scale_op.h
@@ -275,72 +275,6 @@ void NumpyLocationScaleForward(const nnvm::NodeAttrs &attrs,
}
}
-template<typename xpu, int ndim, typename DType>
-inline void LocationScaleReparamBackwardImpl(const OpContext& ctx,
- const std::vector<TBlob>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<TBlob>& outputs,
- const mxnet::TShape& new_lshape,
- const mxnet::TShape& new_rshape,
- const mxnet::TShape& new_oshape) {
- using namespace mshadow;
- using namespace mshadow::expr;
- using namespace broadcast;
- Stream<xpu> *s = ctx.get_stream<xpu>();
- const TBlob lgrad = outputs[0].reshape(new_lshape);
- const TBlob rgrad = outputs[1].reshape(new_rshape);
- const TBlob ograd = inputs[0].reshape(new_oshape);
- // Mean
- const TBlob lhs = inputs[2].reshape(new_lshape);
- // Scale
- const TBlob rhs = inputs[3].reshape(new_rshape);
- const TBlob samples = inputs[4].reshape(new_oshape);
- const TBlob noise = inputs[5].reshape(new_oshape);
- size_t workspace_size_l = ReduceWorkspaceSize(
- s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType));
- size_t workspace_size_r = ReduceWorkspaceSize(
- s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType));
- size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
- Tensor<xpu, 1, char> workspace =
- ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
- Reduce<red::sum, ndim, DType, op::mshadow_op::identity>(
- s, lgrad, req[0], workspace, ograd);
- Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
- s, rgrad, req[1], workspace, ograd, noise, rhs);
-}
-
-template<typename xpu, int ndim, typename DType>
-inline void ScalarLocationScaleReparamBackwardImpl(const OpContext& ctx,
- const std::vector<TBlob>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<TBlob>& outputs,
- const mxnet::TShape& new_ishape,
- const mxnet::TShape& new_oshape,
- const bool loc_is_tensor) {
- using namespace mshadow;
- using namespace mshadow::expr;
- using namespace broadcast;
- Stream<xpu> *s = ctx.get_stream<xpu>();
- const TBlob igrad = outputs[0].reshape(new_ishape);
- // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor,
- // samples, noise]
- const TBlob ograd = inputs[0].reshape(new_oshape);
- const TBlob itensor = inputs[2].reshape(new_ishape);
- const TBlob samples = inputs[3].reshape(new_oshape);
- const TBlob noise = inputs[4].reshape(new_oshape);
- size_t workspace_size =
- ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType));
- Tensor<xpu, 1, char> workspace =
- ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
- if (loc_is_tensor) {
- Reduce<red::sum, ndim, DType, op::mshadow_op::identity>(s, igrad, req[0],
- workspace, ograd);
- } else {
- Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
- s, igrad, req[0], workspace, ograd, noise, noise);
- }
-}
-
// Allow logistic and gumbel sampling to be differentiable,
// using reparameterization trick described in:
// Auto-encoding variational bayes.
@@ -359,7 +293,7 @@ void LocationScaleReparamBackward(const nnvm::NodeAttrs& attrs,
if (outputs.size() == 0U) {
return;
}
- const NumpyLocationScaleParam ¶m = nnvm::get<NumpyLocationScaleParam>(attrs.parsed);
+ const auto ¶m = nnvm::get<NumpyLocationScaleParam>(attrs.parsed);
// [tensor tensor] case
if (inputs.size() == 6U) {
mxnet::TShape new_lshape, new_rshape, new_oshape;
@@ -367,7 +301,7 @@ void LocationScaleReparamBackward(const nnvm::NodeAttrs& attrs,
&new_lshape, &new_rshape, &new_oshape);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
- LocationScaleReparamBackwardImpl<xpu, NDim, DType>(
+ CommonReparamBackwardImpl<xpu, NDim, DType>(
ctx, inputs, req, outputs, new_lshape, new_rshape, new_oshape);
});
});
@@ -380,7 +314,7 @@ void LocationScaleReparamBackward(const nnvm::NodeAttrs& attrs,
bool loc_is_tensor = !param.loc.has_value();
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
- ScalarLocationScaleReparamBackwardImpl<xpu, NDim, DType>(
+ CommonScalarReparamBackwardImpl<xpu, NDim, DType>(
ctx, inputs, req, outputs, new_ishape, new_oshape, loc_is_tensor);
});
});
diff --git a/src/operator/numpy/random/np_normal_op.h b/src/operator/numpy/random/np_normal_op.h
index e43d98d..06b5bfa 100644
--- a/src/operator/numpy/random/np_normal_op.h
+++ b/src/operator/numpy/random/np_normal_op.h
@@ -161,7 +161,7 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs,
const std::vector<TBlob> &outputs) {
using namespace mshadow;
using namespace mxnet_op;
- const NumpyNormalParam ¶m = nnvm::get<NumpyNormalParam>(attrs.parsed);
+ const auto ¶m = nnvm::get<NumpyNormalParam>(attrs.parsed);
Stream<xpu> *s = ctx.get_stream<xpu>();
// Generate base random number.
Random<xpu, float> *prnd = ctx.requested[0].get_random<xpu, float>(s);
@@ -240,72 +240,6 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs,
}
}
-template<typename xpu, int ndim, typename DType>
-inline void NormalReparamBackwardImpl(const OpContext& ctx,
- const std::vector<TBlob>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<TBlob>& outputs,
- const mxnet::TShape& new_lshape,
- const mxnet::TShape& new_rshape,
- const mxnet::TShape& new_oshape) {
- using namespace mshadow;
- using namespace mshadow::expr;
- using namespace broadcast;
- Stream<xpu> *s = ctx.get_stream<xpu>();
- const TBlob lgrad = outputs[0].reshape(new_lshape);
- const TBlob rgrad = outputs[1].reshape(new_rshape);
- const TBlob ograd = inputs[0].reshape(new_oshape);
- // Mean
- const TBlob lhs = inputs[2].reshape(new_lshape);
- // Variance
- const TBlob rhs = inputs[3].reshape(new_rshape);
- const TBlob samples = inputs[4].reshape(new_oshape);
- const TBlob noise = inputs[5].reshape(new_oshape);
- size_t workspace_size_l = ReduceWorkspaceSize(
- s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType));
- size_t workspace_size_r = ReduceWorkspaceSize(
- s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType));
- size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
- Tensor<xpu, 1, char> workspace =
- ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
- Reduce<red::sum, ndim, DType, op::mshadow_op::identity>(s,
- lgrad, req[0], workspace, ograd);
- Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
- s, rgrad, req[1], workspace, ograd, noise, rhs);
-}
-
-template<typename xpu, int ndim, typename DType>
-inline void ScalarNormalReparamBackwardImpl(const OpContext& ctx,
- const std::vector<TBlob>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<TBlob>& outputs,
- const mxnet::TShape& new_ishape,
- const mxnet::TShape& new_oshape,
- const bool loc_is_tensor) {
- using namespace mshadow;
- using namespace mshadow::expr;
- using namespace broadcast;
- Stream<xpu> *s = ctx.get_stream<xpu>();
- const TBlob igrad = outputs[0].reshape(new_ishape);
- // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor,
- // samples, noise]
- const TBlob ograd = inputs[0].reshape(new_oshape);
- const TBlob itensor = inputs[2].reshape(new_ishape);
- const TBlob samples = inputs[3].reshape(new_oshape);
- const TBlob noise = inputs[4].reshape(new_oshape);
- size_t workspace_size =
- ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType));
- Tensor<xpu, 1, char> workspace =
- ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
- if (loc_is_tensor) {
- Reduce<red::sum, ndim, DType, op::mshadow_op::identity>(s, igrad, req[0],
- workspace, ograd);
- } else {
- Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
- s, igrad, req[0], workspace, ograd, noise, noise);
- }
-}
-
// Allow normal sampling to be differentiable,
// using reparameterization trick described in:
// Auto-encoding variational bayes.
@@ -324,7 +258,7 @@ void NormalReparamBackward(const nnvm::NodeAttrs& attrs,
if (outputs.size() == 0U) {
return;
}
- const NumpyNormalParam ¶m = nnvm::get<NumpyNormalParam>(attrs.parsed);
+ const auto ¶m = nnvm::get<NumpyNormalParam>(attrs.parsed);
// [tensor tensor] case
if (inputs.size() == 6U) {
mxnet::TShape new_lshape, new_rshape, new_oshape;
@@ -332,7 +266,7 @@ void NormalReparamBackward(const nnvm::NodeAttrs& attrs,
&new_lshape, &new_rshape, &new_oshape);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
- NormalReparamBackwardImpl<xpu, NDim, DType>(
+ CommonReparamBackwardImpl<xpu, NDim, DType>(
ctx, inputs, req, outputs, new_lshape, new_rshape, new_oshape);
});
});
@@ -345,7 +279,7 @@ void NormalReparamBackward(const nnvm::NodeAttrs& attrs,
bool loc_is_tensor = !param.loc.has_value();
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
- ScalarNormalReparamBackwardImpl<xpu, NDim, DType>(
+ CommonScalarReparamBackwardImpl<xpu, NDim, DType>(
ctx, inputs, req, outputs, new_ishape, new_oshape, loc_is_tensor);
});
});
diff --git a/src/operator/numpy/random/np_pareto_op.h b/src/operator/numpy/random/np_pareto_op.h
index 5e5d26a..16731c1 100644
--- a/src/operator/numpy/random/np_pareto_op.h
+++ b/src/operator/numpy/random/np_pareto_op.h
@@ -155,32 +155,6 @@ void NumpyParetoForward(const nnvm::NodeAttrs &attrs,
}
}
-template<typename xpu, int ndim, typename DType>
-inline void ScalarParetoReparamBackwardImpl(const OpContext& ctx,
- const std::vector<TBlob>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<TBlob>& outputs,
- const mxnet::TShape& new_ishape,
- const mxnet::TShape& new_oshape) {
- using namespace mshadow;
- using namespace mshadow::expr;
- using namespace broadcast;
- Stream<xpu> *s = ctx.get_stream<xpu>();
- const TBlob igrad = outputs[0].reshape(new_ishape);
- // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor,
- // samples, noise]
- const TBlob ograd = inputs[0].reshape(new_oshape);
- const TBlob itensor = inputs[2].reshape(new_ishape);
- const TBlob samples = inputs[3].reshape(new_oshape);
- const TBlob noise = inputs[4].reshape(new_oshape);
- size_t workspace_size =
- ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType));
- Tensor<xpu, 1, char> workspace =
- ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
- Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
- s, igrad, req[0], workspace, ograd, noise, noise);
- }
-
template<typename xpu>
void ParetoReparamBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
@@ -202,7 +176,7 @@ if (inputs.size() == 5U) {
&new_ishape, &new_ishape, &new_oshape);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
- ScalarParetoReparamBackwardImpl<xpu, NDim, DType>(
+ CommonScalarReparamBackwardImpl<xpu, NDim, DType>(
ctx, inputs, reqs, outputs, new_ishape, new_oshape);
});
});
diff --git a/src/operator/numpy/random/np_rayleigh_op.h b/src/operator/numpy/random/np_rayleigh_op.h
index 0f940e5..75c4784 100644
--- a/src/operator/numpy/random/np_rayleigh_op.h
+++ b/src/operator/numpy/random/np_rayleigh_op.h
@@ -153,32 +153,6 @@ void NumpyRayleighForward(const nnvm::NodeAttrs &attrs,
}
}
-template<typename xpu, int ndim, typename DType>
-inline void ScalarRayleighReparamBackwardImpl(const OpContext& ctx,
- const std::vector<TBlob>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<TBlob>& outputs,
- const mxnet::TShape& new_ishape,
- const mxnet::TShape& new_oshape) {
- using namespace mshadow;
- using namespace mshadow::expr;
- using namespace broadcast;
- Stream<xpu> *s = ctx.get_stream<xpu>();
- const TBlob igrad = outputs[0].reshape(new_ishape);
- // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor,
- // samples, noise]
- const TBlob ograd = inputs[0].reshape(new_oshape);
- const TBlob itensor = inputs[2].reshape(new_ishape);
- const TBlob samples = inputs[3].reshape(new_oshape);
- const TBlob noise = inputs[4].reshape(new_oshape);
- size_t workspace_size =
- ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType));
- Tensor<xpu, 1, char> workspace =
- ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
- Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
- s, igrad, req[0], workspace, ograd, noise, noise);
-}
-
template<typename xpu>
void RayleighReparamBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
@@ -200,7 +174,7 @@ void RayleighReparamBackward(const nnvm::NodeAttrs& attrs,
&new_ishape, &new_ishape, &new_oshape);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
- ScalarRayleighReparamBackwardImpl<xpu, NDim, DType>(
+ CommonScalarReparamBackwardImpl<xpu, NDim, DType>(
ctx, inputs, req, outputs, new_ishape, new_oshape);
});
});
diff --git a/src/operator/numpy/random/np_weibull_op.h b/src/operator/numpy/random/np_weibull_op.h
index 970dc85..a7d6d5d 100644
--- a/src/operator/numpy/random/np_weibull_op.h
+++ b/src/operator/numpy/random/np_weibull_op.h
@@ -155,32 +155,6 @@ void NumpyWeibullForward(const nnvm::NodeAttrs &attrs,
}
}
-template<typename xpu, int ndim, typename DType>
-inline void ScalarWeibullReparamBackwardImpl(const OpContext& ctx,
- const std::vector<TBlob>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<TBlob>& outputs,
- const mxnet::TShape& new_ishape,
- const mxnet::TShape& new_oshape) {
- using namespace mshadow;
- using namespace mshadow::expr;
- using namespace broadcast;
- Stream<xpu> *s = ctx.get_stream<xpu>();
- const TBlob igrad = outputs[0].reshape(new_ishape);
- // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor,
- // samples, noise]
- const TBlob ograd = inputs[0].reshape(new_oshape);
- const TBlob itensor = inputs[2].reshape(new_ishape);
- const TBlob samples = inputs[3].reshape(new_oshape);
- const TBlob noise = inputs[4].reshape(new_oshape);
- size_t workspace_size =
- ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_, sizeof(DType));
- Tensor<xpu, 1, char> workspace =
- ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
- Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
- s, igrad, req[0], workspace, ograd, noise, noise);
- }
-
template<typename xpu>
void WeibullReparamBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
@@ -202,7 +176,7 @@ if (inputs.size() == 5U) {
&new_ishape, &new_ishape, &new_oshape);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
- ScalarWeibullReparamBackwardImpl<xpu, NDim, DType>(
+ CommonScalarReparamBackwardImpl<xpu, NDim, DType>(
ctx, inputs, reqs, outputs, new_ishape, new_oshape);
});
});
diff --git a/src/operator/quantization/quantization_utils.h b/src/operator/quantization/quantization_utils.h
index 2c5c1eb..0d89570 100644
--- a/src/operator/quantization/quantization_utils.h
+++ b/src/operator/quantization/quantization_utils.h
@@ -184,7 +184,7 @@ inline size_t ConfigReduce(mshadow::Stream<xpu>* s,
CHECK_EQ(src_shape->ndim(), NDim);
CHECK_EQ(dst_shape->ndim(), NDim);
- return broadcast::ReduceWorkspaceSize(s, *dst_shape, kWriteTo, *src_shape, sizeof(DType));
+ return broadcast::ReduceWorkspaceSize(s, *dst_shape, kWriteTo, *src_shape);
}
enum QuantizeOutType { kAuto = 0, kInt8, kUint8 };
diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h
index d8814cc..cfbdb7f 100644
--- a/src/operator/quantization/quantize_v2-inl.h
+++ b/src/operator/quantization/quantize_v2-inl.h
@@ -205,10 +205,17 @@ class QuantizeV2Operator {
dev_id);
Tensor<xpu, 1, char> workspace(temp_space.dptr_ + 2 * actual_float_size,
Shape1(temp_reduce_size), s);
+#if !defined(__CUDACC__)
broadcast::Reduce<red::minimum, 2, SrcDType, mshadow::op::identity>(
s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape));
broadcast::Reduce<red::maximum, 2, SrcDType, mshadow::op::identity>(
s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape));
+#else
+ broadcast::RTCReduce(ctx, in_min_t.reshape(dst_shape), kWriteTo, workspace,
+ inputs[0].reshape(src_shape), "red::minimum{}", 2, "identity");
+ broadcast::RTCReduce(ctx, in_max_t.reshape(dst_shape), kWriteTo, workspace,
+ inputs[0].reshape(src_shape), "red::maximum{}", 2, "identity");
+#endif
if (out_type == mshadow::kUint8) {
Kernel<quantize_v2_unsigned, xpu>::Launch(
s, outputs[0].Size(), outputs[0].dptr<uint8_t>(), outputs[1].dptr<float>(),
diff --git a/src/operator/quantization/requantize-inl.h b/src/operator/quantization/requantize-inl.h
index 2bdc3a7..5668670 100644
--- a/src/operator/quantization/requantize-inl.h
+++ b/src/operator/quantization/requantize-inl.h
@@ -148,6 +148,7 @@ void RequantizeForward(const nnvm::NodeAttrs& attrs,
temp_space.dptr_ + 8) + 1, Shape1(1), xpu::kDevMask, dev_id);
Tensor<xpu, 1, char> workspace(
temp_space.dptr_+2*actual_float_size+2*actual_quantized_size, Shape1(temp_reduce_size), s);
+#if !defined(__CUDACC__)
broadcast::Reduce<red::minimum, 2, SrcDType, mshadow::op::identity>(
s, actual_min_quantized.reshape(dst_shape),
kWriteTo, workspace, inputs[0].reshape(src_shape));
@@ -158,6 +159,18 @@ void RequantizeForward(const nnvm::NodeAttrs& attrs,
broadcast::Reduce<red::maximum, 2, SrcDType, mshadow::op::identity>(
s, actual_max_quantized.reshape(dst_shape),
kWriteTo, workspace, inputs[0].reshape(src_shape));
+#else
+ broadcast::RTCReduce(ctx, actual_min_quantized.reshape(dst_shape),
+ kWriteTo, workspace, inputs[0].reshape(src_shape),
+ "red::minimum{}", 2, "identity");
+ Kernel<QuantizedToFloatStruct, xpu>::Launch(s, 1,
+ actual_min_float.dptr_, actual_min_quantized.dptr<SrcDType>(),
+ inputs[1].dptr<float>(), inputs[2].dptr<float>());
+
+ broadcast::RTCReduce(ctx, actual_max_quantized.reshape(dst_shape),
+ kWriteTo, workspace, inputs[0].reshape(src_shape),
+ "red::maximum{}", 2, "identity");
+#endif
Kernel<QuantizedToFloatStruct, xpu>::Launch(s, 1,
actual_max_float.dptr_, actual_max_quantized.dptr<SrcDType>(),
inputs[1].dptr<float>(), inputs[2].dptr<float>());
diff --git a/src/operator/random/pdf_op.h b/src/operator/random/pdf_op.h
index f6dc777..f53d3a6 100644
--- a/src/operator/random/pdf_op.h
+++ b/src/operator/random/pdf_op.h
@@ -592,10 +592,11 @@ void PdfOpBackward(const nnvm::NodeAttrs& attrs,
const PdfParam& param = nnvm::get<PdfParam>(attrs.parsed);
const size_t N(outputs[1].Size());
const TShape src_shape(Shape2(N, outputs[0].Size() / N)), dst_shape(Shape2(N, 1));
+ const size_t red_work_size(broadcast::ReduceWorkspaceSize(
+ s, dst_shape, kAddTo, src_shape));
+#if !defined(__CUDACC__)
// Inputs to PdfOpBackward: grad, samples, parm1, parm2, pdf.
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
- const size_t red_work_size(broadcast::ReduceWorkspaceSize(
- s, dst_shape, kAddTo, src_shape, sizeof(DType)));
const size_t tmp_size(outputs[0].Size() * pnum * sizeof(DType) + red_work_size);
Tensor<xpu, 1, char> tmp_space =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(tmp_size), s);
@@ -620,6 +621,35 @@ void PdfOpBackward(const nnvm::NodeAttrs& attrs,
s, outputs[2].reshape(dst_shape), req[2], red_work, grads[2].reshape(src_shape));
}
});
+#else
+ // Inputs to PdfOpBackward: grad, samples, parm1, parm2, pdf.
+ MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+ const size_t tmp_size(outputs[0].Size() * pnum * sizeof(DType) + red_work_size);
+ Tensor<xpu, 1, char> tmp_space =
+ ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(tmp_size), s);
+ std::vector<TBlob> grads = {outputs[0]};
+ grads.push_back(TBlob(tmp_space.dptr_, outputs[0].shape_,
+ outputs[1].dev_mask(), outputs[1].type_flag_, outputs[1].dev_id()));
+ if (pnum == 2) {
+ grads.push_back(TBlob(tmp_space.dptr_ + outputs[0].Size() * sizeof(DType), outputs[0].shape_,
+ outputs[2].dev_mask(), outputs[2].type_flag_, outputs[2].dev_id()));
+ }
+ if (param.is_log) {
+ PdfGradCaller<xpu, DType, pdfgrad<true>, pnum, vparm>::op(inputs, req, grads, s);
+ } else {
+ PdfGradCaller<xpu, DType, pdfgrad<false>, pnum, vparm>::op(inputs, req, grads, s);
+ }
+ Tensor<xpu, 1, char> red_work(
+ tmp_space.dptr_ + pnum * outputs[0].Size() * sizeof(DType), Shape1(red_work_size), s);
+ broadcast::RTCReduce(ctx, outputs[1].reshape(dst_shape), req[1], red_work,
+ grads[1].reshape(src_shape), "red::sum{}", 2, "identity");
+ if (pnum == 2) {
+ broadcast::RTCReduce(ctx, outputs[2].reshape(dst_shape), req[2], red_work,
+ grads[2].reshape(src_shape), "red::sum{}", 2, "identity");
+ }
+ });
+
+#endif
}
} // namespace op
diff --git a/src/operator/tensor/broadcast_reduce-inl.cuh b/src/operator/tensor/broadcast_reduce-inl.cuh
deleted file mode 100644
index 53b0ad8..0000000
--- a/src/operator/tensor/broadcast_reduce-inl.cuh
+++ /dev/null
@@ -1,414 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2015-2020 by Contributors
- * \file broadcast_reduce-inl.cuh
- * \brief CUDA implementations for binary broadcast and reduce
- * \author Antti-Pekka Hynninen, Przemyslaw Tredak
-*/
-#ifndef MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_INL_CUH_
-#define MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_INL_CUH_
-
-using namespace mshadow::cuda;
-
-template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP, int unroll,
- typename IndexOP = mxnet::op::mshadow_op::set_index_no_op<AType, int>>
-__launch_bounds__(nthread_reduce)
-__global__ void reduce_kernel(const int N, const int M, const bool addto,
- const DType* __restrict big, OType *small,
- const Shape<ndim> big_shape0, const Shape<ndim> small_shape,
- const Shape<ndim> big_shape, const Shape<ndim> big_stride,
- const int Mnext, const bool do_transpose) {
- extern __shared__ char shTileChar[];
- AType* shTile = (AType*)(shTileChar);
- const int tid = threadIdx.x + threadIdx.y*blockDim.x;
- const int bx = (do_transpose) ? blockDim.y : blockDim.x;
- const int by = (do_transpose) ? blockDim.x : blockDim.y;
- const int tidx = (do_transpose) ? tid / by : threadIdx.x;
- const int tidy = (do_transpose) ? tid % by : threadIdx.y;
- for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) {
- // This TB handles M range [Mstart, ...., Mend - 1]
- const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext);
- const int Mend = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext);
- for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) {
- int idx = idx0 + tidx;
- Shape<ndim> coord = mxnet_op::unravel(idx, small_shape);
- int idx_big0 = mxnet_op::ravel(coord, big_shape0);
-
- AType val, residual;
- Reducer::SetInitValue(val, residual);
- if (idx < N) {
- for (int k = tidy + Mstart; k < Mend; k += by*unroll) {
- int idx_big[unroll];
- #pragma unroll
- for (int u=0;u < unroll;u++) {
- idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride);
- }
- AType tmp[unroll];
- #pragma unroll
- for (int u=0;u < unroll;u++) {
- if (k + u*by < Mend) {
- tmp[u] = OP::Map(big[idx_big[u]]);
- // argmin/max, set IndexedNum.idx
- if (IndexOP::do_op)
- IndexOP::Op(&tmp[u], k+u*by);
- }
- }
- #pragma unroll
- for (int u=0;u < unroll;u++) {
- if (k + u*by < Mend) Reducer::Reduce(val, tmp[u], residual);
- }
- }
- }
-
- // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0
- if (by > 1) {
- // Fix bx to avoid bank conflicts. Assumes warpSize number of banks
- const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx;
- const int it0 = tidx + tidy*fbx;
- shTile[it0 * 2] = val;
- shTile[it0 * 2 + 1] = residual;
- __syncthreads();
- for (int t=1;t < by;t <<= 1) {
- AType tmp, tmp_residual;
- Reducer::SetInitValue(tmp, tmp_residual);
- if (tidy + t < by) {
- tmp = shTile[(it0 + t*fbx) * 2];
- tmp_residual = shTile[(it0 + t*fbx) * 2 + 1];
- }
- __syncthreads();
- Reducer::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual);
- __syncthreads();
- }
- if (idx < N && tidy == 0) {
- Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
- assign(&small[idx + m0*N], addto, OType(shTile[tidx * 2]));
- }
- } else {
- if (idx < N) {
- Reducer::Finalize(val, residual);
- assign(&small[idx + m0*N], addto, OType(val));
- }
- }
- }
- }
-}
-
-template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2, int unroll>
-__launch_bounds__(nthread_reduce)
-__global__ void reduce_kernel(const int N, const int M, const bool addto,
- const DType* __restrict big, const DType* __restrict lhs,
- const DType* __restrict rhs, DType *small,
- const Shape<ndim> big_shape0, const Shape<ndim> lhs_shape0,
- const Shape<ndim> rhs_shape0, const Shape<ndim> small_shape,
- const Shape<ndim> big_shape, const Shape<ndim> lhs_shape,
- const Shape<ndim> rhs_shape, const Shape<ndim> big_stride,
- const Shape<ndim> lhs_stride, const Shape<ndim> rhs_stride,
- const int Mnext, const bool do_transpose) {
- extern __shared__ char shTileChar[];
- DType* shTile = (DType*)(shTileChar);
- const int tid = threadIdx.x + threadIdx.y*blockDim.x;
- const int bx = (do_transpose) ? blockDim.y : blockDim.x;
- const int by = (do_transpose) ? blockDim.x : blockDim.y;
- const int tidx = (do_transpose) ? tid / by : threadIdx.x;
- const int tidy = (do_transpose) ? tid % by : threadIdx.y;
- for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) {
- // This TB handles M range [Mstart, ...., Mend - 1]
- const int Mstart = (int)((uint64_t)M*(uint64_t)m0/(uint64_t)Mnext);
- const int Mend = (int)((uint64_t)M*(uint64_t)(m0 + 1)/(uint64_t)Mnext);
- for (int idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) {
- int idx = idx0 + tidx;
- Shape<ndim> coord = mxnet_op::unravel(idx, small_shape);
- int idx_big0 = mxnet_op::ravel(coord, big_shape0);
- int idx_lhs0 = mxnet_op::ravel(coord, lhs_shape0);
- int idx_rhs0 = mxnet_op::ravel(coord, rhs_shape0);
-
- DType val, residual;
- Reducer::SetInitValue(val, residual);
- if (idx < N) {
- for (int k = tidy + Mstart; k < Mend; k += by*unroll) {
- int idx_big[unroll];
- int idx_lhs[unroll];
- int idx_rhs[unroll];
- #pragma unroll
- for (int u=0;u < unroll;u++) {
- idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride);
- idx_lhs[u] = idx_lhs0 + mxnet_op::unravel_dot(k + u*by, lhs_shape, lhs_stride);
- idx_rhs[u] = idx_rhs0 + mxnet_op::unravel_dot(k + u*by, rhs_shape, rhs_stride);
- }
- DType tmp[unroll];
- #pragma unroll
- for (int u=0;u < unroll;u++) {
- if (k + u*by < Mend) {
- tmp[u] = OP1::Map(big[idx_big[u]], OP2::Map(lhs[idx_lhs[u]], rhs[idx_rhs[u]]));
- }
- }
- #pragma unroll
- for (int u=0;u < unroll;u++) {
- if (k + u*by < Mend) Reducer::Reduce(val, tmp[u], residual);
- }
- }
- }
-
- // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0
- if (by > 1) {
- // Fix bx to avoid bank conflicts. Assumes warpSize number of banks
- const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx;
- const int it0 = tidx + tidy*fbx;
- shTile[it0 * 2] = val;
- shTile[it0 * 2 + 1] = residual;
- __syncthreads();
- for (int t=1;t < by;t <<= 1) {
- DType tmp, tmp_residual;
- Reducer::SetInitValue(tmp, tmp_residual);
- if (tidy + t < by) {
- tmp = shTile[(it0 + t*fbx) * 2];
- tmp_residual = shTile[(it0 + t*fbx) * 2 + 1];
- }
- __syncthreads();
- Reducer::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual);
- __syncthreads();
- }
- if (idx < N && tidy == 0) {
- Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
- assign(&small[idx + m0*N], addto, shTile[tidx * 2]);
- }
- } else {
- if (idx < N) {
- Reducer::Finalize(val, residual);
- assign(&small[idx + m0*N], addto, val);
- }
- }
- }
- }
-}
-
-// Simple reduction of lines when M is small
-template<typename Reducer, typename DType>
-__launch_bounds__(kMaxThreadsPerBlock)
-__global__ void reduce_lines_kernel(const int N, const int M, const bool addto,
- const int small_in_stride, const DType* __restrict small_in, DType *small_out) {
- for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
-
- DType val, residual;
- Reducer::SetInitValue(val, residual);
- for (int k = 0; k < M; k++) {
- Reducer::Reduce(val, small_in[idx + k*small_in_stride], residual);
- }
-
- if (idx < N) {
- Reducer::Finalize(val, residual);
- assign(&small_out[idx], addto, val);
- }
-
- }
-}
-
-template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP>
-__launch_bounds__(kMaxThreadsPerBlock)
-__global__ void reduce_kernel_M1(const int N, const bool addto,
- const DType* __restrict big, OType *small, const Shape<ndim> bshape,
- const Shape<ndim> sshape) {
- for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
- Shape<ndim> coord = mxnet_op::unravel(idx, sshape);
- int j = mxnet_op::ravel(coord, bshape);
- AType val, residual, temp = OP::Map(big[j]);
- Reducer::SetInitValue(val, residual);
- Reducer::Reduce(val, temp, residual);
- Reducer::Finalize(val, residual);
- assign(&small[idx], addto, OType(val));
- }
-}
-
-template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2>
-__launch_bounds__(kMaxThreadsPerBlock)
-__global__ void reduce_kernel_M1(const int N, const bool addto,
- const DType* __restrict big,
- const DType* __restrict lhs,
- const DType* __restrict rhs,
- DType *small,
- const Shape<ndim> big_shape,
- const Shape<ndim> lhs_shape,
- const Shape<ndim> rhs_shape,
- const Shape<ndim> small_shape) {
- for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
- Shape<ndim> coord = mxnet_op::unravel(idx, small_shape);
- int idx_big = mxnet_op::ravel(coord, big_shape);
- int idx_lhs = mxnet_op::ravel(coord, lhs_shape);
- int idx_rhs = mxnet_op::ravel(coord, rhs_shape);
- DType val, residual;
- Reducer::SetInitValue(val, residual);
- Reducer::Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual);
- Reducer::Finalize(val, residual);
- assign(&small[idx], addto, val);
- }
-}
-
-#define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \
- if (do_unroll) { \
- const int unrollVar = unrollAmount; \
- {__VA_ARGS__} \
- } else { \
- const int unrollVar = 1; \
- {__VA_ARGS__} \
- }
-
-template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP,
- typename IndexOP = mxnet::op::mshadow_op::set_index_no_op<AType, int>>
-void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req,
- const TBlob& big, const Tensor<gpu, 1, char>& workspace,
- const ReduceImplConfig& config) {
- if (config.M == 1) {
- reduce_kernel_M1<Reducer, ndim, AType, DType, OType, OP>
- <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>(
- config.N, req == kAddTo, big.dptr<DType>(), reinterpret_cast<OType*>(small.dptr_),
- big.shape_.get<ndim>(), small.shape_.get<ndim>());
- MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1);
- } else {
- OType* small_dptr = reinterpret_cast<OType*>(small.dptr_);
- bool addto = (req == kAddTo);
- if (config.Mnext > 1) {
- // small_dptr[] is N*Mnext*sizeof(DType) bytes
- small_dptr = reinterpret_cast<OType*>(workspace.dptr_);
- addto = false;
- // Check that the workspace is contigiuous
- CHECK_EQ(workspace.CheckContiguous(), true);
- // Check that we have enough storage
- CHECK_GE(workspace.size(0), config.workspace_size);
- }
-
- const int by = (config.kernel_1.do_transpose) ?
- config.kernel_1.blockDim.x : config.kernel_1.blockDim.y;
- const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce );
- KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, {
- reduce_kernel<Reducer, ndim, AType, DType, OType, OP, UNROLL, IndexOP>
- <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>(
- config.N, config.M, addto, big.dptr<DType>(), small_dptr, big.shape_.get<ndim>(),
- small.shape_.get<ndim>(), config.rshape.get<ndim>(), config.rstride.get<ndim>(),
- config.Mnext, config.kernel_1.do_transpose);
- });
- MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel);
-
- if (config.Mnext > 1) {
- reduce_lines_kernel<Reducer, OType>
- <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>>
- (config.N, config.Mnext, req == kAddTo, config.N, small_dptr,
- reinterpret_cast<OType*>(small.dptr_));
- MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel);
- }
- }
-}
-
-template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2>
-void ReduceImpl(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const TBlob& rhs,
- const OpReqType req, const TBlob& big, const Tensor<gpu, 1, char>& workspace,
- const ReduceImplConfig& config) {
- if (config.M == 1) {
- reduce_kernel_M1<Reducer, ndim, DType, OP1, OP2>
- <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>(
- config.N, req == kAddTo, big.dptr<DType>(), lhs.dptr<DType>(), rhs.dptr<DType>(),
- small.dptr<DType>(), big.shape_.get<ndim>(), lhs.shape_.get<ndim>(),
- rhs.shape_.get<ndim>(), small.shape_.get<ndim>());
- MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1);
- } else {
- DType* small_dptr = small.dptr<DType>();
- bool addto = (req == kAddTo);
- if (config.Mnext > 1) {
- // small_dptr[] is N*Mnext*sizeof(DType) bytes
- small_dptr = reinterpret_cast<DType*>(workspace.dptr_);
- addto = false;
- // Check that the workspace is contigiuous
- CHECK_EQ(workspace.CheckContiguous(), true);
- // Check that we have enough storage
- CHECK_GE(workspace.size(0), config.workspace_size);
- }
-
- const int by = (config.kernel_1.do_transpose) ?
- config.kernel_1.blockDim.x : config.kernel_1.blockDim.y;
- const bool do_unroll = ( config.M / (by*config.Mnext) >= unroll_reduce );
- KERNEL_UNROLL_SWITCH(do_unroll, unroll_reduce, UNROLL, {
- reduce_kernel<Reducer, ndim, DType, OP1, OP2, UNROLL>
- <<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>(
- config.N, config.M, addto, big.dptr<DType>(), lhs.dptr<DType>(), rhs.dptr<DType>(),
- small_dptr, big.shape_.get<ndim>(), lhs.shape_.get<ndim>(),
- rhs.shape_.get<ndim>(), small.shape_.get<ndim>(), config.rshape.get<ndim>(),
- config.lhs_shape.get<ndim>(), config.rhs_shape.get<ndim>(), config.rstride.get<ndim>(),
- config.lhs_stride.get<ndim>(), config.rhs_stride.get<ndim>(), config.Mnext,
- config.kernel_1.do_transpose);
- MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel);
- });
-
- if (config.Mnext > 1) {
- reduce_lines_kernel<Reducer, DType>
- <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>>
- (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr<DType>());
- MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel);
- }
- }
-}
-
-#undef KERNEL_UNROLL_SWITCH
-
-template<typename Reducer, int ndim, typename DType, typename OP, bool safe_acc = false>
-void Reduce(Stream<gpu> *s, const TBlob& small, const OpReqType req,
- const Tensor<gpu, 1, char>& workspace, const TBlob& big) {
- if (req == kNullOp) return;
- cudaStream_t stream = Stream<gpu>::GetStream(s);
- ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, sizeof(DType));
- if (safe_acc) {
- MXNET_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
- typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
- MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
- typedef typename std::conditional<safe_acc, OType, DataType>::type OutType;
- config = ReduceImplConfig(small.shape_, big.shape_, nullptr, nullptr,
- sizeof(AccType));
- ReduceImpl<Reducer, ndim, AccType, DataType, OutType, OP>(
- stream, small, req, big, workspace, config);
- });
- });
- } else {
- ReduceImpl<Reducer, ndim, DType, DType, DType, OP>(stream, small, req, big, workspace, config);
- }
-}
-
-template<typename Reducer, int ndim, typename DType, typename OP, bool safe_acc = false>
-void ReduceBool(Stream<gpu> *s, const TBlob& small, const OpReqType req,
- const Tensor<gpu, 1, char>& workspace, const TBlob& big) {
- if (req == kNullOp) return;
- cudaStream_t stream = Stream<gpu>::GetStream(s);
- ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, sizeof(DType));
- ReduceImpl<Reducer, ndim, bool, DType, bool, OP>(stream, small, req, big, workspace, config);
-}
-
-template <typename Reducer, int ndim, typename DType, typename OP>
-void ReduceWithExtraMem(Stream<gpu>* s, const TBlob& small, const OpReqType req,
- const Tensor<gpu, 1, char>& workspace, const TBlob& big) {};
-
-template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2>
-void Reduce(Stream<gpu> *s, const TBlob& small, const OpReqType req,
- const Tensor<gpu, 1, char>& workspace, const TBlob& big,
- const TBlob& lhs, const TBlob& rhs) {
- if (req == kNullOp) return;
- cudaStream_t stream = Stream<gpu>::GetStream(s);
- ReduceImplConfig config(small.shape_, big.shape_, &lhs.shape_, &rhs.shape_, sizeof(DType));
- ReduceImpl<Reducer, ndim, DType, OP1, OP2>(stream, small, lhs, rhs, req, big, workspace, config);
-}
-
-#endif //MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_INL_CUH_
diff --git a/src/operator/tensor/broadcast_reduce-inl.h b/src/operator/tensor/broadcast_reduce-inl.h
index a7d5d67..1907c02 100644
--- a/src/operator/tensor/broadcast_reduce-inl.h
+++ b/src/operator/tensor/broadcast_reduce-inl.h
@@ -445,13 +445,13 @@ void ReduceWithExtraMem(Stream<cpu>* s, const TBlob& small, const OpReqType req,
}
inline size_t ReduceWorkspaceSize(Stream<cpu> *s, const mxnet::TShape& small, const OpReqType req,
- const mxnet::TShape& big, const int type_size) {
+ const mxnet::TShape& big) {
return 0;
}
inline size_t ReduceWorkspaceSize(Stream<cpu> *s, const mxnet::TShape& small, const OpReqType req,
const mxnet::TShape& big, const mxnet::TShape& lhs,
- const mxnet::TShape& rhs, const int type_size) {
+ const mxnet::TShape& rhs) {
return 0;
}
@@ -539,11 +539,13 @@ struct ReduceImplConfig {
inline ReduceImplConfig(const ::mxnet::TShape& small, const ::mxnet::TShape& big,
const ::mxnet::TShape* lhs,
- const ::mxnet::TShape* rhs,
- const size_t type_size) :
+ const ::mxnet::TShape* rhs) :
rshape(small.ndim(), 1), rstride(small.ndim(), 1),
lhs_shape(small.ndim(), 1), lhs_stride(small.ndim(), 1),
rhs_shape(small.ndim(), 1), rhs_stride(small.ndim(), 1) {
+ // The largest reduction type currently is (index_t, double) struct
+ // aligned to 16B
+ constexpr size_t max_type_size = 2 * sizeof(double);
constexpr int maxLoopPerTB = 64;
int ndim = small.ndim();
@@ -646,7 +648,7 @@ struct ReduceImplConfig {
by++;
}
kernel_1.shMemSize = (kernel_1.blockDim.x > 1) ?
- kernel_1.blockDim.x*by*type_size * 2 : 0;
+ kernel_1.blockDim.x*by*max_type_size * 2 : 0;
// Maximum number of times we want TB to loop in M
// Max size of M-block each TB can handle
int maxMblock = kernel_1.blockDim.x*maxLoopPerTB;
@@ -657,7 +659,7 @@ struct ReduceImplConfig {
ceil_idiv<unsigned int>(N, kernel_1.blockDim.x));
kernel_1.gridDim.y = std::min(kBaseGridNum, Mnext);
kernel_1.shMemSize = (kernel_1.blockDim.y > 1) ?
- kernel_1.blockDim.x*kernel_1.blockDim.y*type_size * 2 : 0;
+ kernel_1.blockDim.x*kernel_1.blockDim.y*max_type_size * 2 : 0;
// Maximum number of times we want TB to loop in M
// Max size of M-block each TB can handle
int maxMblock = kernel_1.blockDim.y*maxLoopPerTB;
@@ -666,7 +668,7 @@ struct ReduceImplConfig {
if (Mnext > 1) {
// small_dptr[] is N*Mnext*type_size bytes
- workspace_size += N*Mnext*sizeof(double);
+ workspace_size += N * Mnext * max_type_size;
// Set gridDim.y to Mnext
kernel_1.gridDim.y = std::min(kBaseGridNum, Mnext);
}
@@ -681,24 +683,20 @@ struct ReduceImplConfig {
};
inline size_t ReduceWorkspaceSize(Stream<gpu> *s, const ::mxnet::TShape& small, const OpReqType req,
- const ::mxnet::TShape& big, const int type_size) {
+ const ::mxnet::TShape& big) {
if (req == kNullOp) return 0;
- ReduceImplConfig config(small, big, nullptr, nullptr, type_size);
+ ReduceImplConfig config(small, big, nullptr, nullptr);
return config.workspace_size;
}
inline size_t ReduceWorkspaceSize(Stream<gpu> *s, const ::mxnet::TShape& small, const OpReqType req,
const ::mxnet::TShape& big, const ::mxnet::TShape& lhs,
- const ::mxnet::TShape& rhs, const int type_size) {
+ const ::mxnet::TShape& rhs) {
if (req == kNullOp) return 0;
- ReduceImplConfig config(small, big, &lhs, &rhs, type_size);
+ ReduceImplConfig config(small, big, &lhs, &rhs);
return config.workspace_size;
}
-#ifdef __CUDACC__
-#include "broadcast_reduce-inl.cuh"
-#endif
-
#endif // MXNET_USE_CUDA
template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2>
@@ -784,7 +782,8 @@ void RTCReduce(const OpContext& ctx,
const TBlob& big,
const std::string& reducer,
int ndim,
- const std::string& OP);
+ const std::string& OP,
+ const bool use_index = false);
void RTCReduce(const OpContext& ctx,
const TBlob& small,
diff --git a/src/operator/tensor/broadcast_reduce_minmax_value.cu b/src/operator/tensor/broadcast_reduce_minmax_value.cu
index baf79fe..c8cb757 100644
--- a/src/operator/tensor/broadcast_reduce_minmax_value.cu
+++ b/src/operator/tensor/broadcast_reduce_minmax_value.cu
@@ -28,13 +28,15 @@ namespace mxnet {
namespace op {
NNVM_REGISTER_OP(max)
-.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow::red::maximum>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<ReduceAxesParam, 0>
+ {"identity", "red::maximum{}", false});
NNVM_REGISTER_OP(_backward_max)
.set_attr<FCompute>("FCompute<gpu>", ReduceAxesBackwardUseInOut<gpu, mshadow_op::eq>);
NNVM_REGISTER_OP(min)
-.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow::red::minimum>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<ReduceAxesParam, 0>
+ {"identity", "red::minimum{}", false});
NNVM_REGISTER_OP(_backward_min)
.set_attr<FCompute>("FCompute<gpu>", ReduceAxesBackwardUseInOut<gpu, mshadow_op::eq>);
diff --git a/src/operator/tensor/broadcast_reduce_op.cc b/src/operator/tensor/broadcast_reduce_op.cc
new file mode 100644
index 0000000..483787e
--- /dev/null
+++ b/src/operator/tensor/broadcast_reduce_op.cc
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "broadcast_reduce_op.h"
+#include <limits>
+#include "../numpy/np_broadcast_reduce_op.h"
+#include "elemwise_binary_scalar_op.h"
+#include "mxnet/tuple.h"
+
+namespace mxnet {
+namespace op {
+
+#if MXNET_USE_CUDA
+
+void ReduceAxesRTCComputeImpl(const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs,
+ const mxnet::TShape& small,
+ const std::string& reducer,
+ const mshadow::Tensor<gpu, 1, char>* workspace,
+ const bool normalize,
+ const std::string& OP,
+ const int ddof) {
+ using namespace mshadow;
+
+ mxnet::TShape src_shape, dst_shape;
+ BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape);
+ Stream<gpu>* s = ctx.get_stream<gpu>();
+ Tensor<gpu, 1, char> w;
+ if (workspace == nullptr) {
+ size_t workspace_size = broadcast::ReduceWorkspaceSize(
+ s, dst_shape, req[0], src_shape);
+ w = ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(workspace_size), s);
+ workspace = &w;
+ }
+ const TBlob in_data = inputs[0].reshape(src_shape);
+ const TBlob out_data = outputs[0].reshape(dst_shape);
+ BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
+ broadcast::RTCReduce(ctx, out_data, req[0], *workspace, in_data, reducer, NDim, OP);
+ });
+ if (normalize) {
+ NumpyBinaryScalarParam p{};
+ p.scalar = static_cast<double>(src_shape.Size()/dst_shape.Size() - ddof);
+ NodeAttrs a;
+ a.parsed = p;
+ BinaryScalarRTCCompute {"div"}(a, ctx, {out_data}, {kWriteInplace}, {out_data});
+ }
+}
+
+namespace {
+template <typename Param>
+void PrepareReduce(const Param& param,
+ const std::vector<TBlob>& inputs,
+ const std::vector<TBlob>& outputs,
+ mxnet::TShape* shape, int* ddof);
+
+template <>
+void PrepareReduce<ReduceAxesParam>(const ReduceAxesParam& param,
+ const std::vector<TBlob>& inputs,
+ const std::vector<TBlob>& outputs,
+ mxnet::TShape* small, int* ddof) {
+ if (param.keepdims) {
+ *small = outputs[0].shape_;
+ } else {
+ *small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, param.exclude);
+ }
+
+ *ddof = 0;
+}
+
+template <>
+void PrepareReduce<NumpyReduceAxesNoDTypeParam>(const NumpyReduceAxesNoDTypeParam& param,
+ const std::vector<TBlob>& inputs,
+ const std::vector<TBlob>& outputs,
+ mxnet::TShape* small, int* ddof) {
+ if (param.initial.has_value()) {
+ LOG(FATAL) << "initial is not supported yet";
+ }
+ if (param.keepdims) {
+ *small = outputs[0].shape_;
+ } else {
+ *small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true);
+ }
+
+ *ddof = 0;
+}
+
+template <>
+void PrepareReduce<NumpyReduceAxesParam>(const NumpyReduceAxesParam& param,
+ const std::vector<TBlob>& inputs,
+ const std::vector<TBlob>& outputs,
+ mxnet::TShape* small, int* ddof) {
+ if (param.initial.has_value()) {
+ LOG(FATAL) << "initial is not supported yet";
+ }
+ if (param.keepdims) {
+ *small = outputs[0].shape_;
+ } else {
+ *small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true);
+ }
+
+ *ddof = 0;
+}
+
+template <>
+void PrepareReduce<NumpyReduceAxesBoolParam>(const NumpyReduceAxesBoolParam& param,
+ const std::vector<TBlob>& inputs,
+ const std::vector<TBlob>& outputs,
+ mxnet::TShape* small, int* ddof) {
+ if (param.keepdims) {
+ *small = outputs[0].shape_;
+ } else {
+ *small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true);
+ }
+
+ *ddof = 0;
+}
+
+} // namespace
+
+template <typename Param, int init>
+void ReduceAxesRTCCompute<Param, init>::operator()(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ if (req[0] == kNullOp) return;
+ mxnet::TShape small;
+ int ddof;
+ const auto& param = nnvm::get<Param>(attrs.parsed);
+ CHECK_NE(req[0], kWriteInplace) << "Reduce does not support write in-place";
+ PrepareReduce(param, inputs, outputs, &small, &ddof);
+ if (outputs[0].shape_.Size() == 0U) return; // zero-size tensor
+ if (inputs[0].shape_.Size() == 0) {
+ if (normalize && mxnet::common::is_float(outputs[0].type_flag_)) {
+ LOG(WARNING) << "WARNING: Mean of empty slice.";
+ NumpyBinaryScalarParam p{};
+ p.scalar = std::numeric_limits<float>::quiet_NaN();
+ NodeAttrs a;
+ a.parsed = p;
+ BinaryScalarRTCCompute {"right"} (a, ctx, outputs, {kWriteTo}, outputs);
+ } else {
+ if (normalize) {
+ LOG(WARNING) << "WARNING: nan is outside the range of"<<
+ "representable values of type 'int'";
+ }
+ if (init == 0 && req[0] == kAddTo) return;
+ NumpyBinaryScalarParam p{};
+ p.scalar = init;
+ NodeAttrs a;
+ a.parsed = p;
+ BinaryScalarRTCCompute {"right"} (a, ctx, outputs, {req[0]}, outputs);
+ }
+ return;
+ }
+
+ ReduceAxesRTCComputeImpl(ctx, inputs, req, outputs, small, reducer, nullptr, normalize, OP, ddof);
+}
+
+template struct ReduceAxesRTCCompute<ReduceAxesParam, 0>;
+template struct ReduceAxesRTCCompute<NumpyReduceAxesParam, 0>;
+template struct ReduceAxesRTCCompute<NumpyReduceAxesParam, 1>;
+template struct ReduceAxesRTCCompute<NumpyReduceAxesNoDTypeParam, 0>;
+template struct ReduceAxesRTCCompute<NumpyReduceAxesBoolParam, 0>;
+template struct ReduceAxesRTCCompute<NumpyReduceAxesBoolParam, 1>;
+
+#endif
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h
index 872a898..7baf255 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -658,7 +658,9 @@ void ReduceAxesComputeImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
- const mxnet::TShape& small) {
+ const mxnet::TShape& small,
+ const mshadow::Tensor<xpu, 1, char>* workspace = nullptr,
+ const int ddof = 0) {
using namespace mshadow;
using namespace mshadow::expr;
@@ -670,15 +672,18 @@ void ReduceAxesComputeImpl(const OpContext& ctx,
const TBlob in_data = inputs[0].reshape(src_shape);
const TBlob out_data = outputs[0].reshape(dst_shape);
BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
- size_t workspace_size = broadcast::ReduceWorkspaceSize(
- s, out_data.shape_, req[0], in_data.shape_, sizeof(OType));
- Tensor<xpu, 1, char> workspace =
- ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+ Tensor<xpu, 1, char> w;
+ if (workspace == nullptr) {
+ size_t workspace_size = broadcast::ReduceWorkspaceSize(
+ s, out_data.shape_, req[0], in_data.shape_);
+ w = ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
+ workspace = &w;
+ }
broadcast::Reduce<reducer, NDim, DType, OP, safe_acc>(
- s, out_data, req[0], workspace, in_data);
+ s, out_data, req[0], *workspace, in_data);
if (normalize) {
auto out = out_data.FlatTo2D<xpu, OType>(s);
- out /= scalar<OType>(src_shape.Size()/dst_shape.Size());
+ out /= scalar<OType>(src_shape.Size()/dst_shape.Size() - ddof);
}
});
});
@@ -704,7 +709,7 @@ void ReduceAxesComputeBoolImpl(const OpContext& ctx,
const TBlob out_data = outputs[0].reshape(dst_shape);
BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
size_t workspace_size = broadcast::ReduceWorkspaceSize(
- s, out_data.shape_, req[0], in_data.shape_, sizeof(OType));
+ s, out_data.shape_, req[0], in_data.shape_);
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
broadcast::ReduceBool<reducer, NDim, DType, OP>(
@@ -736,6 +741,35 @@ void ReduceAxesCompute(const nnvm::NodeAttrs& attrs,
ReduceAxesComputeImpl<xpu, reducer, false, normalize, OP>(ctx, inputs, req, outputs, small);
}
+#if MXNET_USE_CUDA
+
+template <typename Param, int init>
+struct ReduceAxesRTCCompute {
+ std::string OP;
+ std::string reducer;
+ bool normalize;
+
+ void operator()(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs);
+};
+
+void ReduceAxesRTCComputeImpl(const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs,
+ const mxnet::TShape& small,
+ const std::string& reducer,
+ const mshadow::Tensor<gpu, 1, char>* workspace = nullptr,
+
+ const bool normalize = false,
+ const std::string& OP = "identity",
+ const int ddof = 0);
+
+#endif
+
template <typename red_op, int req, int axis>
struct ReduceCsrKernel;
@@ -1516,7 +1550,8 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs,
} else {
small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, false);
}
- bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true);
+#if !defined(__CUDACC__)
+ bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) {
common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for LpNorm with float16 inputs. "
"See https://mxnet.apache.org/api/faq/env_var "
@@ -1539,6 +1574,15 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs,
ctx, inputs, req, outputs, small);
}
}
+#else
+ const std::string &red = param.ord == 1
+ ? "red::sum{}"
+ : "red::nrm2{}";
+ const std::string &op = param.ord == 1
+ ? "abs"
+ : "identity";
+ ReduceAxesRTCComputeImpl(ctx, inputs, req, outputs, small, red, nullptr, false, op);
+#endif
}
template<int req>
diff --git a/src/operator/tensor/broadcast_reduce_op_value.cu b/src/operator/tensor/broadcast_reduce_op_value.cu
index 35b3c02..f7c2834 100644
--- a/src/operator/tensor/broadcast_reduce_op_value.cu
+++ b/src/operator/tensor/broadcast_reduce_op_value.cu
@@ -37,7 +37,8 @@ NNVM_REGISTER_OP(broadcast_like)
.set_attr<FCompute>("FCompute<gpu>", BroadcastCompute<gpu>);
NNVM_REGISTER_OP(_broadcast_backward)
-.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow::red::sum>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<ReduceAxesParam, 0>{"identity",
+ "red::sum{}", false});
} // namespace op
} // namespace mxnet
diff --git a/src/operator/tensor/broadcast_reduce_prod_value.cu b/src/operator/tensor/broadcast_reduce_prod_value.cu
index 5731de3..7e7a95b 100644
--- a/src/operator/tensor/broadcast_reduce_prod_value.cu
+++ b/src/operator/tensor/broadcast_reduce_prod_value.cu
@@ -28,13 +28,15 @@ namespace mxnet {
namespace op {
NNVM_REGISTER_OP(prod)
-.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow_op::product>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<ReduceAxesParam, 0>
+ {"identity", "red::product{}", false});
NNVM_REGISTER_OP(_backward_prod)
.set_attr<FCompute>("FCompute<gpu>", ReduceAxesBackwardUseInOut<gpu, mshadow_op::rdiv>);
NNVM_REGISTER_OP(nanprod)
-.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow_op::nanprod>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<ReduceAxesParam, 0>
+ {"identity", "red::nanprod{}", false});
NNVM_REGISTER_OP(_backward_nanprod)
.set_attr<FCompute>("FCompute<gpu>", ReduceAxesBackwardUseInOut<gpu, mshadow_op::nanprod_grad>);
diff --git a/src/operator/tensor/broadcast_reduce_sum_value.cu b/src/operator/tensor/broadcast_reduce_sum_value.cu
index 2385d36..40a8ed8 100644
--- a/src/operator/tensor/broadcast_reduce_sum_value.cu
+++ b/src/operator/tensor/broadcast_reduce_sum_value.cu
@@ -28,19 +28,22 @@ namespace mxnet {
namespace op {
NNVM_REGISTER_OP(sum)
-.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow::red::sum>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<ReduceAxesParam, 0>{"identity",
+ "red::sum{}", false});
NNVM_REGISTER_OP(_backward_sum)
.set_attr<FCompute>("FCompute<gpu>", ReduceAxesBackwardUseNone<gpu>);
NNVM_REGISTER_OP(mean)
-.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow::red::sum, true>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<ReduceAxesParam, 0>{"identity",
+ "red::sum{}", true});
NNVM_REGISTER_OP(_backward_mean)
.set_attr<FCompute>("FCompute<gpu>", ReduceAxesBackwardUseNone<gpu, true>);
NNVM_REGISTER_OP(nansum)
-.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow_op::nansum>);
+.set_attr<FCompute>("FCompute<gpu>", ReduceAxesRTCCompute<ReduceAxesParam, 0>{"identity",
+ "red::nansum{}", false});
NNVM_REGISTER_OP(_backward_nansum)
.set_attr<FCompute>("FCompute<gpu>", ReduceAxesBackwardUseInOut<gpu, mshadow_op::nansum_grad>);
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.cc b/src/operator/tensor/elemwise_binary_broadcast_op.cc
index 2f9832a..9a682dc 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.cc
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.cc
@@ -376,10 +376,10 @@ void BinaryBroadcastRTCBackwardUseNone::operator()(const nnvm::NodeAttrs& attrs,
if (out.shape_.Size() != 0) {
broadcast::RTCReduce(ctx, lhs, req[0],
workspace, out,
- "red::sum", NDim, LOP);
+ "red::sum{}", NDim, LOP);
broadcast::RTCReduce(ctx, rhs, req[1],
workspace, out,
- "red::sum", NDim, ROP);
+ "red::sum{}", NDim, ROP);
} else {
using namespace common::cuda::rtc::util;
if (lhs.shape_.Size() != 0) {
@@ -425,21 +425,21 @@ void BinaryBroadcastRTCBackwardUseIn::operator()(const nnvm::NodeAttrs& attrs,
const TBlob rhs = inputs[2].reshape(new_rshape);
size_t workspace_size_l = broadcast::ReduceWorkspaceSize(
s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_,
- rhs.shape_, common::mshadow_type_info(outputs[0].type_flag_).size);
+ rhs.shape_);
size_t workspace_size_r = broadcast::ReduceWorkspaceSize(
s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_,
- rhs.shape_, common::mshadow_type_info(outputs[1].type_flag_).size);
+ rhs.shape_);
size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
Tensor<gpu, 1, char> workspace =
ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(workspace_size), s);
if (req[0] != kNullOp) {
broadcast::RTCReduce(ctx, lgrad, req[0], workspace,
- ograd, lhs, rhs, "red::sum", NDim,
+ ograd, lhs, rhs, "red::sum{}", NDim,
"mul", LOP);
}
if (req[1] != kNullOp) {
broadcast::RTCReduce(ctx, rgrad, req[1], workspace,
- ograd, lhs, rhs, "red::sum", NDim,
+ ograd, lhs, rhs, "red::sum{}", NDim,
"mul", ROP);
}
});
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h
index a5bfdd7..b1700c7 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -629,9 +629,9 @@ inline void BinaryBroadcastBackwardUseInImpl(const OpContext& ctx,
const TBlob lhs = inputs[1].reshape(new_lshape);
const TBlob rhs = inputs[2].reshape(new_rshape);
size_t workspace_size_l = ReduceWorkspaceSize(
- s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType));
+ s, lgrad.shape_, req[0], ograd.shape_, lhs.shape_, rhs.shape_);
size_t workspace_size_r = ReduceWorkspaceSize(
- s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_, sizeof(DType));
+ s, rgrad.shape_, req[1], ograd.shape_, lhs.shape_, rhs.shape_);
size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index 49a3ed2..24d6ca8 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -2046,8 +2046,13 @@ void RepeatOpBackward(const nnvm::NodeAttrs& attrs,
inputs[0].type_flag_, inputs[0].dev_id());
std::vector<TBlob> newInputs = {iblob};
+#if !defined(__CUDACC__)
ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, false>(
ctx, newInputs, req, newOutputs, rshapes.first);
+#else
+ ReduceAxesRTCComputeImpl(ctx, newInputs, req, newOutputs, rshapes.first,
+ "red::sum{}", nullptr, false);
+#endif
}
struct TileParam : public dmlc::Parameter<TileParam> {
@@ -2238,8 +2243,13 @@ void TileOpBackward(const nnvm::NodeAttrs& attrs,
inputs[0].type_flag_, inputs[0].dev_id());
std::vector<TBlob> newInputs = {iblob};
+#if !defined(__CUDACC__)
ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, false>(
ctx, newInputs, req, newOutputs, rshapes.first);
+#else
+ ReduceAxesRTCComputeImpl(ctx, newInputs, req, newOutputs, rshapes.first,
+ "red::sum{}", nullptr, false);
+#endif
}
struct ReverseParam : public dmlc::Parameter<ReverseParam> {
diff --git a/src/operator/tensor/reduce_rtc.cc b/src/operator/tensor/reduce_rtc.cc
index 9e2d6d3..ac39f6f 100644
--- a/src/operator/tensor/reduce_rtc.cc
+++ b/src/operator/tensor/reduce_rtc.cc
@@ -49,12 +49,39 @@ struct reduce_kernel_params {
const char reduce_function_code[] = R"code(
#define FUNC OP(IType0::from(big[idx_big[u]]))
+using AType = typename AccType<InputType0>::type;
)code";
const char reduce_function_use_input_code[] = R"code(
#define FUNC OP1(IType0::from(big[idx_big[u]]), \
OP2(IType1::from(lhs[idx_lhs[u]]), \
IType2::from(rhs[idx_rhs[u]])))
+using AType = typename AccType<InputType0>::type;
+)code";
+
+const char reduce_function_index_code[] = R"code(
+#define FUNC AType(OP(IType0::from(big[idx_big[u]])), index)
+
+template <typename T>
+struct AccIndex {
+ index_t idx;
+ T num;
+
+ __device__ inline AccIndex<T>() {}
+ __device__ inline AccIndex<T>(const T& val, const index_t idx) : num(val), idx(idx) {}
+
+ __device__ inline operator index_t() const volatile {
+ return idx;
+ }
+
+ __device__ inline AccIndex<T>& operator=(const AccIndex<T>& other) {
+ idx = other.idx;
+ num = other.num;
+ return *this;
+ }
+};
+
+using AType = AccIndex<typename AccType<InputType0>::type>;
)code";
const char reduce_kernel_code[] = R"code(
@@ -71,21 +98,107 @@ struct reduce_kernel_params {
index_t rhs_shape[util::MAX_DIM];
};
-__launch_bounds__(kRTCMaxThreadsPerBlock)
-__global__ void reduce_kernel(const int N, const int M, const bool addto,
- const InputType0* __restrict big,
- const InputType1* __restrict lhs,
- const InputType2* __restrict rhs,
- OutputType0 *small,
- const reduce_kernel_params params,
- const int Mnext) {
+inline __device__ AType reduce(const index_t idx, const int tidx,
+ const int tidy, const int N,
+ const index_t Mstart, const index_t Mend,
+ const InputType0* __restrict big,
+ const InputType1* __restrict lhs,
+ const InputType2* __restrict rhs,
+ const reduce_kernel_params& params) {
extern __shared__ char shTileChar[];
using IType0 = AccType<InputType0>;
using IType1 = AccType<InputType1>;
using IType2 = AccType<InputType2>;
using OType = AccType<OutputType0>;
- using AType = typename IType0::type;
AType* shTile = (AType*)(shTileChar);
+ const int bx = (do_transpose) ? blockDim.y : blockDim.x;
+ const int by = (do_transpose) ? blockDim.x : blockDim.y;
+ index_t coord[ndim];
+ util::unravel(idx, params.small_shape, coord);
+ index_t idx_big0, idx_lhs0, idx_rhs0;
+ idx_big0 = util::ravel(coord, params.big_shape);
+ if (use_input) {
+ idx_lhs0 = util::ravel(coord, params.lhs_shape0);
+ idx_rhs0 = util::ravel(coord, params.rhs_shape0);
+ }
+
+ AType val, residual;
+ REDUCER.SetInitValue(val, residual);
+ if (idx < N) {
+ for (index_t k = tidy + Mstart; k < Mend; k += by*UNROLL) {
+ index_t idx_big[UNROLL];
+ index_t idx_lhs[UNROLL];
+ index_t idx_rhs[UNROLL];
+ #pragma unroll
+ for (int u=0;u < UNROLL;u++) {
+ idx_big[u] = idx_big0 + util::unravel_dot<ndim>(k + u*by, params.rshape,
+ params.rstride);
+ if (use_input) {
+ idx_lhs[u] = idx_lhs0 + util::unravel_dot<ndim>(k + u*by, params.lhs_shape,
+ params.lhs_stride);
+ idx_rhs[u] = idx_rhs0 + util::unravel_dot<ndim>(k + u*by, params.rhs_shape,
+ params.rhs_stride);
+ }
+ }
+ AType tmp[UNROLL];
+ #pragma unroll
+ for (int u=0;u < UNROLL;u++) {
+ if (k + u*by < Mend) {
+ const index_t index = k + u*by;
+ tmp[u] = FUNC;
+ }
+ }
+ #pragma unroll
+ for (int u=0;u < UNROLL;u++) {
+ if (k + u*by < Mend) REDUCER.Reduce(val, tmp[u], residual);
+ }
+ }
+ }
+
+ // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0
+ if (by > 1) {
+ // Fix bx to avoid bank conflicts. Assumes warpSize number of banks
+ const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx;
+ const int it0 = tidx + tidy*fbx;
+ shTile[it0 * 2] = val;
+ shTile[it0 * 2 + 1] = residual;
+ __syncthreads();
+ for (int t=1;t < by;t <<= 1) {
+ AType tmp, tmp_residual;
+ REDUCER.SetInitValue(tmp, tmp_residual);
+ if (tidy + t < by) {
+ tmp = shTile[(it0 + t*fbx) * 2];
+ tmp_residual = shTile[(it0 + t*fbx) * 2 + 1];
+ }
+ __syncthreads();
+ REDUCER.Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual);
+ __syncthreads();
+ }
+ if (idx < N && tidy == 0) {
+ REDUCER.Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
+ return shTile[tidx * 2];
+ } else {
+ return AType();
+ }
+ } else {
+ if (idx < N) {
+ REDUCER.Finalize(val, residual);
+ return val;
+ } else {
+ return AType();
+ }
+ }
+}
+
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void reduce_kernel_single(const int N, const int M,
+ const InputType0* __restrict big,
+ const InputType1* __restrict lhs,
+ const InputType2* __restrict rhs,
+ OutputType0 *small,
+ const reduce_kernel_params params,
+ const int Mnext) {
+ using OType = AccType<OutputType0>;
const int tid = threadIdx.x + threadIdx.y*blockDim.x;
const int bx = (do_transpose) ? blockDim.y : blockDim.x;
const int by = (do_transpose) ? blockDim.x : blockDim.y;
@@ -96,117 +209,74 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
const index_t Mstart = (index_t)((int64)M*(int64)m0/(int64)Mnext);
const index_t Mend = (index_t)((int64)M*(int64)(m0 + 1)/(int64)Mnext);
for (index_t idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) {
- int idx = idx0 + tidx;
- index_t coord[ndim];
- util::unravel(idx, params.small_shape, coord);
- index_t idx_big0, idx_lhs0, idx_rhs0;
- idx_big0 = util::ravel(coord, params.big_shape);
- if (use_input) {
- idx_lhs0 = util::ravel(coord, params.lhs_shape0);
- idx_rhs0 = util::ravel(coord, params.rhs_shape0);
- }
-
- AType val, residual;
- REDUCER::SetInitValue(val, residual);
- if (idx < N) {
- for (index_t k = tidy + Mstart; k < Mend; k += by*UNROLL) {
- index_t idx_big[UNROLL];
- index_t idx_lhs[UNROLL];
- index_t idx_rhs[UNROLL];
- #pragma unroll
- for (int u=0;u < UNROLL;u++) {
- idx_big[u] = idx_big0 + util::unravel_dot<ndim>(k + u*by, params.rshape,
- params.rstride);
- if (use_input) {
- idx_lhs[u] = idx_lhs0 + util::unravel_dot<ndim>(k + u*by, params.lhs_shape,
- params.lhs_stride);
- idx_rhs[u] = idx_rhs0 + util::unravel_dot<ndim>(k + u*by, params.rhs_shape,
- params.rhs_stride);
- }
- }
- typename OType::type tmp[UNROLL];
- #pragma unroll
- for (int u=0;u < UNROLL;u++) {
- if (k + u*by < Mend) {
- tmp[u] = FUNC;
- }
- }
- #pragma unroll
- for (int u=0;u < UNROLL;u++) {
- if (k + u*by < Mend) REDUCER::Reduce(val, tmp[u], residual);
- }
+ const index_t idx = idx0 + tidx;
+ AType val = reduce(idx, tidx, tidy, N, Mstart, Mend, big, lhs, rhs, params);
+ if (idx < N && (by == 1 || tidy == 0)) {
+ if (req == OpReqType::kAddTo) {
+ small[idx + m0 * N] = OType::to(op::add(OType::from(small[idx + m0 * N]),
+ static_cast<typename OType::type>(val)));
+ } else {
+ small[idx + m0 * N] = OType::to(val);
}
}
+ }
+ }
+}
- // Shared memory block bx * by. Reduction is along by. Final result is in tidy=0
- if (by > 1) {
- // Fix bx to avoid bank conflicts. Assumes warpSize number of banks
- const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx;
- const int it0 = tidx + tidy*fbx;
- shTile[it0 * 2] = val;
- shTile[it0 * 2 + 1] = residual;
- __syncthreads();
- for (int t=1;t < by;t <<= 1) {
- AType tmp, tmp_residual;
- REDUCER::SetInitValue(tmp, tmp_residual);
- if (tidy + t < by) {
- tmp = shTile[(it0 + t*fbx) * 2];
- tmp_residual = shTile[(it0 + t*fbx) * 2 + 1];
- }
- __syncthreads();
- REDUCER::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual);
- __syncthreads();
- }
- if (idx < N && tidy == 0) {
- REDUCER::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
- if (addto) {
- small[idx + m0 * N] = OType::to(op::add(OType::from(small[idx + m0 * N]),
- shTile[tidx * 2]));
- } else {
- small[idx + m0 * N] = OType::to(shTile[tidx * 2]);
- }
- }
- } else {
- if (idx < N) {
- REDUCER::Finalize(val, residual);
- if (addto) {
- small[idx + m0 * N] = OType::to(op::add(OType::from(small[idx + m0 * N]),
- val));
- } else {
- small[idx + m0 * N] = OType::to(val);
- }
- }
+__launch_bounds__(kRTCMaxThreadsPerBlock)
+__global__ void reduce_kernel_multi(const int N, const int M,
+ const InputType0* __restrict big,
+ const InputType1* __restrict lhs,
+ const InputType2* __restrict rhs,
+ AType *small,
+ const reduce_kernel_params params,
+ const int Mnext) {
+ const int tid = threadIdx.x + threadIdx.y*blockDim.x;
+ const int bx = (do_transpose) ? blockDim.y : blockDim.x;
+ const int by = (do_transpose) ? blockDim.x : blockDim.y;
+ const int tidx = (do_transpose) ? tid / by : threadIdx.x;
+ const int tidy = (do_transpose) ? tid % by : threadIdx.y;
+ for (int m0 = blockIdx.y; m0 < Mnext; m0 += gridDim.y) {
+ // This TB handles M range [Mstart, ...., Mend - 1]
+ const index_t Mstart = (index_t)((int64)M*(int64)m0/(int64)Mnext);
+ const index_t Mend = (index_t)((int64)M*(int64)(m0 + 1)/(int64)Mnext);
+ for (index_t idx0 = blockIdx.x*bx; idx0 < N; idx0 += bx*gridDim.x) {
+ const index_t idx = idx0 + tidx;
+ AType val = reduce(idx, tidx, tidy, N, Mstart, Mend, big, lhs, rhs, params);
+ if (idx < N && (by == 1 || tidy == 0)) {
+ small[idx + m0 * N] = val;
}
}
}
}
+
)code";
const char reduce_lines_kernel_code[] = R"code(
__launch_bounds__(kRTCMaxThreadsPerBlock)
__global__ void reduce_lines_kernel(const index_t N, const index_t M,
const index_t small_in_stride,
- const OutputType0* __restrict small_in,
+ const AType* __restrict small_in,
OutputType0 *small_out) {
using OType = AccType<OutputType0>;
for (index_t idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
- typename OType::type val, residual;
- REDUCER::SetInitValue(val, residual);
+ AType val, residual;
+ REDUCER.SetInitValue(val, residual);
for (int k = 0; k < M; k++) {
- REDUCER::Reduce(val,
- OType::from(reinterpret_cast<const OutputType0*>(small_in)[idx + k*small_in_stride]),
+ REDUCER.Reduce(val,
+ small_in[idx + k*small_in_stride],
residual);
}
if (idx < N) {
- REDUCER::Finalize(val, residual);
+ REDUCER.Finalize(val, residual);
if (req == OpReqType::kAddTo) {
- small_out[idx] = OType::to(op::add(OType::from(small_out[idx]), val));
+ small_out[idx] = OType::to(op::add(OType::from(small_out[idx]),
+ static_cast<typename OType::type>(val)));
} else {
small_out[idx] = OType::to(val);
}
}
-
}
}
)code";
@@ -215,14 +285,13 @@ void RTCReduceImpl(Stream<gpu> *s, const TBlob& small, const bool addto,
const TBlob& big, const Tensor<gpu, 1, char>& workspace,
const ReduceImplConfig& config, const int ndim,
const std::string &common_code, int dev_id,
- const TBlob *lhs = nullptr, const TBlob *rhs = nullptr) {
+ const TBlob *lhs = nullptr, const TBlob *rhs = nullptr,
+ const bool use_index = false) {
using namespace common::cuda::rtc;
void* small_dptr = small.dptr_;
- bool first_kernel_addto = addto;
if (config.Mnext > 1) {
// small_dptr[] is N*Mnext*sizeof(DType) bytes
small_dptr = workspace.dptr_;
- first_kernel_addto = false;
// Check that the workspace is contigiuous
CHECK_EQ(workspace.CheckContiguous(), true);
// Check that we have enough storage
@@ -281,7 +350,6 @@ void RTCReduceImpl(Stream<gpu> *s, const TBlob& small, const bool addto,
std::vector<const void*> args;
args.emplace_back(&config.N);
args.emplace_back(&config.M);
- args.emplace_back(&first_kernel_addto);
args.emplace_back(&big.dptr_);
if (lhs != nullptr) {
args.emplace_back(&(lhs->dptr_));
@@ -295,10 +363,11 @@ void RTCReduceImpl(Stream<gpu> *s, const TBlob& small, const bool addto,
args.emplace_back(&config.Mnext);
const auto &function_code = (lhs == nullptr)
- ? reduce_function_code
+ ? (use_index ? reduce_function_index_code : reduce_function_code)
: reduce_function_use_input_code;
+ const auto& kernel_name = (config.Mnext > 1) ? "reduce_kernel_multi" : "reduce_kernel_single";
auto reduce_kernel_func = get_function(code + function_code,
- "reduce_kernel",
+ kernel_name,
reduce_kernel_code,
dev_id);
launch(reduce_kernel_func, config.kernel_1.gridDim,
@@ -313,7 +382,7 @@ void RTCReduceImpl(Stream<gpu> *s, const TBlob& small, const bool addto,
args.emplace_back(&small_dptr);
args.emplace_back(&small.dptr_);
- auto reduce_lines_kernel_func = get_function(code,
+ auto reduce_lines_kernel_func = get_function(code + function_code,
"reduce_lines_kernel",
reduce_lines_kernel_code,
dev_id);
@@ -348,9 +417,11 @@ __global__ void reduce_kernel_M1(const int N,
using IType1 = AccType<InputType1>;
using IType2 = AccType<InputType2>;
using OType = AccType<OutputType0>;
- for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
+ for (index_t index = threadIdx.x + blockIdx.x*blockDim.x;
+ index < N;
+ index += blockDim.x*gridDim.x) {
index_t coord[ndim];
- util::unravel(idx, params.small_shape, coord);
+ util::unravel(index, params.small_shape, coord);
index_t idx_big[1];
idx_big[0] = util::ravel(coord, params.big_shape);
index_t idx_lhs[1], idx_rhs[1];
@@ -358,16 +429,17 @@ __global__ void reduce_kernel_M1(const int N,
idx_lhs[0] = util::ravel(coord, params.lhs_shape);
idx_rhs[0] = util::ravel(coord, params.rhs_shape);
}
- typename OType::type val, residual;
- REDUCER::SetInitValue(val, residual);
+ AType val, residual;
+ REDUCER.SetInitValue(val, residual);
const int u = 0;
- REDUCER::Reduce(val, FUNC, residual);
- REDUCER::Finalize(val, residual);
+ REDUCER.Reduce(val, FUNC, residual);
+ REDUCER.Finalize(val, residual);
if (req == OpReqType::kAddTo) {
- const auto temp = op::add(val, OType::from(small[idx]));
- small[idx] = OType::to(temp);
+ const auto temp = op::add(static_cast<typename OType::type>(val),
+ OType::from(small[index]));
+ small[index] = OType::to(temp);
} else {
- small[idx] = OType::to(val);
+ small[index] = OType::to(static_cast<typename OType::type>(val));
}
}
}
@@ -376,7 +448,8 @@ __global__ void reduce_kernel_M1(const int N,
void RTCReduceM1Impl(Stream<gpu> *s, const TBlob &small, const TBlob &big,
const TBlob *lhs, const TBlob *rhs,
const ReduceImplConfig &config, const int ndim,
- const std::string &common_code, int dev_id) {
+ const std::string &common_code, int dev_id,
+ const bool use_index = false) {
using namespace common::cuda::rtc;
std::string code = common_code +
@@ -427,7 +500,7 @@ void RTCReduceM1Impl(Stream<gpu> *s, const TBlob &small, const TBlob &big,
args.emplace_back(¶m);
const auto &function_code = (lhs == nullptr)
- ? reduce_function_code
+ ? (use_index ? reduce_function_index_code : reduce_function_code)
: reduce_function_use_input_code;
auto reduce_kernel_M1_func = get_function(code + function_code,
"reduce_kernel_M1",
@@ -447,14 +520,12 @@ void RTCReduce(const OpContext& ctx,
const TBlob& big,
const std::string& reducer,
int ndim,
- const std::string& OP) {
+ const std::string& OP,
+ const bool use_index) {
using namespace mxnet::common::cuda::rtc;
if (req == kNullOp) return;
Stream<gpu> *s = ctx.get_stream<gpu>();
- size_t big_type_size = common::mshadow_type_info(big.type_flag_).acc_size;
- size_t small_type_size = common::mshadow_type_info(small.type_flag_).acc_size;
- size_t type_size = std::max(big_type_size, small_type_size);
- ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr, type_size);
+ ReduceImplConfig config(small.shape_, big.shape_, nullptr, nullptr);
std::string common_code = std::string("const OpReqType req = ") +
util::to_string(req) +
";\n"
@@ -469,10 +540,12 @@ void RTCReduce(const OpContext& ctx,
";\n";
if (config.M == 1) {
RTCReduceM1Impl(s, small, big, nullptr, nullptr, config,
- ndim, common_code, ctx.run_ctx.ctx.dev_id);
+ ndim, common_code, ctx.run_ctx.ctx.dev_id,
+ use_index);
} else {
RTCReduceImpl(s, small, req == kAddTo, big, workspace, config,
- ndim, common_code, ctx.run_ctx.ctx.dev_id);
+ ndim, common_code, ctx.run_ctx.ctx.dev_id,
+ nullptr, nullptr, use_index);
}
}
@@ -490,10 +563,7 @@ void RTCReduce(const OpContext& ctx,
using namespace mxnet::common::cuda::rtc;
if (req == kNullOp) return;
Stream<gpu> *s = ctx.get_stream<gpu>();
- size_t big_type_size = common::mshadow_type_info(big.type_flag_).acc_size;
- size_t small_type_size = common::mshadow_type_info(small.type_flag_).acc_size;
- size_t type_size = std::max(big_type_size, small_type_size);
- ReduceImplConfig config(small.shape_, big.shape_, &lhs.shape_, &rhs.shape_, type_size);
+ ReduceImplConfig config(small.shape_, big.shape_, &lhs.shape_, &rhs.shape_);
std::string common_code = std::string("const OpReqType req = ") +
util::to_string(req) +
";\n"
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index ba8e327..1fc7b8e 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -3006,8 +3006,10 @@ def test_np_binary_funcs():
if isinstance(dtype, tuple):
assert len(dtype) == 2
ldtype, rdtype = dtype
- np_test_x1 = _np.random.uniform(low, high, lshape).astype(ldtype)
- np_test_x2 = _np.random.uniform(low, high, rshape).astype(rdtype)
+ npldtype = ldtype if dtype != _np.float16 else _np.float32
+ nprdtype = rdtype if dtype != _np.float16 else _np.float32
+ np_test_x1 = _np.random.uniform(low, high, lshape).astype(ldtype).astype(npldtype)
+ np_test_x2 = _np.random.uniform(low, high, rshape).astype(rdtype).astype(nprdtype)
mx_test_x1 = mx.numpy.array(np_test_x1, dtype=ldtype)
mx_test_x2 = mx.numpy.array(np_test_x2, dtype=rdtype)
for hybridize in [True, False]:
@@ -4372,7 +4374,7 @@ def test_np_argmin_argmax():
((3, 5, 7), 2, False),
((3, 5, 7, 9, 11), -3, False),
]
- dtypes = ['float16', 'float32', 'float64']
+ dtypes = ['float16', 'float32', 'float64', 'bool', 'int32']
ops = ['argmin', 'argmax']
class TestArgExtreme(HybridBlock):
@@ -4387,7 +4389,7 @@ def test_np_argmin_argmax():
for op_name in ops:
for shape, axis, throw_exception in workloads:
for dtype in dtypes:
- a = np.random.uniform(size=shape, dtype=dtype)
+ a = np.random.uniform(low=0, high=100, size=shape).astype(dtype)
if throw_exception:
# Cannot use assert_exception because sometimes the main thread
# proceeds to `assert False` before the exception is thrown