You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ke...@apache.org on 2019/01/14 22:33:36 UTC

[incubator-mxnet] branch master updated: Fix launch bounds in spatial transformer (#13188)

This is an automated email from the ASF dual-hosted git repository.

kellen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0faa5b7  Fix launch bounds in spatial transformer (#13188)
0faa5b7 is described below

commit 0faa5b72c8a94f53219b3da3144892ac415d5ea4
Author: Przemyslaw Tredak <pt...@gmail.com>
AuthorDate: Mon Jan 14 14:33:17 2019 -0800

    Fix launch bounds in spatial transformer (#13188)
    
    * Fix launch bounds in spatial transformer
    
    * Adding explanation in comment.
---
 src/operator/spatial_transformer.cu | 43 +++++++++++++++++++++++++++----------
 1 file changed, 32 insertions(+), 11 deletions(-)

diff --git a/src/operator/spatial_transformer.cu b/src/operator/spatial_transformer.cu
index 33dbe3e..fd330bd 100644
--- a/src/operator/spatial_transformer.cu
+++ b/src/operator/spatial_transformer.cu
@@ -35,12 +35,23 @@ template<typename DType>
 __device__ bool between(DType value, int lowerBound, int upperBound) {
   return (value >= lowerBound && value <= upperBound);
 }
+
 template<typename DType>
-__global__ void BilinearSamplingForwardKernel(const int i_c, const int i_h,
-                                              const int i_w, const DType* data,
-                                              const DType* grid, const int o_n,
-                                              const int o_c, const int o_h,
-                                              const int o_w, DType* out) {
+__global__ void
+/*
+ * In order to not generate the code that uses too many
+ * registers (resulting in too many resources requested
+ * error) we need to tell the compiler that we will be
+ * launching this kernel with cuda::kMaxThreadsPerBlock
+ * threads per block. Setting __launch_bounds__ ensures
+ * that such configuration can always be launched.
+ */
+__launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
+BilinearSamplingForwardKernel(const int i_c, const int i_h,
+                              const int i_w, const DType* data,
+                              const DType* grid, const int o_n,
+                              const int o_c, const int o_h,
+                              const int o_w, DType* out) {
   for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
        index < o_n * o_c * o_h * o_w;
        index += blockDim.x * gridDim.x * gridDim.y) {
@@ -77,13 +88,23 @@ __global__ void BilinearSamplingForwardKernel(const int i_c, const int i_h,
     }
 }
 
+/*
+ * In order to not generate the code that uses too many
+ * registers (resulting in too many resources requested
+ * error) we need to tell the compiler that we will be
+ * launching this kernel with cuda::kMaxThreadsPerBlock
+ * threads per block. Setting __launch_bounds__ ensures
+ * that such configuration can always be launched.
+ */
 template<typename DType>
-__global__ void BilinearSamplingBackwardKernel(const int i_c, const int i_h,
-                                              const int i_w, const DType* grad,
-                                              const DType* data, const int o_n,
-                                              const int o_c, const int o_h,
-                                              const int o_w, DType* g_input,
-                                              DType* grid_src) {
+__global__ void
+__launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
+BilinearSamplingBackwardKernel(const int i_c, const int i_h,
+                               const int i_w, const DType* grad,
+                               const DType* data, const int o_n,
+                               const int o_c, const int o_h,
+                               const int o_w, DType* g_input,
+                               DType* grid_src) {
   for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
        index < o_n * o_h * o_w;
        index += blockDim.x * gridDim.x * gridDim.y) {