You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/11 17:20:43 UTC

[GitHub] anirudh2290 closed pull request #12386: [MXNET-810] Add support for more req patterns for bilinear sampler backward

anirudh2290 closed pull request #12386: [MXNET-810] Add support for more req patterns for bilinear sampler backward
URL: https://github.com/apache/incubator-mxnet/pull/12386
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/bilinear_sampler-inl.h b/src/operator/bilinear_sampler-inl.h
index e0b4db7b367..499d2339620 100644
--- a/src/operator/bilinear_sampler-inl.h
+++ b/src/operator/bilinear_sampler-inl.h
@@ -95,19 +95,16 @@ class BilinearSamplerOp : public Operator {
     Tensor<xpu, 4, DType> gdata = in_grad[bs::kData].get<xpu, 4, DType>(s);
     Tensor<xpu, 4, DType> ggrid = in_grad[bs::kGrid].get<xpu, 4, DType>(s);
     Tensor<xpu, 4, DType> grad = out_grad[bs::kOut].get<xpu, 4, DType>(s);
-    if (req[bs::kData] != kNullOp && req[bs::kGrid] != kNullOp) {
+    if (req[bs::kData] == kNullOp && req[bs::kGrid] == kNullOp) {
+      return;
+    } else {
       if (req[bs::kData] == kWriteTo) {
         gdata = scalar<DType>(0.0f);
       }
       if (req[bs::kGrid] == kWriteTo) {
         ggrid = scalar<DType>(0.0f);
       }
-      BilinearSamplerBackward(gdata, ggrid, grad, data, grid);
-    } else if (req[bs::kData] == kNullOp && req[bs::kGrid] == kNullOp) {
-      return;
-    } else {
-      LOG(FATAL) << "Have not implemented the data req combinations! gdata_req="
-                 << req[bs::kData] << " ggrid_req=" << req[bs::kGrid];
+      BilinearSamplerBackward(gdata, ggrid, grad, data, grid, req[bs::kData], req[bs::kGrid]);
     }
   }
 
diff --git a/src/operator/bilinear_sampler.cc b/src/operator/bilinear_sampler.cc
index 3365d98bb4d..a3b7d576424 100644
--- a/src/operator/bilinear_sampler.cc
+++ b/src/operator/bilinear_sampler.cc
@@ -78,10 +78,12 @@ inline void BilinearSamplerForward(const Tensor<cpu, 4, DType> &output,
 
 template<typename DType>
 inline void BilinearSamplerBackward(const Tensor<cpu, 4, DType> &gdata,
-                                     const Tensor<cpu, 4, DType> &ggrid,
-                                     const Tensor<cpu, 4, DType> &output_grad,
-                                     const Tensor<cpu, 4, DType> &input_data,
-                                     const Tensor<cpu, 4, DType> &grid) {
+                                    const Tensor<cpu, 4, DType> &ggrid,
+                                    const Tensor<cpu, 4, DType> &output_grad,
+                                    const Tensor<cpu, 4, DType> &input_data,
+                                    const Tensor<cpu, 4, DType> &grid,
+                                    const mxnet::OpReqType data_req,
+                                    const mxnet::OpReqType grid_req) {
   DType *g_input = gdata.dptr_;
   DType *grad_grid = ggrid.dptr_;
   const DType *grid_src = grid.dptr_;
@@ -104,8 +106,7 @@ inline void BilinearSamplerBackward(const Tensor<cpu, 4, DType> &gdata,
           DType top_left_x_w = 1.0 - (x_real - top_left_x);
           for (index_t c = 0; c < static_cast<index_t>(o_c); ++c) {
             index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
-            int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w
-                                  + top_left_x;
+            int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x;
             // calc 4 vertex value in input data
             DType top_left_v = 0;
             DType top_right_v = 0;
@@ -113,22 +114,30 @@ inline void BilinearSamplerBackward(const Tensor<cpu, 4, DType> &gdata,
             DType bottom_right_v = 0;
             // calc input grad
             if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
-              *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w;
+              if (data_req != mxnet::kNullOp) {
+                *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w;
+              }
               top_left_v = *(data + data_index);
             }
             if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
-              *(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w
-                                              * (1.0 - top_left_x_w);
+              if (data_req != mxnet::kNullOp) {
+                *(g_input + data_index + 1) +=
+                  *(grad + grad_index) * top_left_y_w * (1.0 - top_left_x_w);
+              }
               top_right_v = *(data + data_index + 1);
             }
             if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
-              *(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w)
-                                              * top_left_x_w;
+              if (data_req != mxnet::kNullOp) {
+                *(g_input + data_index+ i_w) +=
+                  *(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w;
+              }
               bottom_left_v = *(data + data_index + i_w);
             }
             if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
-              *(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w)
-                                                  * (1.0 - top_left_x_w);
+              if (data_req != mxnet::kNullOp) {
+                *(g_input + data_index+ i_w + 1) +=
+                  *(grad + grad_index) * (1.0 - top_left_y_w) * (1.0 - top_left_x_w);
+              }
               bottom_right_v = *(data + data_index + i_w + 1);
             }
             // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src
@@ -139,9 +148,11 @@ inline void BilinearSamplerBackward(const Tensor<cpu, 4, DType> &gdata,
                               (top_left_v - top_right_v - bottom_left_v + bottom_right_v)
                               * top_left_y_w);
           }
-          // calc grad of grid
-          *(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2;
-          *(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2;
+          if (grid_req != mxnet::kNullOp) {
+            // calc grad of grid
+            *(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2;
+            *(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2;
+          }
         }
       }
     }
diff --git a/src/operator/bilinear_sampler.cu b/src/operator/bilinear_sampler.cu
index e1f205258a2..2e6be3e1ef3 100644
--- a/src/operator/bilinear_sampler.cu
+++ b/src/operator/bilinear_sampler.cu
@@ -79,7 +79,7 @@ __global__ void BilinearSamplerForwardKernel(const int i_c, const int i_h,
   }
 }
 
-template<typename DType>
+template<typename DType, int Req1, int Req2>
 __global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h,
                                               const int i_w, const DType* grad,
                                               const DType* data, const int o_n,
@@ -114,22 +114,30 @@ __global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h,
       DType bottom_right_v = 0;
       // calc input grad
       if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
-        atomicAdd(&g_input[data_index], *(grad + grad_index) * top_left_y_w * top_left_x_w);
+        if (Req1 != mxnet::kNullOp) {
+          atomicAdd(&g_input[data_index], *(grad + grad_index) * top_left_y_w * top_left_x_w);
+        }
         top_left_v = *(data + data_index);
       }
       if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
-        atomicAdd(&g_input[data_index + 1], *(grad + grad_index) * top_left_y_w
-                                        * (1.0 - top_left_x_w));
+        if (Req1 != mxnet::kNullOp) {
+          atomicAdd(&g_input[data_index + 1],
+                    *(grad + grad_index) * top_left_y_w * (1.0 - top_left_x_w));
+        }
         top_right_v = *(data + data_index + 1);
       }
       if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
-        atomicAdd(&g_input[data_index+ i_w], *(grad + grad_index) * (1.0 - top_left_y_w)
-                                        * top_left_x_w);
+        if (Req1 != mxnet::kNullOp) {
+          atomicAdd(&g_input[data_index+ i_w],
+                    *(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w);
+        }
         bottom_left_v = *(data + data_index + i_w);
       }
       if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
-        atomicAdd(&g_input[data_index+ i_w + 1], *(grad + grad_index) * (1.0 - top_left_y_w)
-                                            * (1.0 - top_left_x_w));
+        if (Req1 != mxnet::kNullOp) {
+          atomicAdd(&g_input[data_index+ i_w + 1],
+                    *(grad + grad_index) * (1.0 - top_left_y_w) * (1.0 - top_left_x_w));
+        }
         bottom_right_v = *(data + data_index + i_w + 1);
       }
       // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src
