You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2019/02/06 01:04:33 UTC

[incubator-mxnet] branch master updated: Fix performance regression in normalize operator (#14055)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new df4a4fd  Fix performance regression in normalize operator (#14055)
df4a4fd is described below

commit df4a4fdbfc2a0577833641077a39e0237cdcf4af
Author: Sandeep Krishnamurthy <sa...@gmail.com>
AuthorDate: Tue Feb 5 17:04:07 2019 -0800

    Fix performance regression in normalize operator (#14055)
    
    * parallelize on channel forward pass
    
    * parallelize on channel normalize backward pass
    
    * Fix lint issues
    
    * Trying to fix CI build failure on GPU
    
    * Fix failing GPU test on CI Do not pass normalize param as is to GPU kernel
    
    * Fix to_tensor tests
    
    * Pass mean and std_dev as native types for kernel
    
    * Fix CI failure. Do not pass mean, std as vector to kernel
---
 src/operator/image/image_random-inl.h     | 136 +++++++++++++++++++++---------
 tests/python/gpu/test_gluon_transforms.py |  33 ++------
 2 files changed, 106 insertions(+), 63 deletions(-)

diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h
index c9dd85a..4480163 100644
--- a/src/operator/image/image_random-inl.h
+++ b/src/operator/image/image_random-inl.h
@@ -217,37 +217,50 @@ inline bool NormalizeOpType(const nnvm::NodeAttrs& attrs,
 template<int req>
 struct normalize_forward {
     template<typename DType>
-    MSHADOW_XINLINE static void Map(int j, DType* out_data, const DType* in_data,
-                                    const int i, const int length, const int step,
-                                    const DType mean, const DType std_dev) {
-        KERNEL_ASSIGN(out_data[step + i*length + j], req,
-                      (in_data[step + i*length + j] - mean) / std_dev);
+    MSHADOW_XINLINE static void Map(uint32_t c, DType* out_data, const DType* in_data,
+                                    const float mean_d0, const float mean_d1, const float mean_d2,
+                                    const float std_d0, const float std_d1, const float std_d2,
+                                    const int length, const int step) {
+        float mean, std;
+        switch (c) {
+          case 0 : mean = mean_d0;
+                   std = std_d0;
+                   break;
+          case 1 : mean = mean_d1;
+                   std = std_d1;
+                   break;
+          case 2 : mean = mean_d2;
+                   std = std_d2;
+                   break;
+        }
+        #pragma omp parallel for
+        for (int i = 0; i < length; ++i) {
+          KERNEL_ASSIGN(out_data[step + c*length + i], req,
+                        (in_data[step + c*length + i] - mean) / std);
+        }
     }
 };
 
 template<typename xpu>
 void NormalizeImpl(const OpContext &ctx,
-                          const std::vector<TBlob> &inputs,
-                          const std::vector<TBlob> &outputs,
-                          const std::vector<OpReqType> &req,
-                          const NormalizeParam &param,
-                          const int length,
-                          const uint32_t channel,
-                          const int step = 0) {
+                   const std::vector<TBlob> &inputs,
+                   const std::vector<TBlob> &outputs,
+                   const std::vector<OpReqType> &req,
+                   const float mean_d0, const float mean_d1,
+                   const float mean_d2, const float std_d0,
+                   const float std_d1, const float std_d2,
+                   const int length,
+                   const uint32_t channel,
+                   const int step = 0) {
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
 
     MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
         DType* input = inputs[0].dptr<DType>();
         DType* output = outputs[0].dptr<DType>();
-
-        for (uint32_t i = 0; i < channel; ++i) {
-            DType mean = param.mean[param.mean.ndim() > i ? i : 0];
-            DType std_dev = param.std[param.std.ndim() > i ? i : 0];
-            mxnet_op::Kernel<normalize_forward<req_type>, xpu>::Launch(
-                s, length, output, input,
-                i, length, step, mean, std_dev);
-        }
+        mxnet_op::Kernel<normalize_forward<req_type>, xpu>::Launch(
+            s, channel, output, input, mean_d0, mean_d1, mean_d2,
+            std_d0, std_d1, std_d2, length, step);
       });
     });
 }
@@ -264,11 +277,35 @@ void NormalizeOpForward(const nnvm::NodeAttrs &attrs,
 
   const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);
 
