You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ma...@apache.org on 2019/07/21 23:26:11 UTC

[incubator-mxnet] branch master updated: Add omp parallel optimization for _contrib_BilinearReisze2D (#15584)

This is an automated email from the ASF dual-hosted git repository.

marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9a1a102  Add omp parallel optimization for _contrib_BilinearReisze2D (#15584)
9a1a102 is described below

commit 9a1a1028529c91c83f275edb74b272f8ad3dcc3d
Author: Wuxun Zhang <wu...@intel.com>
AuthorDate: Mon Jul 22 07:25:30 2019 +0800

    Add omp parallel optimization for _contrib_BilinearReisze2D (#15584)
    
    * Add omp parallel optimization for bilinear_resize op
    
    * retrigger CI
    
    * retrigger CI
    
    * trigger CI
---
 src/operator/contrib/bilinear_resize.cc | 146 ++++++++++++++++++--------------
 tests/python/gpu/test_operator_gpu.py   |  16 ++++
 tests/python/unittest/test_operator.py  |   2 +-
 3 files changed, 101 insertions(+), 63 deletions(-)

diff --git a/src/operator/contrib/bilinear_resize.cc b/src/operator/contrib/bilinear_resize.cc
index 441ea53..3463247 100644
--- a/src/operator/contrib/bilinear_resize.cc
+++ b/src/operator/contrib/bilinear_resize.cc
@@ -23,7 +23,6 @@
  * \author Hang Zhang
 */
 #include "bilinear_resize-inl.h"
-// #include "elemwise_op_common.h"
 #include "../elemwise_op_common.h"
 
 namespace mxnet {
@@ -44,56 +43,66 @@ void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<cpu> *s,
   int inputHeight = itensor.size(2);
   int inputWidth = itensor.size(3);
 
+  const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+
   DType *idata = itensor.dptr_;
   DType *odata = otensor.dptr_;
   channels = nbatch * channels;
+  const int input_elems_per_channel = inputWidth * inputHeight;
+  const int output_elems_per_channel = outputWidth * outputHeight;
+
   // special case: just copy
   if (inputHeight == outputHeight && inputWidth == outputWidth) {
-    for (int h2 = 0; h2 < outputHeight; ++h2) {
+#pragma omp parallel for num_threads(nthreads)
+    for (int index = 0; index < output_elems_per_channel; index++) {
+      const int h2 = index / outputWidth;
       const int h1 = h2;
-      for (int w2 = 0; w2 < outputWidth; ++w2) {
-        const int w1 = w2;
-        const DType* pos1 = &idata[h1 * inputWidth + w1];
-        DType* pos2 = &odata[h2 * outputWidth + w2];
-        for (int c = 0; c < channels; ++c) {
-          pos2[0] = pos1[0];
-          pos1 += inputWidth * inputHeight;
-          pos2 += outputWidth * outputHeight;
-        }
+      const int w2 = index % outputWidth;
+      const int w1 = w2;
+      const DType* pos1 = &idata[h1 * inputWidth + w1];
+      DType* pos2 = &odata[index];
+      for (int c = 0; c < channels; ++c) {
+        *pos2 = *pos1;
+        pos1 += input_elems_per_channel;
+        pos2 += output_elems_per_channel;
       }
     }
     return;
   }
+
   const float rheight =(outputHeight > 1) ? static_cast<float>(inputHeight - 1)/
                        (outputHeight - 1) : 0.f;
   const float rwidth = (outputWidth > 1) ? static_cast<float>(inputWidth - 1) /
                        (outputWidth - 1) : 0.f;
-  for (int h2 = 0; h2 < outputHeight; ++h2) {
+#pragma omp parallel for num_threads(nthreads)
+  for (int index = 0; index < output_elems_per_channel; index++) {
+    const int h2 = index / outputWidth;
+    const int w2 = index % outputWidth;
+
     const float h1r = rheight * h2;
     const int h1 = h1r;
     const int h1p = (h1 < inputHeight - 1) ? 1 : 0;
     const DType h1lambda = h1r - h1;
     const DType h0lambda = (DType)1. - h1lambda;
-    for (int w2 = 0; w2 < outputWidth; ++w2) {
-      const float w1r = rwidth * w2;
-      const int w1 = w1r;
-      const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
-      const DType w1lambda = w1r - w1;
-      const DType w0lambda = (DType)1. - w1lambda;
-      const DType* pos1 = &idata[h1 * inputWidth + w1];
-      DType* pos2 = &odata[h2 * outputWidth + w2];
-      for (int c = 0; c < channels; ++c) {
-        pos2[0] = h0lambda * (w0lambda * pos1[0]+ w1lambda * pos1[w1p])
-                  + h1lambda * (w0lambda * pos1[h1p * inputWidth]
-                  + w1lambda * pos1[h1p * inputWidth + w1p]);
-        pos1 += inputWidth * inputHeight;
-        pos2 += outputWidth * outputHeight;
-      }
+
+    const float w1r = rwidth * w2;
+    const int w1 = w1r;
+    const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
+    const DType w1lambda = w1r - w1;
+    const DType w0lambda = (DType)1. - w1lambda;
+    const DType* pos1 = &idata[h1 * inputWidth + w1];
+    DType* pos2 = &odata[index];
+
+    for (int c = 0; c < channels; ++c) {
+      *pos2 = h0lambda * (w0lambda * (*pos1) + w1lambda * *(pos1 + w1p))
+                  + h1lambda * (w0lambda * *(pos1 + h1p * inputWidth)
+                  + w1lambda * *(pos1 + h1p * inputWidth + w1p));
+      pos1 += input_elems_per_channel;
+      pos2 += output_elems_per_channel;
     }
   }
 }
 
-
 template<typename xpu, typename DType, typename AccReal>
 void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
                                               const std::vector<TBlob> &input,
@@ -109,23 +118,28 @@ void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
   int inputHeight = gradInput.size(2);
   int inputWidth = gradInput.size(3);
 
+  const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+
   DType *dataInput = gradInput.dptr_;
   DType *dataOutput = gradOutput.dptr_;
   channels = nbatch * channels;
+  const int input_elems_per_channel = inputWidth * inputHeight;
+  const int output_elems_per_channel = outputWidth * outputHeight;
 
   // special case: same-size matching grids
   if (inputHeight == outputHeight && inputWidth == outputWidth) {
-    for (int h2 = 0; h2 < outputHeight; ++h2) {
+#pragma omp parallel for num_threads(nthreads)
+    for (int index = 0; index < output_elems_per_channel; index++) {
+      const int h2 = index / outputWidth;
       const int h1 = h2;
-      for (int w2 = 0; w2 < outputWidth; ++w2) {
-        const int w1 = w2;
-        DType* pos1 = &dataInput[h1 * inputWidth + w1];
-        const DType* pos2 = &dataOutput[h2 * outputWidth + w2];
-        for (int c = 0; c < channels; ++c) {
-          pos1[0] += pos2[0];
-          pos1 += inputWidth * inputHeight;
-          pos2 += outputWidth * outputHeight;
-        }
+      const int w2 = index % outputWidth;
+      const int w1 = w2;
+      DType* pos1 = &dataInput[h1 * inputWidth + w1];
+      const DType* pos2 = &dataOutput[index];
+      for (int c = 0; c < channels; ++c) {
+        *pos1 += *pos2;
+        pos1 += input_elems_per_channel;
+        pos2 += output_elems_per_channel;
       }
     }
     return;
@@ -134,28 +148,36 @@ void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
                        (outputHeight - 1) : 0.f;
   const float rwidth = (outputWidth > 1) ? static_cast<float>(inputWidth - 1)/
                        (outputWidth - 1) : 0.f;
-  for (int h2 = 0; h2 < outputHeight; ++h2) {
+
+#pragma omp parallel for num_threads(nthreads)
+  for (int index = 0; index < output_elems_per_channel; index++) {
+    const int h2 = index / outputWidth;
+    const int w2 = index % outputWidth;
+
     const float h1r = rheight * h2;
     const int h1 = h1r;
     const int h1p = (h1 < inputHeight - 1) ? 1 : 0;
     const DType h1lambda = h1r - h1;
     const DType h0lambda = (DType)1. - h1lambda;
-    for (int w2 = 0; w2 < outputWidth; ++w2) {
-      const float w1r = rwidth * w2;
-      const int w1 = w1r;
-      const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
-      const DType w1lambda = w1r - w1;
-      const DType w0lambda = (DType)1. - w1lambda;
-      DType* posInput = &dataInput[h1 * inputWidth + w1];
-      const DType* posOutput = &dataOutput[h2 * outputWidth + w2];
-      for (int c = 0; c < channels; ++c) {
-        posInput[0] += h0lambda * w0lambda * posOutput[0];
-        posInput[w1p] += h0lambda * w1lambda * posOutput[0];
-        posInput[h1p * inputWidth] += h1lambda * w0lambda * posOutput[0];
-        posInput[h1p * inputWidth + w1p] += h1lambda * w1lambda * posOutput[0];
-        posInput += inputWidth * inputHeight;
-        posOutput += outputWidth * outputHeight;
+
+    const float w1r = rwidth * w2;
+    const int w1 = w1r;
+    const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
+    const DType w1lambda = w1r - w1;
+    const DType w0lambda = (DType)1. - w1lambda;
+
+    DType* posInput = &dataInput[h1 * inputWidth + w1];
+    const DType* posOutput = &dataOutput[index];
+    for (int c = 0; c < channels; ++c) {
+      #pragma omp critical
+      {
+        *posInput += h0lambda * w0lambda * (*posOutput);
+        *(posInput + w1p) += h0lambda * w1lambda * (*posOutput);
+        *(posInput + h1p * inputWidth) += h1lambda * w0lambda * (*posOutput);
+        *(posInput + h1p * inputWidth + w1p) += h1lambda * w1lambda * (*posOutput);
       }
+      posInput += input_elems_per_channel;
+      posOutput += output_elems_per_channel;
     }
   }
 
@@ -165,19 +187,19 @@ void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
     int inputWidthLike = gradInputLike.size(3);
     DType *dataInputLike = gradInputLike.dptr_;
     int channelsLike = nbatch * gradInputLike.size(1);
-    for (int h_like = 0; h_like < inputHeightLike; ++h_like) {
-      for (int w_like = 0; w_like < inputWidthLike; ++w_like) {
-        DType *posInput = &dataInputLike[h_like * inputWidthLike + w_like];
-        for (int c = 0; c < channelsLike; ++c) {
-          posInput[0] = 0;
-          posInput += inputWidthLike * inputHeightLike;
-        }
+
+    const int inputLike_elems_per_channel = inputHeightLike * inputWidthLike;
+#pragma omp parallel for num_threads(nthreads)
+    for (int index = 0; index < inputLike_elems_per_channel; index++) {
+      DType *posInput = &dataInputLike[index];
+      for (int c = 0; c < channelsLike; ++c) {
+        *posInput = 0;
+        posInput += inputLike_elems_per_channel;
       }
     }
   }
 }
 
-
 DMLC_REGISTER_PARAMETER(BilinearSampleParam);
 
 NNVM_REGISTER_OP(_contrib_BilinearResize2D)
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 5b4f81d..f9814ab 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -1144,6 +1144,22 @@ def test_flatten_slice_after_conv():
 
 
 @with_seed()
+def test_bilinear_resize_op():
+    ctx_list = [{'ctx': mx.cpu(0), 'data': (2, 2, 20, 20), 'type_dict': {'data': np.float32}},
+                {'ctx': mx.gpu(0), 'data': (2, 2, 20, 20), 'type_dict': {'data': np.float32}}]
+
+    data = mx.sym.Variable('data')
+    sym = mx.sym.contrib.BilinearResize2D(data, height=10, width=5)
+    check_consistency(sym, ctx_list)
+
+    sym = mx.sym.contrib.BilinearResize2D(data, None, scale_height=2, scale_width=0.5, mode='odd_scale')
+    check_consistency(sym, ctx_list)
+
+    sym = mx.sym.contrib.BilinearResize2D(data, None, scale_height=0.5, scale_width=2, mode='to_even_up')
+    check_consistency(sym, ctx_list)
+
+
+@with_seed()
 def test_global_pooling():
     def test_1d_pooling(pool_type, p_value=2):
         data = (2, 3, 20)
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 915a83f..d195ea9 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -7649,7 +7649,7 @@ def test_bilinear_resize_op():
                 w1r = 1.0 * w2 * rwidth
                 w1 = int(np.floor(w1r))
                 w1lambda = w1r - w1
-                w1p = 1 if w1 < (inputHeight - 1) else 0
+                w1p = 1 if w1 < (inputWidth - 1) else 0
                 for b in range(batch):
                     for c in range(channel):
                         y[b][c][h2][w2] = (1-h1lambda)*((1-w1lambda)*x[b][c][h1][w1] + \