@@ -140,9 +148,11 @@ __global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h,
                         (top_left_v - top_right_v - bottom_left_v + bottom_right_v)
                         * top_left_y_w);
     }
-    // calc grad of grid
-    *(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2;
-    *(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2;
+    if (Req2 != mxnet::kNullOp) {
+      // calc grad of grid
+      *(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2;
+      *(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2;
+    }
   }
 }
 }  // namespace cuda
@@ -174,10 +184,13 @@ inline void BilinearSamplerForward(const Tensor<gpu, 4, DType> &output,
 
 template<typename DType>
 inline void BilinearSamplerBackward(const Tensor<gpu, 4, DType> &input_grad,
-                                     const Tensor<gpu, 4, DType> &ggrid,
-                                     const Tensor<gpu, 4, DType> &output_grad,
-                                     const Tensor<gpu, 4, DType> &input_data,
-                                     const Tensor<gpu, 4, DType> &grid) {
+                                    const Tensor<gpu, 4, DType> &ggrid,
+                                    const Tensor<gpu, 4, DType> &output_grad,
+                                    const Tensor<gpu, 4, DType> &input_data,
+                                    const Tensor<gpu, 4, DType> &grid,
+                                    const mxnet::OpReqType data_req,
+                                    const mxnet::OpReqType grid_req) {
+  using namespace mxnet;
   DType *g_input = input_grad.dptr_;
   DType *grad_grid = ggrid.dptr_;
   const DType *grid_src = grid.dptr_;
@@ -196,8 +209,13 @@ inline void BilinearSamplerBackward(const Tensor<gpu, 4, DType> &input_grad,
   dim3 threads_per_block(kMaxThreadsPerBlock);
   CheckLaunchParam(num_blocks, threads_per_block, "bilinear sampler backward");
   cudaStream_t stream = Stream<gpu>::GetStream(input_grad.stream_);
-  cuda::BilinearSamplerBackwardKernel<DType> << <num_blocks, threads_per_block, 0, stream >> >(
-    i_c, i_h, i_w, grad, data, o_n, o_c, o_h, o_w, g_input, grid_src, grad_grid);
+  MXNET_REQ_TYPE_SWITCH(data_req, Req1, {
+    MXNET_REQ_TYPE_SWITCH(grid_req, Req2, {
+      cuda::BilinearSamplerBackwardKernel<DType, Req1, Req2>
+      <<<num_blocks, threads_per_block, 0, stream >>>(
+        i_c, i_h, i_w, grad, data, o_n, o_c, o_h, o_w, g_input, grid_src, grad_grid);
+    });
+  });
   // post kernel check
   cudaError err = cudaPeekAtLastError();
   CHECK_EQ(err, cudaSuccess) << cudaGetErrorString(err);
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index f11a497c564..e77569671eb 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -111,6 +111,33 @@ inline int get_num_threads<cpu>(const int N) {
   }
 
 
+/*! \brief operator request type switch */
+#define MXNET_REQ_TYPE_SWITCH(req, ReqType, ...)  \
+  switch (req) {                                    \
+  case kNullOp:                                     \
+    {                                               \
+      const OpReqType ReqType = kNullOp;            \
+      {__VA_ARGS__}                                 \
+    }                                               \
+    break;                                          \
+  case kWriteInplace:                               \
+  case kWriteTo:                                    \
+    {                                               \
+      const OpReqType ReqType = kWriteTo;           \
+      {__VA_ARGS__}                                 \
+    }                                               \
+    break;                                          \
+  case kAddTo:                                      \
+    {                                               \
+      const OpReqType ReqType = kAddTo;             \
+      {__VA_ARGS__}                                 \
+    }                                               \
+    break;                                          \
+  default:                                          \
+    break;                                          \
+  }
+
+
 #define MXNET_NDIM_SWITCH(NDim, ndim, ...)         \
   if (NDim == 0) {                                 \
   } else if (NDim == 1) {                          \
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 1fc2c8e922d..d201a2e09c6 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -1945,7 +1945,7 @@ def test_bilinear_sampler_versions():
             exe.arg_dict['data'][:] = test_data
             exe.arg_dict['grid'][:] = test_grid
             exe.forward(is_train=True)
-            assert_almost_equal(exe_list[0].outputs[0].asnumpy(), exe.outputs[0].asnumpy(), rtol=1e-3, atol=1e-5)
+            assert_almost_equal(exe_list[ref_idx].outputs[0].asnumpy(), exe.outputs[0].asnumpy(), rtol=1e-3, atol=1e-5)
 
         out_grad = np.random.uniform(low=-0.01, high=0.01,size=data_shape[:2] + grid_shape[2:]).astype(np.float32)
         for exe in exe_list:
@@ -1975,6 +1975,22 @@ def test_bilinear_sampler_versions():
         assert_almost_equal(exe_list[ref_idx].grad_dict['data'].asnumpy(), data_grad + data_initial_grad, rtol=1e-3, atol=1e-5)
         assert_almost_equal(exe_list[ref_idx].grad_dict['grid'].asnumpy(), grid_grad + grid_initial_grad, rtol=1e-3, atol=1e-5)
 
+        for req_dict in [{'data' : 'null', 'grid' : 'write'}, {'data' : 'write', 'grid' : 'null'}]:
+            # Mixture of kWriteTo and kNullOp
+            exe_cpu_mix = sym1.simple_bind(data=data_shape, grid=grid_shape, ctx=mx.cpu(), grad_req=req_dict)
+            exe_gpu_mix = sym2.simple_bind(data=data_shape, grid=grid_shape, ctx=default_context(), grad_req=req_dict)
+            exe_cudnn_mix = sym3.simple_bind(data=data_shape, grid=grid_shape, ctx=default_context(), grad_req=req_dict)
+            exe_list = [exe_cpu_mix, exe_gpu_mix, exe_cudnn_mix]
+            for exe in exe_list:
+                exe.arg_dict['data'][:] = test_data
+                exe.arg_dict['grid'][:] = test_grid
+                exe.forward(is_train=True)
+                exe.backward(mx.nd.array(out_grad))
+                if req_dict['data'] is 'write':
+                    assert_almost_equal(exe.grad_dict['data'].asnumpy(), exe_list[ref_idx].grad_dict['data'].asnumpy(), rtol=1e-3, atol=1e-5)
+                if req_dict['grid'] is 'write':
+                    assert_almost_equal(exe.grad_dict['grid'].asnumpy(), exe_list[ref_idx].grad_dict['grid'].asnumpy(), rtol=1e-3, atol=1e-5)
+
 
 def test_context_num_gpus():
     # Test that num_gpus reports at least one GPU, as the test is run on a GPU host.


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services