+  // Note: We need mean and std_dev in the kernel.
+  // It is costly (device copy) to pass it as vector, for gpu kernel.
+  // Hence, passing it as below for performance.
+  float mean_d0, mean_d1, mean_d2;
+  float std_d0, std_d1, std_d2;
+
+  // Mean and Std can be 1 or 3 D only.
+  if (param.mean.ndim() == 1) {
+    mean_d0 = mean_d1 = mean_d2 = param.mean[0];
+  } else {
+    mean_d0 = param.mean[0];
+    mean_d1 = param.mean[1];
+    mean_d2 = param.mean[2];
+  }
+
+  if (param.std.ndim() == 1) {
+    std_d0 = std_d1 = std_d2 = param.std[0];
+  } else {
+    std_d0 = param.std[0];
+    std_d1 = param.std[1];
+    std_d2 = param.std[2];
+  }
+
   // 3D input (c, h, w)
   if (inputs[0].ndim() == 3) {
     const int length = inputs[0].shape_[1] * inputs[0].shape_[2];
     const uint32_t channel = inputs[0].shape_[0];
-    NormalizeImpl<xpu>(ctx, inputs, outputs, req, param, length, channel);
+    NormalizeImpl<xpu>(ctx, inputs, outputs, req, mean_d0, mean_d1, mean_d2,
+                       std_d0, std_d1, std_d2, length, channel);
   } else if (inputs[0].ndim() == 4) {
     // 4D input (n, c, h, w)
     const int batch_size = inputs[0].shape_[0];
@@ -278,7 +315,8 @@ void NormalizeOpForward(const nnvm::NodeAttrs &attrs,
 
     #pragma omp parallel for
     for (auto n = 0; n < batch_size; ++n) {
-      NormalizeImpl<xpu>(ctx, inputs, outputs, req, param, length, channel, n*step);
+      NormalizeImpl<xpu>(ctx, inputs, outputs, req, mean_d0, mean_d1, mean_d2,
+                       std_d0, std_d1, std_d2, length, channel, n*step);
     }
   }
 }
@@ -287,12 +325,25 @@ void NormalizeOpForward(const nnvm::NodeAttrs &attrs,
 template<int req>
 struct normalize_backward {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int j, DType* in_grad, const DType* out_grad,
-                                  const int i, const int length,
-                                  const int step, const DType std_dev) {
+  MSHADOW_XINLINE static void Map(uint32_t c, DType* in_grad, const DType* out_grad,
+                                  const float std_d0, const float std_d1, const float std_d2,
+                                  const int length, const int step) {
     // d/dx{(x - mean) / std_dev} => (1 / std_dev)
-    KERNEL_ASSIGN(in_grad[step + i*length + j], req,
-                  out_grad[step + i*length + j] * (1.0 / std_dev));
+    float std_dev;
+    switch (c) {
+        case 0 : std_dev = std_d0;
+                 break;
+        case 1 : std_dev = std_d1;
+                 break;
+        case 2 : std_dev = std_d2;
+                 break;
+    }
+
+    #pragma omp parallel for
+    for (int i = 0; i < length; ++i) {
+      KERNEL_ASSIGN(in_grad[step + c*length + i], req,
+                    out_grad[step + c*length + i] * (1.0 / std_dev));
+    }
   }
 };
 
@@ -301,21 +352,18 @@ void NormalizeBackwardImpl(const OpContext &ctx,
                            const std::vector<TBlob> &inputs,
                            const std::vector<TBlob> &outputs,
                            const std::vector<OpReqType> &req,
-                           const NormalizeParam &param,
+                           const float std_d0, const float std_d1, const float std_d2,
                            const int length,
                            const uint32_t channel,
                            const int step = 0) {
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    const TBlob& out_grad = inputs[0];
-    const TBlob& in_grad = outputs[0];
+
     MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
-        for (uint32_t i = 0; i < channel; ++i) {
-            DType std_dev = param.std[param.std.ndim() > i ? i : 0];
-            mxnet_op::Kernel<normalize_backward<req_type>, xpu>::Launch(
-                s, length, in_grad.dptr<DType>(), out_grad.dptr<DType>(),
-                i, length, step, std_dev);
-        }
+        DType* out_grad = inputs[0].dptr<DType>();
+        DType* in_grad = outputs[0].dptr<DType>();
+        mxnet_op::Kernel<normalize_backward<req_type>, xpu>::Launch(
+            s, channel, in_grad, out_grad, std_d0, std_d1, std_d2, length, step);
       });
     });
 }
