You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/22 20:43:47 UTC

[incubator-mxnet] 16/20: Vision (#8856)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit fb199fc246153df0a54e3e94bae7b4ca330f33fa
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Tue Nov 28 17:14:26 2017 -0800

    Vision (#8856)
    
    * refactor
    
    * fix
    
    * fix
---
 src/operator/image/image_random-inl.h | 756 ++++++++++++++++++----------------
 src/operator/image/image_random.cc    |  84 ++--
 2 files changed, 455 insertions(+), 385 deletions(-)

diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h
index 3bee843..9d10a30 100644
--- a/src/operator/image/image_random-inl.h
+++ b/src/operator/image/image_random-inl.h
@@ -38,18 +38,19 @@
 
 namespace mxnet {
 namespace op {
+namespace image {
 
-inline bool CheckIsImage(const TBlob &image) {
-  CHECK_EQ(image.type_flag_, mshadow::kUint8) << "input type is not an image.";
-  CHECK_EQ(image.ndim(), 3) << "input dimension is not 3.";
-  CHECK(image.shape_[2] == 1 || image.shape_[2] == 3) << "image channel should be 1 or 3.";
-}
-
-static void RandomFlip(const nnvm::NodeAttrs &attrs,
-                       const OpContext &ctx,
-                       const std::vector<TBlob> &inputs,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &outputs) {
+inline bool ToTensorShape(const nnvm::NodeAttrs& attrs,
+                          std::vector<TShape> *in_attrs,
+                          std::vector<TShape> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  TShape &shp = (*in_attrs)[0];
+  if (!shp.ndim()) return false;
+  CHECK_EQ(shp.ndim(), 3)
+      << "Input image must have shape (height, width, channels), but got " << shp;
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[2], shp[0], shp[1]}));
+  return true;
 }
 
 inline bool ToTensorType(const nnvm::NodeAttrs& attrs,
@@ -57,47 +58,39 @@ inline bool ToTensorType(const nnvm::NodeAttrs& attrs,
                          std::vector<int> *out_attrs) {
   CHECK_EQ(in_attrs->size(), 1U);
   CHECK_EQ(out_attrs->size(), 1U);
-  CHECK_EQ((*in_attrs)[0], mshadow::kUint8)
-    << "`to_tensor` only supports uint8 input";
   TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32);
   return (*in_attrs)[0] != -1;
 }
 
-inline bool ToTensorShape(const nnvm::NodeAttrs& attrs,
-                          std::vector<TShape> *in_attrs,
-                          std::vector<TShape> *out_attrs) {
-  CHECK_EQ(in_attrs->size(), 1U);
-  CHECK_EQ(out_attrs->size(), 1U);
-  TShape &shp = (*in_attrs)[0];
-  CHECK_EQ(shp.ndim(), 3U) << "`to_tensor` only supports 3 dimensions";
-  TShape ret(3);
-  ret[0] = shp[2];
-  ret[1] = shp[0];
-  ret[2] = shp[1];
-  SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret);
-  return true;
-}
-
-static void ToTensor(const nnvm::NodeAttrs &attrs,
+void ToTensor(const nnvm::NodeAttrs &attrs,
                      const OpContext &ctx,
                      const std::vector<TBlob> &inputs,
                      const std::vector<OpReqType> &req,
                      const std::vector<TBlob> &outputs) {
   CHECK_EQ(req[0], kWriteTo)
     << "`to_tensor` does not support inplace";
-  CheckIsImage(inputs[0]);
 
   int length = inputs[0].shape_[0] * inputs[0].shape_[1];
   int channel = inputs[0].shape_[2];
 
-  float* output = outputs[0].dptr<float>();
-  uint8_t* input = inputs[0].dptr<uint8_t>();
+  MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    float* output = outputs[0].dptr<float>();
+    DType* input = inputs[0].dptr<DType>();
 
-  for (int l = 0; l < length; ++l) {
-    for (int c = 0; c < channel; ++c) {
-      output[c*length + l] = static_cast<float>(input[l*channel + c]) / 255.0f;
+    for (int l = 0; l < length; ++l) {
+      for (int c = 0; c < channel; ++c) {
+        output[c*length + l] = static_cast<float>(input[l*channel + c]) / 255.0f;
+      }
     }
-  }
+  });
+}
+
+inline bool TensorShape(const nnvm::NodeAttrs& attrs,
+                       std::vector<TShape> *in_attrs,
+                       std::vector<TShape> *out_attrs) {
+  TShape& dshape = (*in_attrs)[0];
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
+  return true;
 }
 
 struct NormalizeParam : public dmlc::Parameter<NormalizeParam> {
@@ -117,20 +110,28 @@ inline bool NormalizeShape(const nnvm::NodeAttrs& attrs,
   const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);
   const auto& dshape = (*in_attrs)[0];
   if (!dshape.ndim()) return false;
-  CHECK_EQ(dshape.ndim(), 3)
-      << "Input must have 3 dimensions";
 
+  CHECK_EQ(dshape.ndim(), 3)
+      << "Input tensor must have shape (channels, height, width), but got "
+      << dshape;
   auto nchannels = dshape[0];
+  CHECK(nchannels == 3 || nchannels == 1)
+      << "The first dimension of input tensor must be the channel dimension with "
+      << "either 1 or 3 elements, but got input with shape " << dshape;
   CHECK(param.mean.ndim() == 1 || param.mean.ndim() == nchannels)
-      << "mean must have either 1 or " << nchannels << " elements";
+      << "Invalid mean for input with shape " << dshape
+      << ". mean must have either 1 or " << nchannels
+      << " elements, but got " << param.mean;
   CHECK(param.std.ndim() == 1 || param.std.ndim() == nchannels)
-      << "std must have either 1 or " << nchannels << " elements";
+      << "Invalid std for input with shape " << dshape
+      << ". std must have either 1 or " << nchannels
+      << " elements, but got " << param.std;
 
   SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
+  return true;
 }
 
-
-static void Normalize(const nnvm::NodeAttrs &attrs,
+void Normalize(const nnvm::NodeAttrs &attrs,
                       const OpContext &ctx,
                       const std::vector<TBlob> &inputs,
                       const std::vector<OpReqType> &req,
@@ -154,214 +155,241 @@ static void Normalize(const nnvm::NodeAttrs &attrs,
   });
 }
 
-struct FlipParam : public dmlc::Parameter<FlipParam> {
-  int axis;
-  DMLC_DECLARE_PARAMETER(FlipParam) {
-    DMLC_DECLARE_FIELD(axis)
-    .describe("0 or 1. 0 for horizontal flip, 1 for vertical flip.");
-  }
-};
+template<typename DType>
+inline DType saturate_cast(const float& src) {
+  return static_cast<DType>(src);
+}
 
-#define SWAP_IF_INPLACE(dst, dst_idx, src, src_idx) \
-  if (dst == src) {                                 \
-    std::swap(dst[dst_idx], src[src_idx]);          \
-  } else {                                          \
-    dst[dst_idx] = src[src_idx];                    \
-  }
+template<>
+inline uint8_t saturate_cast(const float& src) {
+  return std::min(std::max(src, 0.f), 255.f);
+}
+
+inline bool ImageShape(const nnvm::NodeAttrs& attrs,
+                       std::vector<TShape> *in_attrs,
+                       std::vector<TShape> *out_attrs) {
+  TShape& dshape = (*in_attrs)[0];
+  CHECK_EQ(dshape.ndim(), 3)
+      << "Input image must have shape (height, width, channels), but got " << dshape;
+  auto nchannels = dshape[dshape.ndim()-1];
+  CHECK(nchannels == 3 || nchannels == 1)
+      << "The last dimension of input image must be the channel dimension with "
+      << "either 1 or 3 elements, but got input with shape " << dshape;
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
+  return true;
+}
 
 template<typename DType>
-static void FlipImpl(const TShape &shape, DType *src, DType *dst, int axis) {
-  const int height = shape[0];
-  const int width = shape[1];
-  const int nchannel = shape[2];
-
-  const int length = width * nchannel;
-  const int height_stride = (src == dst && axis == 1) ? (height >> 1) : height;
-  const int width_stride = (src == dst && axis == 0) ? (width >> 1) : width;
-
-  for (int h = 0; h < height_stride; ++h) {
-    const int h_dst = (axis == 0) ? h : (height - h);
-    for (int w = 0; w < width_stride; ++w) {
-      const int w_dst = (axis == 0) ? (width - w) : w;
-      const int idx_dst = h_dst * length + w_dst * nchannel;
-      const int idx_src = h * length + w * nchannel;
-      SWAP_IF_INPLACE(dst, idx_dst, src, idx_src);
-      if (nchannel > 1) {
-        SWAP_IF_INPLACE(dst, idx_dst + 1, src, idx_src + 1);
-        SWAP_IF_INPLACE(dst, idx_dst + 2, src, idx_src + 2);
+void FlipImpl(const TShape &shape, DType *src, DType *dst, int axis) {
+  int head = 1, mid = shape[axis], tail = 1;
+  for (int i = 0; i < axis; ++i) head *= shape[i];
+  for (int i = axis+1; i < shape.ndim(); ++i) tail *= shape[i];
+
+  for (int i = 0; i < head; ++i) {
+    for (int j = 0; j < (mid >>2); ++j) {
+      int idx1 = (i*mid + j)*tail;
+      int idx2 = idx1 + (mid - (j<<2))*tail;
+      for (int k = 0; k < tail; ++k, ++idx1, ++idx2) {
+        DType tmp = src[idx1];
+        dst[idx1] = src[idx2];
+        dst[idx2] = tmp;
       }
     }
   }
 }
 
-static void Flip(const nnvm::NodeAttrs &attrs,
-                  const OpContext &ctx,
-                  const std::vector<TBlob> &inputs,
-                  const std::vector<OpReqType> &req,
-                  const std::vector<TBlob> &outputs) {
-  const FlipParam &param = nnvm::get<FlipParam>(attrs.parsed);
-  CHECK(param.axis == 0 || param.axis == 1) << "flip axis must be 0 or 1.";
-  CheckIsImage(inputs[0]);
-  const TShape& ishape = inputs[0].shape_;
+void RandomHorizontalFlip(
+    const nnvm::NodeAttrs &attrs,
+    const OpContext &ctx,
+    const std::vector<TBlob> &inputs,
+    const std::vector<OpReqType> &req,
+    const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  Random<cpu> *prnd = ctx.requested[0].get_random<cpu, float>(s);
+  if (std::bernoulli_distribution()(prnd->GetRndEngine())) return;
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    FlipImpl(inputs[0].shape_, inputs[0].dptr<DType>(),
+             outputs[0].dptr<DType>(), 1);
+  });
+}
+
+void RandomVerticalFlip(
+    const nnvm::NodeAttrs &attrs,
+    const OpContext &ctx,
+    const std::vector<TBlob> &inputs,
+    const std::vector<OpReqType> &req,
+    const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  Random<cpu> *prnd = ctx.requested[0].get_random<cpu, float>(s);
+  if (std::bernoulli_distribution()(prnd->GetRndEngine())) return;
   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-    FlipImpl(ishape, inputs[0].dptr<DType>(), outputs[0].dptr<DType>(), param.axis);
+    FlipImpl(inputs[0].shape_, inputs[0].dptr<DType>(),
+             outputs[0].dptr<DType>(), 0);
   });
 }
 
-struct RandomBrightnessParam : public dmlc::Parameter<RandomBrightnessParam> {
-  float max_brightness;
-  DMLC_DECLARE_PARAMETER(RandomBrightnessParam) {
-    DMLC_DECLARE_FIELD(max_brightness)
+struct RandomEnhanceParam : public dmlc::Parameter<RandomEnhanceParam> {
+  float min_factor;
+  float max_factor;
+  DMLC_DECLARE_PARAMETER(RandomEnhanceParam) {
+    DMLC_DECLARE_FIELD(min_factor)
+    .set_lower_bound(0.0)
+    .describe("Minimum factor.");
+    DMLC_DECLARE_FIELD(max_factor)
     .set_lower_bound(0.0)
-    .describe("Max Brightness.");
+    .describe("Maximum factor.");
   }
 };
 
-static void RandomBrightness(const nnvm::NodeAttrs &attrs,
-                             const OpContext &ctx,
-                             const std::vector<TBlob> &inputs,
-                             const std::vector<OpReqType> &req,
-                             const std::vector<TBlob> &outputs) {
+inline void AdjustBrightnessImpl(const float& alpha_b,
+                                 const OpContext &ctx,
+                                 const std::vector<TBlob> &inputs,
+                                 const std::vector<OpReqType> &req,
+                                 const std::vector<TBlob> &outputs) {
   using namespace mshadow;
-  const RandomBrightnessParam &param = nnvm::get<RandomBrightnessParam>(attrs.parsed);
-
   int length = inputs[0].Size();
 
-  uint8_t* output = outputs[0].dptr<uint8_t>();
-  uint8_t* input = inputs[0].dptr<uint8_t>();
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    DType* output = outputs[0].dptr<DType>();
+    DType* input = inputs[0].dptr<DType>();
+    for (int l = 0; l < length; ++l) {
+      float val = static_cast<float>(input[l]) * alpha_b;
+      output[l] = saturate_cast<DType>(val);
+    }
+  });
+}
+
+void RandomBrightness(const nnvm::NodeAttrs &attrs,
+                      const OpContext &ctx,
+                      const std::vector<TBlob> &inputs,
+                      const std::vector<OpReqType> &req,
+                      const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  const RandomEnhanceParam &param = nnvm::get<RandomEnhanceParam>(attrs.parsed);
+
 
   Stream<cpu> *s = ctx.get_stream<cpu>();
   Random<cpu> *prnd = ctx.requested[0].get_random<cpu, float>(s);
-  float alpha_b = 1.0 + std::uniform_real_distribution<float>(
-      -param.max_brightness, param.max_brightness)(prnd->GetRndEngine());
+  float alpha_b = std::uniform_real_distribution<float>(
+      param.min_factor, param.max_factor)(prnd->GetRndEngine());
 
-  for (int l = 0; l < length; ++l) {
-    float val = static_cast<float>(input[l]) * alpha_b;
-    val = std::min(std::max(val, 0.f), 255.f);
-    output[l] = static_cast<uint8_t>(val);
-  }
+  AdjustBrightnessImpl(alpha_b, ctx, inputs, req, outputs);
 }
 
+inline void AdjustContrastImpl(const float& alpha_c,
+                               const OpContext &ctx,
+                               const std::vector<TBlob> &inputs,
+                               const std::vector<OpReqType> &req,
+                               const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  static const float coef[] = { 0.299f, 0.587f, 0.114f };
 
-struct RandomContrastParam : public dmlc::Parameter<RandomContrastParam> {
-  float max_contrast;
-  DMLC_DECLARE_PARAMETER(RandomContrastParam) {
-    DMLC_DECLARE_FIELD(max_contrast)
-    .set_lower_bound(0.0)
-    .describe("Max Contrast.");
-  }
-};
+  int length = inputs[0].shape_[0] * inputs[0].shape_[1];
+  int nchannels = inputs[0].shape_[2];
 
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    DType* output = outputs[0].dptr<DType>();
+    DType* input = inputs[0].dptr<DType>();
+
+    float sum = 0.f;
+    if (nchannels > 1) {
+      for (int l = 0; l < length; ++l) {
+        for (int c = 0; c < 3; ++c) sum += input[l*3 + c] * coef[c];
+      }
+    } else {
+      for (int l = 0; l < length; ++l) sum += input[l];
+    }
+    float gray_mean = sum / static_cast<float>(length);
+    float beta = (1 - alpha_c) * gray_mean;
 
-static void RandomContrast(const nnvm::NodeAttrs &attrs,
+    for (int l = 0; l < length * nchannels; ++l) {
+      float val = input[l] * alpha_c + beta;
+      output[l] = saturate_cast<DType>(val);
+    }
+  });
+}
+
+inline void RandomContrast(const nnvm::NodeAttrs &attrs,
                            const OpContext &ctx,
                            const std::vector<TBlob> &inputs,
                            const std::vector<OpReqType> &req,
                            const std::vector<TBlob> &outputs) {
   using namespace mshadow;
+  const RandomEnhanceParam &param = nnvm::get<RandomEnhanceParam>(attrs.parsed);
+
+
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
+  float alpha_c = std::uniform_real_distribution<float>(
+      param.min_factor, param.max_factor)(prnd->GetRndEngine());
+
+  AdjustContrastImpl(alpha_c, ctx, inputs, req, outputs);
+}
+
+inline void AdjustSaturationImpl(const float& alpha_s,
+                                 const OpContext &ctx,
+                                 const std::vector<TBlob> &inputs,
+                                 const std::vector<OpReqType> &req,
+                                 const std::vector<TBlob> &outputs) {
   static const float coef[] = { 0.299f, 0.587f, 0.114f };
-  const RandomContrastParam &param = nnvm::get<RandomContrastParam>(attrs.parsed);
 
   int length = inputs[0].shape_[0] * inputs[0].shape_[1];
   int nchannels = inputs[0].shape_[2];
 
-  uint8_t* output = outputs[0].dptr<uint8_t>();
-  uint8_t* input = inputs[0].dptr<uint8_t>();
+  float alpha_o = 1.f - alpha_s;
 
-  Stream<cpu> *s = ctx.get_stream<cpu>();
-  Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
-  float alpha_c = 1.0 + std::uniform_real_distribution<float>(
-    -param.max_contrast, param.max_contrast)(prnd->GetRndEngine());
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    DType* output = outputs[0].dptr<DType>();
+    DType* input = inputs[0].dptr<DType>();
 
-  float sum = 0.f;
-  if (nchannels > 1) {
-    for (int l = 0; l < length; ++l) {
-      for (int c = 0; c < nchannels; ++c) sum += input[l*nchannels + c] * coef[c];
+    if (nchannels == 1) {
+      for (int l = 0; l < length; ++l) output[l] = input[l];
+      return;
     }
-  } else {
-    for (int l = 0; l < length; ++l) sum += input[l];
-  }
-  float gray_mean = sum / static_cast<float>(length);
-  float beta = (1 - alpha_c) * gray_mean;
 
-  for (int l = 0; l < length * nchannels; ++l) {
-    float val = input[l] * alpha_c + beta;
-    val = std::min(std::max(val, 0.f), 255.f);
-    output[l] = static_cast<uint8_t>(val);
-  }
+    for (int l = 0; l < length; ++l) {
+      float gray = 0.f;
+      for (int c = 0; c < 3; ++c) {
+        gray = input[l*3 + c] * coef[c];
+      }
+      gray *= alpha_o;
+      for (int c = 0; c < 3; ++c) {
+        float val = gray + input[l*3 + c] * alpha_s;
+        output[l*3 + c] = saturate_cast<DType>(val);
+      }
+    }
+  });
 }
 
-struct RandomSaturationParam : public dmlc::Parameter<RandomSaturationParam> {
-  float max_saturation;
-  DMLC_DECLARE_PARAMETER(RandomSaturationParam) {
-    DMLC_DECLARE_FIELD(max_saturation)
-    .set_default(0.0)
-    .describe("Max Saturation.");
-  }
-};
-
-static void RandomSaturation(const nnvm::NodeAttrs &attrs,
+inline void RandomSaturation(const nnvm::NodeAttrs &attrs,
                              const OpContext &ctx,
                              const std::vector<TBlob> &inputs,
                              const std::vector<OpReqType> &req,
                              const std::vector<TBlob> &outputs) {
   using namespace mshadow;
-  const RandomSaturationParam &param = nnvm::get<RandomSaturationParam>(attrs.parsed);
-  static const float coef[] = { 0.299f, 0.587f, 0.114f };
-
-  int length = inputs[0].shape_[0] * inputs[0].shape_[1];
-  int nchannels = inputs[0].shape_[2];
-
-  uint8_t* output = outputs[0].dptr<uint8_t>();
-  uint8_t* input = inputs[0].dptr<uint8_t>();
+  const RandomEnhanceParam &param = nnvm::get<RandomEnhanceParam>(attrs.parsed);
 
   Stream<cpu> *s = ctx.get_stream<cpu>();
   Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
-  float alpha_s = 1.f + std::uniform_real_distribution<float>(
-    -param.max_saturation, param.max_saturation)(prnd->GetRndEngine());
-  float alpha_o = 1.f - alpha_s;
-
-  if (nchannels == 1) {
-    for (int l = 0; l < length * nchannels; ++l) output[l] = input[l];
-    return;
-  }
+  float alpha_s = std::uniform_real_distribution<float>(
+      param.min_factor, param.max_factor)(prnd->GetRndEngine());
 
-  for (int l = 0; l < length; ++l) {
-    float gray = 0.f;
-    for (int c = 0; c < nchannels; ++c) {
-      gray = input[l*nchannels + c] * coef[c];
-    }
-    gray *= alpha_o;
-    for (int c = 0; c < nchannels; ++c) {
-      float val = gray + input[l*nchannels + c] * alpha_s;
-      val = std::min(std::max(val, 0.f), 255.f);
-      output[l*nchannels + c] = static_cast<uint8_t>(val);
-    }
-  }
+  AdjustSaturationImpl(alpha_s, ctx, inputs, req, outputs);
 }
 
-struct RandomHueParam : public dmlc::Parameter<RandomHueParam> {
-  float max_hue;
-  DMLC_DECLARE_PARAMETER(RandomHueParam) {
-    DMLC_DECLARE_FIELD(max_hue)
-    .set_default(0.0)
-    .describe("Max Hue.");
-  }
-};
-
-template <typename DType> static
-void RGB2HLSConvert(const DType src_r,
-                    const DType src_g,
-                    const DType src_b,
-                    DType *dst_h,
-                    DType *dst_l,
-                    DType *dst_s
-                   ) {
-  DType b = src_b, g = src_g, r = src_r;
-  DType h = 0.f, s = 0.f, l;
-  DType vmin;
-  DType vmax;
-  DType diff;
+void RGB2HLSConvert(const float& src_r,
+                    const float& src_g,
+                    const float& src_b,
+                    float *dst_h,
+                    float *dst_l,
+                    float *dst_s) {
+  float b = src_b / 255.f, g = src_g / 255.f, r = src_r / 255.f;
+  float h = 0.f, s = 0.f, l;
+  float vmin;
+  float vmax;
+  float diff;
 
   vmax = vmin = r;
   vmax = fmax(vmax, g);
@@ -372,7 +400,7 @@ void RGB2HLSConvert(const DType src_r,
   diff = vmax - vmin;
   l = (vmax + vmin) * 0.5f;
 
-  if (diff > std::numeric_limits<DType>::epsilon()) {
+  if (diff > std::numeric_limits<float>::epsilon()) {
     s = (l < 0.5f) * diff / (vmax + vmin);
     s += (l >= 0.5f) * diff / (2.0f - vmax - vmin);
 
@@ -389,23 +417,20 @@ void RGB2HLSConvert(const DType src_r,
   *dst_s = s;
 }
 
-
-static  int c_HlsSectorData[6][3] = {
-  { 1, 3, 0 },
-  { 1, 0, 2 },
-  { 3, 0, 1 },
-  { 0, 2, 1 },
-  { 0, 1, 3 },
-  { 2, 1, 0 }
-};
-
-template <typename DType>  static  void HLS2RGBConvert(const DType src_h,
-    const DType src_l,
-    const DType src_s,
-    DType *dst_r,
-    DType *dst_g,
-    DType *dst_b) {
-
+void HLS2RGBConvert(const float& src_h,
+                    const float& src_l,
+                    const float& src_s,
+                    float *dst_r,
+                    float *dst_g,
+                    float *dst_b) {
+  static const int c_HlsSectorData[6][3] = {
+    { 1, 3, 0 },
+    { 1, 0, 2 },
+    { 3, 0, 1 },
+    { 0, 2, 1 },
+    { 0, 1, 3 },
+    { 2, 1, 0 }
+  };
 
   float h = src_h, l = src_l, s = src_s;
   float b = l, g = l, r = l;
@@ -415,6 +440,8 @@ template <typename DType>  static  void HLS2RGBConvert(const DType src_h,
     p2 += (l > 0.5f) * (l + s - l * s);
     float p1 = 2 * l - p2;
 
+    h *= 1.f / 60.f;
+
     if (h < 0) {
       do { h += 6; } while (h < 0);
     } else if (h >= 6) {
@@ -436,177 +463,202 @@ template <typename DType>  static  void HLS2RGBConvert(const DType src_h,
     r = tab[c_HlsSectorData[sector][2]];
   }
 
-  *dst_b = b;
-  *dst_g = g;
-  *dst_r = r;
+  *dst_b = b * 255.f;
+  *dst_g = g * 255.f;
+  *dst_r = r * 255.f;
 }
 
-template<typename xpu, typename DType>
-static  void RandomHueKernal(const TBlob &input,
-                             const TBlob &output,
-                             Stream<xpu> *s,
-                             int hight,
-                             int weight,
-                             DType alpha) {
-  auto input_3d = input.get<xpu, 3, DType>(s);
-  auto output_3d = output.get<xpu, 3, DType>(s);
-  for (int h_index = 0; h_index < hight; ++h_index) {
-    for (int w_index = 0; w_index < weight; ++w_index) {
-      DType h;
-      DType l;
-      DType s;
-      RGB2HLSConvert(input_3d[0][h_index][w_index],
-                     input_3d[1][h_index][w_index],
-                     input_3d[2][h_index][w_index],
-                     &h, &l, &s);
-      h += alpha;
-      h = std::max(DType(0), std::min(DType(180), h));
-
-      HLS2RGBConvert(
-        h, l, s,
-        &output_3d[0][h_index][w_index],
-        &output_3d[1][h_index][w_index],
-        &output_3d[2][h_index][w_index]);
+void AdjustHueImpl(float alpha,
+                   const OpContext &ctx,
+                   const std::vector<TBlob> &inputs,
+                   const std::vector<OpReqType> &req,
+                   const std::vector<TBlob> &outputs) {
+  int length = inputs[0].shape_[0] * inputs[0].shape_[1];
+  if (inputs[0].shape_[2] == 1) return;
+
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    DType* input = inputs[0].dptr<DType>();
+    DType* output = outputs[0].dptr<DType>();
+
+    for (int i = 0; i < length; ++i) {
+      float h, l, s;
+      float r = static_cast<float>(*(input++));
+      float g = static_cast<float>(*(input++));
+      float b = static_cast<float>(*(input++));
+      RGB2HLSConvert(r, g, b, &h, &l, &s);
+      h += alpha * 360.f;
+      HLS2RGBConvert(h, l, s, &r, &g, &b);
+      *(output++) = saturate_cast<DType>(r);
+      *(output++) = saturate_cast<DType>(g);
+      *(output++) = saturate_cast<DType>(b);
     }
-  }
+  });
 }
 
-template<typename xpu>
-static void RandomHue(const nnvm::NodeAttrs &attrs,
-                      const OpContext &ctx,
-                      const std::vector<TBlob> &inputs,
-                      const std::vector<OpReqType> &req,
-                      const std::vector<TBlob> &outputs) {
+void RandomHue(const nnvm::NodeAttrs &attrs,
+               const OpContext &ctx,
+               const std::vector<TBlob> &inputs,
+               const std::vector<OpReqType> &req,
+               const std::vector<TBlob> &outputs) {
   using namespace mshadow;
-  auto input = inputs[0];
-  auto output = outputs[0];
-  int channel = input.shape_[0];
-  int hight = input.shape_[1];
-  int weight = input.shape_[2];
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  Random<xpu> *prnd = ctx.requested[kRandom].get_random<xpu, real_t>(s);
-
-  const RandomHueParam &param = nnvm::get<RandomHueParam>(attrs.parsed);
-  float alpha =  std::uniform_real_distribution<float>(
-    -param.max_hue, param.max_hue)(prnd->GetRndEngine());
-  auto output_float = output.get<xpu, 3, float>(s);
-
-  MSHADOW_TYPE_SWITCH(input.type_flag_, DType, {
-    RandomHueKernal<xpu, DType>(input, output, s, hight, weight, alpha);
-  });
+  const RandomEnhanceParam &param = nnvm::get<RandomEnhanceParam>(attrs.parsed);
+
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
+  float alpha = std::uniform_real_distribution<float>(
+      param.min_factor, param.max_factor)(prnd->GetRndEngine());
+
+  AdjustHueImpl(alpha, ctx, inputs, req, outputs);
 }
 
-static void RandomColorJitter(const nnvm::NodeAttrs &attrs,
-                              const OpContext &ctx,
-                              const std::vector<TBlob> &inputs,
-                              const std::vector<OpReqType> &req,
-                              const std::vector<TBlob> &outputs) {
+struct RandomColorJitterParam : public dmlc::Parameter<RandomColorJitterParam> {
+  float brightness;
+  float contrast;
+  float saturation;
+  float hue;
+  DMLC_DECLARE_PARAMETER(RandomColorJitterParam) {
+    DMLC_DECLARE_FIELD(brightness)
+    .describe("How much to jitter brightness.");
+    DMLC_DECLARE_FIELD(contrast)
+    .describe("How much to jitter contrast.");
+    DMLC_DECLARE_FIELD(saturation)
+    .describe("How much to jitter saturation.");
+    DMLC_DECLARE_FIELD(hue)
+    .describe("How much to jitter hue.");
+  }
+};
+
+void RandomColorJitter(const nnvm::NodeAttrs &attrs,
+                       const OpContext &ctx,
+                       const std::vector<TBlob> &inputs,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  const RandomColorJitterParam &param = nnvm::get<RandomColorJitterParam>(attrs.parsed);
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
+
+  int order[4] = {0, 1, 2, 3};
+  std::shuffle(order, order + 4, prnd->GetRndEngine());
+  bool flag = false;
+
+  for (int i = 0; i < 4; ++i) {
+    switch (order[i]) {
+      case 0:
+        if (param.brightness > 0) {
+          float alpha_b = 1.0 + std::uniform_real_distribution<float>(
+              -param.brightness, param.brightness)(prnd->GetRndEngine());
+          AdjustBrightnessImpl(alpha_b, ctx, flag ? outputs : inputs, req, outputs);
+          flag = true;
+        }
+        break;
+      case 1:
+        if (param.contrast > 0) {
+          float alpha_c = 1.0 + std::uniform_real_distribution<float>(
+              -param.contrast, param.contrast)(prnd->GetRndEngine());
+          AdjustContrastImpl(alpha_c, ctx, flag ? outputs : inputs, req, outputs);
+          flag = true;
+        }
+        break;
+      case 2:
+        if (param.saturation > 0) {
+          float alpha_s = 1.f + std::uniform_real_distribution<float>(
+              -param.saturation, param.saturation)(prnd->GetRndEngine());
+          AdjustSaturationImpl(alpha_s, ctx, flag ? outputs : inputs, req, outputs);
+          flag = true;
+        }
+        break;
+      case 3:
+        if (param.hue > 0) {
+          float alpha_h = std::uniform_real_distribution<float>(
+              -param.hue, param.hue)(prnd->GetRndEngine());
+          AdjustHueImpl(alpha_h, ctx, flag ? outputs : inputs, req, outputs);
+          flag = true;
+        }
+        break;
+    }
+  }
 }
 
 struct AdjustLightingParam : public dmlc::Parameter<AdjustLightingParam> {
-  nnvm::Tuple<float> alpha_rgb;
-  nnvm::Tuple<float> eigval;
-  nnvm::Tuple<float> eigvec;
+  nnvm::Tuple<float> alpha;
   DMLC_DECLARE_PARAMETER(AdjustLightingParam) {
-    DMLC_DECLARE_FIELD(alpha_rgb)
-    .set_default({0, 0, 0})
+    DMLC_DECLARE_FIELD(alpha)
     .describe("The lighting alphas for the R, G, B channels.");
-    DMLC_DECLARE_FIELD(eigval)
-    .describe("Eigen value.")
-    .set_default({ 55.46, 4.794, 1.148 });
-    DMLC_DECLARE_FIELD(eigvec)
-    .describe("Eigen vector.")
-    .set_default({ -0.5675,  0.7192,  0.4009,
-                   -0.5808, -0.0045, -0.8140,
-                   -0.5808, -0.0045, -0.8140 });
   }
 };
 
 struct RandomLightingParam : public dmlc::Parameter<RandomLightingParam> {
   float alpha_std;
-  nnvm::Tuple<float> eigval;
-  nnvm::Tuple<float> eigvec;
   DMLC_DECLARE_PARAMETER(RandomLightingParam) {
     DMLC_DECLARE_FIELD(alpha_std)
     .set_default(0.05)
     .describe("Level of the lighting noise.");
-    DMLC_DECLARE_FIELD(eigval)
-    .describe("Eigen value.")
-    .set_default({ 55.46, 4.794, 1.148 });
-    DMLC_DECLARE_FIELD(eigvec)
-    .describe("Eigen vector.")
-    .set_default({ -0.5675,  0.7192,  0.4009,
-                   -0.5808, -0.0045, -0.8140,
-                   -0.5808, -0.0045, -0.8140 });
   }
 };
 
-void AdjustLightingImpl(uint8_t* dst, const uint8_t* src,
-                        float alpha_r, float alpha_g, float alpha_b,
-                        const nnvm::Tuple<float> eigval, const nnvm::Tuple<float> eigvec,
-                        int H, int W) {
-    alpha_r *= eigval[0];
-    alpha_g *= eigval[1];
-    alpha_b *= eigval[2];
-    float pca_r = alpha_r * eigvec[0] + alpha_g * eigvec[1] + alpha_b * eigvec[2];
-    float pca_g = alpha_r * eigvec[3] + alpha_g * eigvec[4] + alpha_b * eigvec[5];
-    float pca_b = alpha_r * eigvec[6] + alpha_g * eigvec[7] + alpha_b * eigvec[8];
-    for (int i = 0; i < H * W; i++) {
-        int base_ind = 3 * i;
-        float in_r = static_cast<float>(src[base_ind]);
-        float in_g = static_cast<float>(src[base_ind + 1]);
-        float in_b = static_cast<float>(src[base_ind + 2]);
-        dst[base_ind] = std::min(255, std::max(0, static_cast<int>(in_r + pca_r)));
-        dst[base_ind + 1] = std::min(255, std::max(0, static_cast<int>(in_g + pca_g)));
-        dst[base_ind + 2] = std::min(255, std::max(0, static_cast<int>(in_b + pca_b)));
+void AdjustLightingImpl(const nnvm::Tuple<float>& alpha,
+                        const OpContext &ctx,
+                        const std::vector<TBlob> &inputs,
+                        const std::vector<OpReqType> &req,
+                        const std::vector<TBlob> &outputs) {
+  static const float eig[3][3] = {
+      { 55.46 * -0.5675, 4.794 * 0.7192,  1.148 * 0.4009 },
+      { 55.46 * -0.5808, 4.794 * -0.0045, 1.148 * -0.8140 },
+      { 55.46 * -0.5836, 4.794 * -0.6948, 1.148 * 0.4203 }
+    };
+
+  int length = inputs[0].shape_[0] * inputs[0].shape_[1];
+  int channels = inputs[0].shape_[2];
+  if (channels == 1) return;
+
+  float pca_r = eig[0][0] * alpha[0] + eig[0][1] * alpha[1] + eig[0][2] * alpha[2];
+  float pca_g = eig[1][0] * alpha[0] + eig[1][1] * alpha[1] + eig[1][2] * alpha[2];
+  float pca_b = eig[2][0] * alpha[0] + eig[2][1] * alpha[1] + eig[2][2] * alpha[2];
+
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    DType* output = outputs[0].dptr<DType>();
+    DType* input = inputs[0].dptr<DType>();
+
+    for (int i = 0; i < length; i++) {
+      int base_ind = 3 * i;
+      float in_r = static_cast<float>(input[base_ind]);
+      float in_g = static_cast<float>(input[base_ind + 1]);
+      float in_b = static_cast<float>(input[base_ind + 2]);
+      output[base_ind] = saturate_cast<DType>(in_r + pca_r);
+      output[base_ind + 1] = saturate_cast<DType>(in_g + pca_g);
+      output[base_ind + 2] = saturate_cast<DType>(in_b + pca_b);
     }
+  });
 }
 
-static void AdjustLighting(const nnvm::NodeAttrs &attrs,
-                           const OpContext &ctx,
-                           const std::vector<TBlob> &inputs,
-                           const std::vector<OpReqType> &req,
-                           const std::vector<TBlob> &outputs) {
-    using namespace mshadow;
-    const AdjustLightingParam &param = nnvm::get<AdjustLightingParam>(attrs.parsed);
-    CHECK_EQ(param.eigval.ndim(), 3) << "There should be 3 numbers in the eigval.";
-    CHECK_EQ(param.eigvec.ndim(), 9) << "There should be 9 numbers in the eigvec.";
-    CHECK_EQ(inputs[0].ndim(), 3);
-    CHECK_EQ(inputs[0].size(2), 3);
-    int H = inputs[0].size(0);
-    int W = inputs[0].size(1);
-    AdjustLightingImpl(outputs[0].dptr<uint8_t>(), inputs[0].dptr<uint8_t>(),
-                       param.alpha_rgb[0], param.alpha_rgb[1], param.alpha_rgb[2],
-                       param.eigval, param.eigvec, H, W);
+void AdjustLighting(const nnvm::NodeAttrs &attrs,
+                    const OpContext &ctx,
+                    const std::vector<TBlob> &inputs,
+                    const std::vector<OpReqType> &req,
+                    const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  const AdjustLightingParam &param = nnvm::get<AdjustLightingParam>(attrs.parsed);
+  AdjustLightingImpl(param.alpha, ctx, inputs, req, outputs);
 }
 
-static void RandomLighting(const nnvm::NodeAttrs &attrs,
-                           const OpContext &ctx,
-                           const std::vector<TBlob> &inputs,
-                           const std::vector<OpReqType> &req,
-                           const std::vector<TBlob> &outputs) {
-    using namespace mshadow;
-    const RandomLightingParam &param = nnvm::get<RandomLightingParam>(attrs.parsed);
-    CHECK_EQ(param.eigval.ndim(), 3) << "There should be 3 numbers in the eigval.";
-    CHECK_EQ(param.eigvec.ndim(), 9) << "There should be 9 numbers in the eigvec.";
-    CHECK_EQ(inputs[0].ndim(), 3);
-    CHECK_EQ(inputs[0].size(2), 3);
-    int H = inputs[0].size(0);
-    int W = inputs[0].size(1);
-    Stream<cpu> *s = ctx.get_stream<cpu>();
-    Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
-    std::normal_distribution<float> dist(0, param.alpha_std);
-    float alpha_r = dist(prnd->GetRndEngine());
-    float alpha_g = dist(prnd->GetRndEngine());
-    float alpha_b = dist(prnd->GetRndEngine());
-    AdjustLightingImpl(outputs[0].dptr<uint8_t>(), inputs[0].dptr<uint8_t>(),
-                       alpha_r, alpha_g, alpha_b,
-                       param.eigval, param.eigvec, H, W);
+void RandomLighting(const nnvm::NodeAttrs &attrs,
+                    const OpContext &ctx,
+                    const std::vector<TBlob> &inputs,
+                    const std::vector<OpReqType> &req,
+                    const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  const RandomLightingParam &param = nnvm::get<RandomLightingParam>(attrs.parsed);
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  Random<cpu> *prnd = ctx.requested[0].get_random<cpu, float>(s);
+  std::normal_distribution<float> dist(0, param.alpha_std);
+  float alpha_r = dist(prnd->GetRndEngine());
+  float alpha_g = dist(prnd->GetRndEngine());
+  float alpha_b = dist(prnd->GetRndEngine());
+  AdjustLightingImpl({alpha_r, alpha_g, alpha_b}, ctx, inputs, req, outputs);
 }
 
-
-
-
+}  // namespace image
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc
index 29edeed..5a21bf8 100644
--- a/src/operator/image/image_random.cc
+++ b/src/operator/image/image_random.cc
@@ -30,6 +30,13 @@
 
 namespace mxnet {
 namespace op {
+namespace image {
+
+DMLC_REGISTER_PARAMETER(NormalizeParam);
+DMLC_REGISTER_PARAMETER(RandomEnhanceParam);
+DMLC_REGISTER_PARAMETER(AdjustLightingParam);
+DMLC_REGISTER_PARAMETER(RandomLightingParam);
+DMLC_REGISTER_PARAMETER(RandomColorJitterParam);
 
 NNVM_REGISTER_OP(_image_to_tensor)
 .describe(R"code()code" ADD_FILELINE)
@@ -42,13 +49,12 @@ NNVM_REGISTER_OP(_image_to_tensor)
 .add_argument("data", "NDArray-or-Symbol", "The input.");
 
 
-DMLC_REGISTER_PARAMETER(NormalizeParam);
 NNVM_REGISTER_OP(_image_normalize)
 .describe(R"code()code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
 .set_attr_parser(ParamParser<NormalizeParam>)
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", NormalizeShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
   [](const NodeAttrs& attrs){
@@ -59,33 +65,31 @@ NNVM_REGISTER_OP(_image_normalize)
 .add_argument("data", "NDArray-or-Symbol", "The input.")
 .add_arguments(NormalizeParam::__FIELDS__());
 
-DMLC_REGISTER_PARAMETER(FlipParam);
-NNVM_REGISTER_OP(_image_flip)
+
+NNVM_REGISTER_OP(_image_random_horizontal_flip)
 .describe(R"code()code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
-.set_attr_parser(ParamParser<FlipParam>)
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", ImageShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
-                                [](const NodeAttrs& attrs){
-                                  return std::vector<std::pair<int, int> >{{0, 0}};
-                                })
-.set_attr<FCompute>("FCompute<cpu>", Flip)
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.set_attr<FCompute>("FCompute<cpu>", RandomHorizontalFlip)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
-.add_argument("data", "NDArray-or-Symbol", "The input.")
-.add_arguments(FlipParam::__FIELDS__());
+.add_argument("data", "NDArray-or-Symbol", "The input.");
+
 
-DMLC_REGISTER_PARAMETER(RandomBrightnessParam);
 NNVM_REGISTER_OP(_image_random_brightness)
 .describe(R"code()code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
-.set_attr_parser(ParamParser<RandomBrightnessParam>)
+.set_attr_parser(ParamParser<RandomEnhanceParam>)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
   return std::vector<ResourceRequest>{ResourceRequest::kRandom};
 })
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", ImageShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
   [](const NodeAttrs& attrs){
@@ -94,18 +98,18 @@ NNVM_REGISTER_OP(_image_random_brightness)
 .set_attr<FCompute>("FCompute<cpu>", RandomBrightness)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
 .add_argument("data", "NDArray-or-Symbol", "The input.")
-.add_arguments(RandomBrightnessParam::__FIELDS__());
+.add_arguments(RandomEnhanceParam::__FIELDS__());
+
 
-DMLC_REGISTER_PARAMETER(RandomContrastParam);
 NNVM_REGISTER_OP(_image_random_contrast)
 .describe(R"code()code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
-.set_attr_parser(ParamParser<RandomContrastParam>)
+.set_attr_parser(ParamParser<RandomEnhanceParam>)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
   return std::vector<ResourceRequest>{ResourceRequest::kRandom};
 })
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", ImageShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
   [](const NodeAttrs& attrs){
@@ -114,18 +118,18 @@ NNVM_REGISTER_OP(_image_random_contrast)
 .set_attr<FCompute>("FCompute<cpu>", RandomContrast)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
 .add_argument("data", "NDArray-or-Symbol", "The input.")
-.add_arguments(RandomContrastParam::__FIELDS__());
+.add_arguments(RandomEnhanceParam::__FIELDS__());
+
 
-DMLC_REGISTER_PARAMETER(RandomSaturationParam);
 NNVM_REGISTER_OP(_image_random_saturation)
 .describe(R"code()code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
-.set_attr_parser(ParamParser<RandomSaturationParam>)
+.set_attr_parser(ParamParser<RandomEnhanceParam>)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
   return std::vector<ResourceRequest>{ResourceRequest::kRandom};
 })
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", ImageShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
   [](const NodeAttrs& attrs){
@@ -134,31 +138,44 @@ NNVM_REGISTER_OP(_image_random_saturation)
 .set_attr<FCompute>("FCompute<cpu>", RandomSaturation)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
 .add_argument("data", "NDArray-or-Symbol", "The input.")
-.add_arguments(RandomSaturationParam::__FIELDS__());
+.add_arguments(RandomEnhanceParam::__FIELDS__());
 
-DMLC_REGISTER_PARAMETER(RandomHueParam);
 NNVM_REGISTER_OP(_image_random_hue)
 .describe(R"code()code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
-.set_attr_parser(ParamParser<RandomHueParam>)
+.set_attr_parser(ParamParser<RandomEnhanceParam>)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
   return std::vector<ResourceRequest>{ResourceRequest::kRandom};
 })
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", ImageShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
-.set_attr<FCompute>("FCompute<cpu>", RandomHue<cpu>)
+.set_attr<FCompute>("FCompute<cpu>", RandomHue)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
 .add_argument("data", "NDArray-or-Symbol", "The input.")
-.add_arguments(RandomHueParam::__FIELDS__());
+.add_arguments(RandomEnhanceParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_image_random_color_jitter)
+.describe(R"code()code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<RandomColorJitterParam>)
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
+  return std::vector<ResourceRequest>{ResourceRequest::kRandom};
+})
+.set_attr<nnvm::FInferShape>("FInferShape", ImageShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FCompute>("FCompute<cpu>", RandomColorJitter)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
+.add_argument("data", "NDArray-or-Symbol", "The input.")
+.add_arguments(RandomColorJitterParam::__FIELDS__());
 
-DMLC_REGISTER_PARAMETER(AdjustLightingParam);
 NNVM_REGISTER_OP(_image_adjust_lighting)
 .describe(R"code(Adjust the lighting level of the input. Follow the AlexNet style.)code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
 .set_attr_parser(ParamParser<AdjustLightingParam>)
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", ImageShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
   [](const NodeAttrs& attrs){
@@ -169,7 +186,7 @@ NNVM_REGISTER_OP(_image_adjust_lighting)
 .add_argument("data", "NDArray-or-Symbol", "The input.")
 .add_arguments(AdjustLightingParam::__FIELDS__());
 
-DMLC_REGISTER_PARAMETER(RandomLightingParam);
+
 NNVM_REGISTER_OP(_image_random_lighting)
 .describe(R"code(Randomly add PCA noise. Follow the AlexNet style.)code" ADD_FILELINE)
 .set_num_inputs(1)
@@ -178,7 +195,7 @@ NNVM_REGISTER_OP(_image_random_lighting)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
   return std::vector<ResourceRequest>{ResourceRequest::kRandom};
 })
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", ImageShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
   [](const NodeAttrs& attrs){
@@ -189,5 +206,6 @@ NNVM_REGISTER_OP(_image_random_lighting)
 .add_argument("data", "NDArray-or-Symbol", "The input.")
 .add_arguments(RandomLightingParam::__FIELDS__());
 
+}  // namespace image
 }  // namespace op
 }  // namespace mxnet

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.