You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2019/01/29 06:09:03 UTC

[incubator-mxnet] branch master updated: Image normalize operator - GPU support, 3D/4D inputs (#13802)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3a1a80a  Image normalize operator - GPU support, 3D/4D inputs (#13802)
3a1a80a is described below

commit 3a1a80a537ac2fc8878425a98741e630ce3ba274
Author: Sandeep Krishnamurthy <sa...@gmail.com>
AuthorDate: Mon Jan 28 22:08:38 2019 -0800

    Image normalize operator - GPU support, 3D/4D inputs (#13802)
    
    * CPU version of normalize operator is working and unit test added
    
    * Add GPU implementation and tests
    
    * Working GPU normalize transforms
    
    * Add default values, fix imports, fix documentation
    
    * Add backward implmentation for image normalize
    
    * Add tests for backward pass
    
    * Move back operators to its original files
    
    * Add review comments
    
    * Add 4D example
    
    * Make infer type generic
    
    * Fix inline function build error
    
    * make functions as inline to avoid multiple definition conflict across cc and cu
    
    * Fix build errors
    
    * Fix failing GPU tests
---
 python/mxnet/gluon/data/vision/transforms.py    |  25 ++-
 src/operator/image/image_random-inl.h           | 222 ++++++++++++++++++++----
 src/operator/image/image_random.cc              |  87 +++++++++-
 src/operator/image/image_random.cu              |  40 +++++
 tests/python/gpu/test_gluon_transforms.py       |  72 ++++++++
 tests/python/unittest/test_gluon_data_vision.py |  41 ++++-
 tests/python/unittest/test_operator.py          |  67 +++++++
 7 files changed, 497 insertions(+), 57 deletions(-)

diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py
index 1750769..2f557f5 100644
--- a/python/mxnet/gluon/data/vision/transforms.py
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -135,7 +135,7 @@ class ToTensor(HybridBlock):
 
 
 class Normalize(HybridBlock):
-    """Normalize an tensor of shape (C x H x W) with mean and
+    """Normalize an tensor of shape (C x H x W) or (N x C x H x W) with mean and
     standard deviation.
 
     Given mean `(m1, ..., mn)` and std `(s1, ..., sn)` for `n` channels,
@@ -154,12 +154,31 @@ class Normalize(HybridBlock):
 
 
     Inputs:
-        - **data**: input tensor with (C x H x W) shape.
+        - **data**: input tensor with (C x H x W) or (N x C x H x W) shape.
 
     Outputs:
         - **out**: output tensor with the shape as `data`.
+
+    Examples
+    --------
+    >>> transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
+    >>> image = mx.nd.random.uniform(0, 1, (3, 4, 2))
+    >>> transformer(image)
+    [[[ 0.18293785  0.19761486]
+      [ 0.23839645  0.28142193]
+      [ 0.20092112  0.28598186]
+      [ 0.18162774  0.28241724]]
+     [[-0.2881726  -0.18821815]
+      [-0.17705294 -0.30780914]
+      [-0.2812064  -0.3512327 ]
+      [-0.05411351 -0.4716435 ]]
+     [[-1.0363373  -1.7273437 ]
+      [-1.6165586  -1.5223348 ]
+      [-1.208275   -1.1878313 ]
+      [-1.4711051  -1.5200229 ]]]
+    <NDArray 3x4x2 @cpu(0)>
     """
-    def __init__(self, mean, std):
+    def __init__(self, mean=0.0, std=1.0):
         super(Normalize, self).__init__()
         self._mean = mean
         self._std = std
diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h
index c64ed28..74807b9 100644
--- a/src/operator/image/image_random-inl.h
+++ b/src/operator/image/image_random-inl.h
@@ -31,7 +31,6 @@
 #include <vector>
 #include <cmath>
 #include <limits>
-#include <algorithm>
 #include <utility>
 #include "../mxnet_op.h"
 #include "../operator_common.h"
@@ -62,7 +61,7 @@ inline bool ToTensorType(const nnvm::NodeAttrs& attrs,
   return (*in_attrs)[0] != -1;
 }
 
-void ToTensor(const nnvm::NodeAttrs &attrs,
+inline void ToTensor(const nnvm::NodeAttrs &attrs,
                      const OpContext &ctx,
                      const std::vector<TBlob> &inputs,
                      const std::vector<OpReqType> &req,
@@ -85,32 +84,53 @@ void ToTensor(const nnvm::NodeAttrs &attrs,
   });
 }
 
+// Normalize Operator
+// Parameter registration for image Normalize operator
 struct NormalizeParam : public dmlc::Parameter<NormalizeParam> {
   nnvm::Tuple<float> mean;
   nnvm::Tuple<float> std;
+
   DMLC_DECLARE_PARAMETER(NormalizeParam) {
     DMLC_DECLARE_FIELD(mean)
-    .describe("Sequence of mean for each channel.");
+    .set_default(nnvm::Tuple<float> {0.0f, 0.0f, 0.0f, 0.0f})
+    .describe("Sequence of means for each channel. "
+              "Default value is 0.");
     DMLC_DECLARE_FIELD(std)
-    .describe("Sequence of standard deviations for each channel.");
+    .set_default(nnvm::Tuple<float> {1.0f, 1.0f, 1.0f, 1.0f})
+    .describe("Sequence of standard deviations for each channel. "
+              "Default value is 1.");
   }
 };
 
-inline bool NormalizeShape(const nnvm::NodeAttrs& attrs,
+// Shape and Type inference for image Normalize operator
+
+// Shape inference
+inline bool NormalizeOpShape(const nnvm::NodeAttrs& attrs,
                           std::vector<TShape> *in_attrs,
                           std::vector<TShape> *out_attrs) {
   const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);
+
   const auto& dshape = (*in_attrs)[0];
   if (!dshape.ndim()) return false;
 
-  CHECK_EQ(dshape.ndim(), 3)
-      << "Input tensor must have shape (channels, height, width), but got "
-      << dshape;
-  auto nchannels = dshape[0];
-  CHECK(nchannels == 3 || nchannels == 1)
+  CHECK((dshape.ndim() == 3) || (dshape.ndim() == 4))
+      << "Input tensor must have shape (channels, height, width), or "
+      << "(N, channels, height, width), but got " << dshape;
+
+  uint32_t nchannels;
+  if (dshape.ndim() == 3) {
+    nchannels = dshape[0];
+    CHECK(nchannels == 3 || nchannels == 1)
       << "The first dimension of input tensor must be the channel dimension with "
       << "either 1 or 3 elements, but got input with shape " << dshape;
-  CHECK(param.mean.ndim() == 1 || param.mean.ndim() == nchannels)
+  } else if (dshape.ndim() == 4) {
+    nchannels = dshape[1];
+    CHECK(nchannels == 3 || nchannels == 1)
+      << "The second dimension of input tensor must be the channel dimension with "
+      << "either 1 or 3 elements, but got input with shape " << dshape;
+  }
+
+  CHECK((param.mean.ndim() == 1) || (param.mean.ndim() == nchannels))
       << "Invalid mean for input with shape " << dshape
       << ". mean must have either 1 or " << nchannels
       << " elements, but got " << param.mean;
@@ -123,28 +143,156 @@ inline bool NormalizeShape(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
-void Normalize(const nnvm::NodeAttrs &attrs,
+// Type Inference
+inline bool NormalizeOpType(const nnvm::NodeAttrs& attrs,
+                          std::vector<int>* in_attrs,
+                          std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+
+  TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
+  TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
+  return out_attrs->at(0) != -1;
+}
+
+template<int req>
+struct normalize_forward {
+    template<typename DType>
+    MSHADOW_XINLINE static void Map(int j, DType* out_data, const DType* in_data,
+                                    const int i, const int length, const int step,
+                                    const DType mean, const DType std_dev) {
+        KERNEL_ASSIGN(out_data[step + i*length + j], req,
+                      (in_data[step + i*length + j] - mean) / std_dev);
+    }
+};
+
+template<typename xpu>
+void NormalizeImpl(const OpContext &ctx,
+                          const std::vector<TBlob> &inputs,
+                          const std::vector<TBlob> &outputs,
+                          const std::vector<OpReqType> &req,
+                          const NormalizeParam &param,
+                          const int length,
+                          const uint32_t channel,
+                          const int step = 0) {
+    mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+
+    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+        DType* input = inputs[0].dptr<DType>();
+        DType* output = outputs[0].dptr<DType>();
+
+        for (uint32_t i = 0; i < channel; ++i) {
+            DType mean = param.mean[param.mean.ndim() > i ? i : 0];
+            DType std_dev = param.std[param.std.ndim() > i ? i : 0];
+            mxnet_op::Kernel<normalize_forward<req_type>, xpu>::Launch(
+                s, length, output, input,
+                i, length, step, mean, std_dev);
+        }
+      });
+    });
+}
+
+template<typename xpu>
+void NormalizeOpForward(const nnvm::NodeAttrs &attrs,
                       const OpContext &ctx,
                       const std::vector<TBlob> &inputs,
                       const std::vector<OpReqType> &req,
                       const std::vector<TBlob> &outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+
   const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);
 
-  int nchannels = inputs[0].shape_[0];
-  int length = inputs[0].shape_[1] * inputs[0].shape_[2];
+  // 3D input (c, h, w)
+  if (inputs[0].ndim() == 3) {
+    const int length = inputs[0].shape_[1] * inputs[0].shape_[2];
+    const uint32_t channel = inputs[0].shape_[0];
+    NormalizeImpl<xpu>(ctx, inputs, outputs, req, param, length, channel);
+  } else if (inputs[0].ndim() == 4) {
+    // 4D input (n, c, h, w)
+    const int batch_size = inputs[0].shape_[0];
+    const int length = inputs[0].shape_[2] * inputs[0].shape_[3];
+    const uint32_t channel = inputs[0].shape_[1];
+    const int step = channel * length;
+
+    #pragma omp parallel for
+    for (auto n = 0; n < batch_size; ++n) {
+      NormalizeImpl<xpu>(ctx, inputs, outputs, req, param, length, channel, n*step);
+    }
+  }
+}
 
-  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-    DType* input = inputs[0].dptr<DType>();
-    DType* output = outputs[0].dptr<DType>();
+// Backward function
+template<int req>
+struct normalize_backward {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int j, DType* in_grad, const DType* out_grad,
+                                  const int i, const int length,
+                                  const int step, const DType std_dev) {
+    // d/dx{(x - mean) / std_dev} => (1 / std_dev)
+    KERNEL_ASSIGN(in_grad[step + i*length + j], req,
+                  out_grad[step + i*length + j] * (1.0 / std_dev));
+  }
+};
 
-    for (int i = 0; i < nchannels; ++i) {
-      DType mean = param.mean[param.mean.ndim() > 1 ? i : 0];
-      DType std = param.std[param.std.ndim() > 1 ? i : 0];
-      for (int j = 0; j < length; ++j) {
-        output[i*length + j] = (input[i*length + j] - mean) / std;
-      }
+template<typename xpu>
+void NormalizeBackwardImpl(const OpContext &ctx,
+                           const std::vector<TBlob> &inputs,
+                           const std::vector<TBlob> &outputs,
+                           const std::vector<OpReqType> &req,
+                           const NormalizeParam &param,
+                           const int length,
+                           const uint32_t channel,
+                           const int step = 0) {
+    mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+    const TBlob& out_grad = inputs[0];
+    const TBlob& in_grad = outputs[0];
+    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+        for (uint32_t i = 0; i < channel; ++i) {
+            DType std_dev = param.std[param.std.ndim() > i ? i : 0];
+            mxnet_op::Kernel<normalize_backward<req_type>, xpu>::Launch(
+                s, length, in_grad.dptr<DType>(), out_grad.dptr<DType>(),
+                i, length, step, std_dev);
+        }
+      });
+    });
+}
+
+template<typename xpu>
+void NormalizeOpBackward(const nnvm::NodeAttrs &attrs,
+                         const OpContext &ctx,
+                         const std::vector<TBlob> &inputs,
+                         const std::vector<OpReqType> &req,
+                         const std::vector<TBlob> &outputs) {
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+
+  const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);
+
+  // Note: inputs[0] is out_grad
+  const TBlob& in_data = inputs[1];
+
+  // 3D input (c, h, w)
+  if (in_data.ndim() == 3) {
+    const int length = in_data.shape_[1] * in_data.shape_[2];
+    const uint32_t channel = in_data.shape_[0];
+    NormalizeBackwardImpl<xpu>(ctx, inputs, outputs, req, param, length, channel);
+  } else if (in_data.ndim() == 4) {
+    // 4D input (n, c, h, w)
+    const int batch_size = in_data.shape_[0];
+    const int length = in_data.shape_[2] * in_data.shape_[3];
+    const uint32_t channel = in_data.shape_[1];
+    const int step = channel * length;
+
+    #pragma omp parallel for
+    for (auto n = 0; n < batch_size; ++n) {
+      NormalizeBackwardImpl<xpu>(ctx, inputs, outputs, req, param, length, channel, n*step);
     }
-  });
+  }
 }
 
 template<typename DType>
@@ -190,7 +338,7 @@ void FlipImpl(const TShape &shape, DType *src, DType *dst) {
   }
 }
 
-void FlipLeftRight(const nnvm::NodeAttrs &attrs,
+inline void FlipLeftRight(const nnvm::NodeAttrs &attrs,
                    const OpContext &ctx,
                    const std::vector<TBlob> &inputs,
                    const std::vector<OpReqType> &req,
@@ -202,7 +350,7 @@ void FlipLeftRight(const nnvm::NodeAttrs &attrs,
   });
 }
 
-void FlipTopBottom(const nnvm::NodeAttrs &attrs,
+inline void FlipTopBottom(const nnvm::NodeAttrs &attrs,
                    const OpContext &ctx,
                    const std::vector<TBlob> &inputs,
                    const std::vector<OpReqType> &req,
@@ -214,7 +362,7 @@ void FlipTopBottom(const nnvm::NodeAttrs &attrs,
   });
 }
 
-void RandomFlipLeftRight(
+inline void RandomFlipLeftRight(
     const nnvm::NodeAttrs &attrs,
     const OpContext &ctx,
     const std::vector<TBlob> &inputs,
@@ -235,7 +383,7 @@ void RandomFlipLeftRight(
   });
 }
 
-void RandomFlipTopBottom(
+inline void RandomFlipTopBottom(
     const nnvm::NodeAttrs &attrs,
     const OpContext &ctx,
     const std::vector<TBlob> &inputs,
@@ -287,7 +435,7 @@ inline void AdjustBrightnessImpl(const float& alpha_b,
   });
 }
 
-void RandomBrightness(const nnvm::NodeAttrs &attrs,
+inline void RandomBrightness(const nnvm::NodeAttrs &attrs,
                       const OpContext &ctx,
                       const std::vector<TBlob> &inputs,
                       const std::vector<OpReqType> &req,
@@ -405,7 +553,7 @@ inline void RandomSaturation(const nnvm::NodeAttrs &attrs,
   AdjustSaturationImpl(alpha_s, ctx, inputs, req, outputs);
 }
 
-void RGB2HLSConvert(const float& src_r,
+inline void RGB2HLSConvert(const float& src_r,
                     const float& src_g,
                     const float& src_b,
                     float *dst_h,
@@ -443,7 +591,7 @@ void RGB2HLSConvert(const float& src_r,
   *dst_s = s;
 }
 
-void HLS2RGBConvert(const float& src_h,
+inline void HLS2RGBConvert(const float& src_h,
                     const float& src_l,
                     const float& src_s,
                     float *dst_r,
@@ -494,7 +642,7 @@ void HLS2RGBConvert(const float& src_h,
   *dst_r = r * 255.f;
 }
 
-void AdjustHueImpl(float alpha,
+inline void AdjustHueImpl(float alpha,
                    const OpContext &ctx,
                    const std::vector<TBlob> &inputs,
                    const std::vector<OpReqType> &req,
@@ -521,7 +669,7 @@ void AdjustHueImpl(float alpha,
   });
 }
 
-void RandomHue(const nnvm::NodeAttrs &attrs,
+inline void RandomHue(const nnvm::NodeAttrs &attrs,
                const OpContext &ctx,
                const std::vector<TBlob> &inputs,
                const std::vector<OpReqType> &req,
@@ -554,7 +702,7 @@ struct RandomColorJitterParam : public dmlc::Parameter<RandomColorJitterParam> {
   }
 };
 
-void RandomColorJitter(const nnvm::NodeAttrs &attrs,
+inline void RandomColorJitter(const nnvm::NodeAttrs &attrs,
                        const OpContext &ctx,
                        const std::vector<TBlob> &inputs,
                        const std::vector<OpReqType> &req,
@@ -623,7 +771,7 @@ struct RandomLightingParam : public dmlc::Parameter<RandomLightingParam> {
   }
 };
 
-void AdjustLightingImpl(const nnvm::Tuple<float>& alpha,
+inline void AdjustLightingImpl(const nnvm::Tuple<float>& alpha,
                         const OpContext &ctx,
                         const std::vector<TBlob> &inputs,
                         const std::vector<OpReqType> &req,
@@ -658,7 +806,7 @@ void AdjustLightingImpl(const nnvm::Tuple<float>& alpha,
   });
 }
 
-void AdjustLighting(const nnvm::NodeAttrs &attrs,
+inline void AdjustLighting(const nnvm::NodeAttrs &attrs,
                     const OpContext &ctx,
                     const std::vector<TBlob> &inputs,
                     const std::vector<OpReqType> &req,
@@ -668,7 +816,7 @@ void AdjustLighting(const nnvm::NodeAttrs &attrs,
   AdjustLightingImpl(param.alpha, ctx, inputs, req, outputs);
 }
 
-void RandomLighting(const nnvm::NodeAttrs &attrs,
+inline void RandomLighting(const nnvm::NodeAttrs &attrs,
                     const OpContext &ctx,
                     const std::vector<TBlob> &inputs,
                     const std::vector<OpReqType> &req,
diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc
index 26f520b..7901747 100644
--- a/src/operator/image/image_random.cc
+++ b/src/operator/image/image_random.cc
@@ -49,21 +49,92 @@ NNVM_REGISTER_OP(_image_to_tensor)
 .add_argument("data", "NDArray-or-Symbol", "The input.");
 
 NNVM_REGISTER_OP(_image_normalize)
-.describe(R"code()code" ADD_FILELINE)
+.describe(R"code(Normalize an tensor of shape (C x H x W) or (N x C x H x W) with mean and
+    standard deviation.
+
+    Given mean `(m1, ..., mn)` and std `(s\ :sub:`1`\ , ..., s\ :sub:`n`)` for `n` channels,
+    this transform normalizes each channel of the input tensor with:
+
+.. math::
+
+        output[i] = (input[i] - m\ :sub:`i`\ ) / s\ :sub:`i`
+
+    If mean or std is scalar, the same value will be applied to all channels.
+
+    Default value for mean is 0.0 and stand deviation is 1.0.
+
+Example:
+
+    .. code-block:: python
+        image = mx.nd.random.uniform(0, 1, (3, 4, 2))
+        normalize(image, mean=(0, 1, 2), std=(3, 2, 1))
+            [[[ 0.18293785  0.19761486]
+              [ 0.23839645  0.28142193]
+              [ 0.20092112  0.28598186]
+              [ 0.18162774  0.28241724]]
+             [[-0.2881726  -0.18821815]
+              [-0.17705294 -0.30780914]
+              [-0.2812064  -0.3512327 ]
+              [-0.05411351 -0.4716435 ]]
+             [[-1.0363373  -1.7273437 ]
+              [-1.6165586  -1.5223348 ]
+              [-1.208275   -1.1878313 ]
+              [-1.4711051  -1.5200229 ]]]
+            <NDArray 3x4x2 @cpu(0)>
+
+        image = mx.nd.random.uniform(0, 1, (2, 3, 4, 2))
+        normalize(image, mean=(0, 1, 2), std=(3, 2, 1))
+            [[[[ 0.18934818  0.13092826]
+               [ 0.3085322   0.27869293]
+               [ 0.02367868  0.11246539]
+               [ 0.0290431   0.2160573 ]]
+              [[-0.4898908  -0.31587923]
+               [-0.08369008 -0.02142242]
+               [-0.11092162 -0.42982462]
+               [-0.06499392 -0.06495637]]
+              [[-1.0213816  -1.526392  ]
+               [-1.2008414  -1.1990893 ]
+               [-1.5385206  -1.4795225 ]
+               [-1.2194707  -1.3211205 ]]]
+             [[[ 0.03942481  0.24021089]
+               [ 0.21330701  0.1940066 ]
+               [ 0.04778443  0.17912441]
+               [ 0.31488964  0.25287187]]
+              [[-0.23907584 -0.4470462 ]
+               [-0.29266903 -0.2631998 ]
+               [-0.3677222  -0.40683383]
+               [-0.11288315 -0.13154092]]
+              [[-1.5438497  -1.7834496 ]
+               [-1.431566   -1.8647819 ]
+               [-1.9812102  -1.675859  ]
+               [-1.3823645  -1.8503251 ]]]]
+            <NDArray 2x3x4x2 @cpu(0)>
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<NormalizeParam>)
 .set_num_inputs(1)
 .set_num_outputs(1)
-.set_attr_parser(ParamParser<NormalizeParam>)
-.set_attr<nnvm::FInferShape>("FInferShape", NormalizeShape)
-.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"data"};
+  })
+.set_attr<nnvm::FInferShape>("FInferShape", NormalizeOpShape)
+.set_attr<nnvm::FInferType>("FInferType", NormalizeOpType)
+.set_attr<FCompute>("FCompute<cpu>", NormalizeOpForward<cpu>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
-  [](const NodeAttrs& attrs){
+  [](const NodeAttrs& attrs) {
     return std::vector<std::pair<int, int> >{{0, 0}};
   })
-.set_attr<FCompute>("FCompute<cpu>", Normalize)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
-.add_argument("data", "NDArray-or-Symbol", "The input.")
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_image_normalize"})
+.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
 .add_arguments(NormalizeParam::__FIELDS__());
 
+NNVM_REGISTER_OP(_backward_image_normalize)
+.set_attr_parser(ParamParser<NormalizeParam>)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", NormalizeOpBackward<cpu>);
+
 MXNET_REGISTER_IMAGE_AUG_OP(_image_flip_left_right)
 .describe(R"code()code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", FlipLeftRight);
diff --git a/src/operator/image/image_random.cu b/src/operator/image/image_random.cu
new file mode 100644
index 0000000..404c3d2
--- /dev/null
+++ b/src/operator/image/image_random.cu
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file image_random.cu
+ * \brief GPU Implementation of image transformation operators
+ */
+#include "./image_random-inl.h"
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+namespace image {
+
+NNVM_REGISTER_OP(_image_normalize)
+.set_attr<FCompute>("FCompute<gpu>", NormalizeOpForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_image_normalize)
+.set_attr<FCompute>("FCompute<gpu>", NormalizeOpBackward<gpu>);
+
+
+}  // namespace image
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/gpu/test_gluon_transforms.py b/tests/python/gpu/test_gluon_transforms.py
new file mode 100644
index 0000000..c7afc76
--- /dev/null
+++ b/tests/python/gpu/test_gluon_transforms.py
@@ -0,0 +1,72 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import print_function
+import os
+import sys
+import mxnet as mx
+import mxnet.ndarray as nd
+import numpy as np
+from mxnet import gluon
+from mxnet.base import MXNetError
+from mxnet.gluon.data.vision import transforms
+from mxnet.test_utils import assert_almost_equal, set_default_context
+from mxnet.test_utils import almost_equal
+curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
+sys.path.insert(0, os.path.join(curr_path, '../unittest'))
+from common import assertRaises, setup_module, with_seed, teardown
+
+
+set_default_context(mx.gpu(0))
+
+@with_seed()
+def test_normalize():
+    # 3D Input
+    data_in_3d = nd.random.uniform(0, 1, (3, 300, 300))
+    out_nd_3d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_3d)
+    data_expected_3d = data_in_3d.asnumpy()
+    data_expected_3d[:][:][0] = data_expected_3d[:][:][0] / 3.0
+    data_expected_3d[:][:][1] = (data_expected_3d[:][:][1] - 1.0) / 2.0
+    data_expected_3d[:][:][2] = data_expected_3d[:][:][2] - 2.0
+    assert_almost_equal(data_expected_3d, out_nd_3d.asnumpy())
+
+    # 4D Input
+    data_in_4d = nd.random.uniform(0, 1, (2, 3, 300, 300))
+    out_nd_4d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_4d)
+    data_expected_4d = data_in_4d.asnumpy()
+    data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0
+    data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0
+    data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0
+    data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0
+    data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0
+    data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0
+    assert_almost_equal(data_expected_4d, out_nd_4d.asnumpy())
+
+    # Default normalize values i.e., mean=0, std=1
+    data_in_3d_def = nd.random.uniform(0, 1, (3, 300, 300))
+    out_nd_3d_def = transforms.Normalize()(data_in_3d_def)
+    data_expected_3d_def = data_in_3d_def.asnumpy()
+    assert_almost_equal(data_expected_3d_def, out_nd_3d_def.asnumpy())
+
+    # Invalid Input - Neither 3D or 4D input
+    invalid_data_in = nd.random.uniform(0, 1, (5, 5, 3, 300, 300))
+    normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
+    assertRaises(MXNetError, normalize_transformer, invalid_data_in)
+
+    # Invalid Input - Channel neither 1 or 3
+    invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300))
+    normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
+    assertRaises(MXNetError, normalize_transformer, invalid_data_in)
\ No newline at end of file
diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py
index 2ff9c5c..c83778f 100644
--- a/tests/python/unittest/test_gluon_data_vision.py
+++ b/tests/python/unittest/test_gluon_data_vision.py
@@ -19,10 +19,11 @@ import mxnet as mx
 import mxnet.ndarray as nd
 import numpy as np
 from mxnet import gluon
+from mxnet.base import MXNetError
 from mxnet.gluon.data.vision import transforms
 from mxnet.test_utils import assert_almost_equal
 from mxnet.test_utils import almost_equal
-from common import setup_module, with_seed, teardown
+from common import assertRaises, setup_module, with_seed, teardown
 
 
 @with_seed()
@@ -35,14 +36,36 @@ def test_to_tensor():
 
 @with_seed()
 def test_normalize():
-    data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
-    data_in = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
-    out_nd = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in)
-    data_expected = data_in.asnumpy()
-    data_expected[:][:][0] = data_expected[:][:][0] / 3.0
-    data_expected[:][:][1] = (data_expected[:][:][1] - 1.0) / 2.0
-    data_expected[:][:][2] = data_expected[:][:][2] - 2.0
-    assert_almost_equal(data_expected, out_nd.asnumpy())
+    # 3D Input
+    data_in_3d = nd.random.uniform(0, 1, (3, 300, 300))
+    out_nd_3d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_3d)
+    data_expected_3d = data_in_3d.asnumpy()
+    data_expected_3d[:][:][0] = data_expected_3d[:][:][0] / 3.0
+    data_expected_3d[:][:][1] = (data_expected_3d[:][:][1] - 1.0) / 2.0
+    data_expected_3d[:][:][2] = data_expected_3d[:][:][2] - 2.0
+    assert_almost_equal(data_expected_3d, out_nd_3d.asnumpy())
+
+    # 4D Input
+    data_in_4d = nd.random.uniform(0, 1, (2, 3, 300, 300))
+    out_nd_4d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_4d)
+    data_expected_4d = data_in_4d.asnumpy()
+    data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0
+    data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0
+    data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0
+    data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0
+    data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0
+    data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0
+    assert_almost_equal(data_expected_4d, out_nd_4d.asnumpy())
+
+    # Invalid Input - Neither 3D or 4D input
+    invalid_data_in = nd.random.uniform(0, 1, (5, 5, 3, 300, 300))
+    normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
+    assertRaises(MXNetError, normalize_transformer, invalid_data_in)
+
+    # Invalid Input - Channel neither 1 or 3
+    invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300))
+    normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
+    assertRaises(MXNetError, normalize_transformer, invalid_data_in)
 
 
 @with_seed()
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 670cc7e..ce61beb 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -7326,6 +7326,73 @@ def test_invalid_max_pooling_pad_type_same():
         name='pooling',
         pooling_convention="same")
 
+
+@with_seed()
+def test_image_normalize():
+    # Part 1 - Test 3D Input
+    shape_3d = (3, 28, 28)
+    mean = (0, 1, 2)
+    std = (3, 2, 1)
+
+    data_in_3d = mx.nd.random.uniform(0, 1, shape_3d)
+    data_expected_3d = data_in_3d.asnumpy()
+    data_expected_3d[:][:][0] = data_expected_3d[:][:][0] / 3.0
+    data_expected_3d[:][:][1] = (data_expected_3d[:][:][1] - 1.0) / 2.0
+    data_expected_3d[:][:][2] = data_expected_3d[:][:][2] - 2.0
+
+    data = mx.symbol.Variable('data')
+    img_norm_sym = mx.sym.image.normalize(data=data, mean=mean, std=std)
+
+    # check forward
+    check_symbolic_forward(img_norm_sym, [data_in_3d], [data_expected_3d],
+                           rtol=1e-5, atol=1e-5)
+
+    # Gradient is 1/std_dev
+    grad_expected_3d = np.ones(shape_3d)
+    grad_expected_3d[:][:][0] = 1 / 3.0
+    grad_expected_3d[:][:][1] = 1 / 2.0
+    grad_expected_3d[:][:][2] = 1 / 1.0
+
+    # check backward
+    check_symbolic_backward(img_norm_sym, location=[data_in_3d], out_grads=[mx.nd.ones(shape_3d)],
+                            expected=[grad_expected_3d], rtol=1e-5, atol=1e-5)
+
+    # check backward using finite difference
+    check_numeric_gradient(img_norm_sym, [data_in_3d], atol=0.001)
+
+    # Part 2 - Test 4D Input
+    shape_4d = (2, 3, 28, 28)
+
+    data_in_4d = mx.nd.random.uniform(0, 1, shape_4d)
+    data_expected_4d = data_in_4d.asnumpy()
+    data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0
+    data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0
+    data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0
+    data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0
+    data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0
+    data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0
+
+    # check forward
+    check_symbolic_forward(img_norm_sym, [data_in_4d], [data_expected_4d],
+                           rtol=1e-5, atol=1e-5)
+
+    # Gradient is 1/std_dev
+    grad_expected_4d = np.ones(shape_4d)
+    grad_expected_4d[0][:][:][0] = 1 / 3.0
+    grad_expected_4d[0][:][:][1] = 1 / 2.0
+    grad_expected_4d[0][:][:][2] = 1 / 1.0
+    grad_expected_4d[1][:][:][0] = 1 / 3.0
+    grad_expected_4d[1][:][:][1] = 1 / 2.0
+    grad_expected_4d[1][:][:][2] = 1 / 1.0
+
+    # check backward
+    check_symbolic_backward(img_norm_sym, location=[data_in_4d], out_grads=[mx.nd.ones(shape_4d)],
+                            expected=[grad_expected_4d], rtol=1e-5, atol=1e-5)
+
+    # check backward using finite difference
+    check_numeric_gradient(img_norm_sym, [data_in_4d], atol=0.001)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()