@@ -331,6 +379,16 @@ void NormalizeOpBackward(const nnvm::NodeAttrs &attrs,
   CHECK_EQ(req.size(), 1U);
 
   const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);
+  float std_d0, std_d1, std_d2;
+
+  // Std can be 1 or 3 D only
+  if (param.std.ndim() == 1) {
+    std_d0 = std_d1 = std_d2 = param.std[0];
+  } else {
+    std_d0 = param.std[0];
+    std_d1 = param.std[1];
+    std_d2 = param.std[2];
+  }
 
   // Note: inputs[0] is out_grad
   const TBlob& in_data = inputs[1];
@@ -339,7 +397,7 @@ void NormalizeOpBackward(const nnvm::NodeAttrs &attrs,
   if (in_data.ndim() == 3) {
     const int length = in_data.shape_[1] * in_data.shape_[2];
     const uint32_t channel = in_data.shape_[0];
-    NormalizeBackwardImpl<xpu>(ctx, inputs, outputs, req, param, length, channel);
+    NormalizeBackwardImpl<xpu>(ctx, inputs, outputs, req, std_d0, std_d1, std_d2, length, channel);
   } else if (in_data.ndim() == 4) {
     // 4D input (n, c, h, w)
     const int batch_size = in_data.shape_[0];
@@ -349,7 +407,9 @@ void NormalizeOpBackward(const nnvm::NodeAttrs &attrs,
 
     #pragma omp parallel for
     for (auto n = 0; n < batch_size; ++n) {
-      NormalizeBackwardImpl<xpu>(ctx, inputs, outputs, req, param, length, channel, n*step);
+      NormalizeBackwardImpl<xpu>(ctx, inputs, outputs, req,
+                                 std_d0, std_d1, std_d2, length,
+                                 channel, n*step);
     }
   }
 }
diff --git a/tests/python/gpu/test_gluon_transforms.py b/tests/python/gpu/test_gluon_transforms.py
index 3927d4c..23b34d3 100644
--- a/tests/python/gpu/test_gluon_transforms.py
+++ b/tests/python/gpu/test_gluon_transforms.py
@@ -80,32 +80,15 @@ def test_to_tensor():
         data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1)))
 
     # 4D Input
-    data_in_4d = nd.random.uniform(0, 1, (2, 3, 300, 300))
-    out_nd_4d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_4d)
-    data_expected_4d = data_in_4d.asnumpy()
-    data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0
-    data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0
-    data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0
-    data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0
-    data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0
-    data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0
-    assert_almost_equal(data_expected_4d, out_nd_4d.asnumpy())
-
-    # Default normalize values i.e., mean=0, std=1
-    data_in_3d_def = nd.random.uniform(0, 1, (3, 300, 300))
-    out_nd_3d_def = transforms.Normalize()(data_in_3d_def)
-    data_expected_3d_def = data_in_3d_def.asnumpy()
-    assert_almost_equal(data_expected_3d_def, out_nd_3d_def.asnumpy())
-
-    # Invalid Input - Neither 3D or 4D input
-    invalid_data_in = nd.random.uniform(0, 1, (5, 5, 3, 300, 300))
-    normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
-    assertRaises(MXNetError, normalize_transformer, invalid_data_in)
+    data_in = np.random.uniform(0, 255, (5, 300, 300, 3)).astype(dtype=np.uint8)
+    out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
+    assert_almost_equal(out_nd.asnumpy(), np.transpose(
+                        data_in.astype(dtype=np.float32) / 255.0, (0, 3, 1, 2)))
 
-    # Invalid Input - Channel neither 1 or 3
-    invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300))
-    normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
-    assertRaises(MXNetError, normalize_transformer, invalid_data_in)
+    # Invalid Input
+    invalid_data_in = nd.random.uniform(0, 255, (5, 5, 300, 300, 3)).astype(dtype=np.uint8)
+    transformer = transforms.ToTensor()
+    assertRaises(MXNetError, transformer, invalid_data_in)
 
 @with_seed()
 def test_resize():