You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/29 01:14:31 UTC

[GitHub] piiswrong closed pull request #8856: Vision

piiswrong closed pull request #8856: Vision
URL: https://github.com/apache/incubator-mxnet/pull/8856
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h
index 3bee84321b..9d10a302dc 100644
--- a/src/operator/image/image_random-inl.h
+++ b/src/operator/image/image_random-inl.h
@@ -38,18 +38,19 @@
 
 namespace mxnet {
 namespace op {
+namespace image {
 
-inline bool CheckIsImage(const TBlob &image) {
-  CHECK_EQ(image.type_flag_, mshadow::kUint8) << "input type is not an image.";
-  CHECK_EQ(image.ndim(), 3) << "input dimension is not 3.";
-  CHECK(image.shape_[2] == 1 || image.shape_[2] == 3) << "image channel should be 1 or 3.";
-}
-
-static void RandomFlip(const nnvm::NodeAttrs &attrs,
-                       const OpContext &ctx,
-                       const std::vector<TBlob> &inputs,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &outputs) {
+inline bool ToTensorShape(const nnvm::NodeAttrs& attrs,
+                          std::vector<TShape> *in_attrs,
+                          std::vector<TShape> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  TShape &shp = (*in_attrs)[0];
+  if (!shp.ndim()) return false;
+  CHECK_EQ(shp.ndim(), 3)
+      << "Input image must have shape (height, width, channels), but got " << shp;
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[2], shp[0], shp[1]}));
+  return true;
 }
 
 inline bool ToTensorType(const nnvm::NodeAttrs& attrs,
@@ -57,47 +58,39 @@ inline bool ToTensorType(const nnvm::NodeAttrs& attrs,
                          std::vector<int> *out_attrs) {
   CHECK_EQ(in_attrs->size(), 1U);
   CHECK_EQ(out_attrs->size(), 1U);
-  CHECK_EQ((*in_attrs)[0], mshadow::kUint8)
-    << "`to_tensor` only supports uint8 input";
   TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32);
   return (*in_attrs)[0] != -1;
 }
 
-inline bool ToTensorShape(const nnvm::NodeAttrs& attrs,
-                          std::vector<TShape> *in_attrs,
-                          std::vector<TShape> *out_attrs) {
-  CHECK_EQ(in_attrs->size(), 1U);
-  CHECK_EQ(out_attrs->size(), 1U);
-  TShape &shp = (*in_attrs)[0];
-  CHECK_EQ(shp.ndim(), 3U) << "`to_tensor` only supports 3 dimensions";
-  TShape ret(3);
-  ret[0] = shp[2];
-  ret[1] = shp[0];
-  ret[2] = shp[1];
-  SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret);
-  return true;
-}
-
-static void ToTensor(const nnvm::NodeAttrs &attrs,
+void ToTensor(const nnvm::NodeAttrs &attrs,
                      const OpContext &ctx,
                      const std::vector<TBlob> &inputs,
                      const std::vector<OpReqType> &req,
                      const std::vector<TBlob> &outputs) {
   CHECK_EQ(req[0], kWriteTo)
     << "`to_tensor` does not support inplace";
-  CheckIsImage(inputs[0]);
 
   int length = inputs[0].shape_[0] * inputs[0].shape_[1];
   int channel = inputs[0].shape_[2];
 
-  float* output = outputs[0].dptr<float>();
-  uint8_t* input = inputs[0].dptr<uint8_t>();
+  MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    float* output = outputs[0].dptr<float>();
+    DType* input = inputs[0].dptr<DType>();
 
-  for (int l = 0; l < length; ++l) {
-    for (int c = 0; c < channel; ++c) {
-      output[c*length + l] = static_cast<float>(input[l*channel + c]) / 255.0f;
+    for (int l = 0; l < length; ++l) {
+      for (int c = 0; c < channel; ++c) {
+        output[c*length + l] = static_cast<float>(input[l*channel + c]) / 255.0f;
+      }
     }
-  }
+  });
+}
+
+inline bool TensorShape(const nnvm::NodeAttrs& attrs,
+                       std::vector<TShape> *in_attrs,
+                       std::vector<TShape> *out_attrs) {
+  TShape& dshape = (*in_attrs)[0];
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
+  return true;
 }
 
 struct NormalizeParam : public dmlc::Parameter<NormalizeParam> {
@@ -117,20 +110,28 @@ inline bool NormalizeShape(const nnvm::NodeAttrs& attrs,
   const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);
   const auto& dshape = (*in_attrs)[0];
   if (!dshape.ndim()) return false;
-  CHECK_EQ(dshape.ndim(), 3)
-      << "Input must have 3 dimensions";
 
+  CHECK_EQ(dshape.ndim(), 3)
+      << "Input tensor must have shape (channels, height, width), but got "
+      << dshape;
   auto nchannels = dshape[0];
+  CHECK(nchannels == 3 || nchannels == 1)
+      << "The first dimension of input tensor must be the channel dimension with "
+      << "either 1 or 3 elements, but got input with shape " << dshape;
   CHECK(param.mean.ndim() == 1 || param.mean.ndim() == nchannels)
-      << "mean must have either 1 or " << nchannels << " elements";
+      << "Invalid mean for input with shape " << dshape
+      << ". mean must have either 1 or " << nchannels
+      << " elements, but got " << param.mean;
   CHECK(param.std.ndim() == 1 || param.std.ndim() == nchannels)
-      << "std must have either 1 or " << nchannels << " elements";
+      << "Invalid std for input with shape " << dshape
+      << ". std must have either 1 or " << nchannels
+      << " elements, but got " << param.std;
 
   SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
+  return true;
 }
 
-
-static void Normalize(const nnvm::NodeAttrs &attrs,
+void Normalize(const nnvm::NodeAttrs &attrs,
                       const OpContext &ctx,
                       const std::vector<TBlob> &inputs,
                       const std::vector<OpReqType> &req,
@@ -154,214 +155,241 @@ static void Normalize(const nnvm::NodeAttrs &attrs,
   });
 }
 
-struct FlipParam : public dmlc::Parameter<FlipParam> {
-  int axis;
-  DMLC_DECLARE_PARAMETER(FlipParam) {
-    DMLC_DECLARE_FIELD(axis)
-    .describe("0 or 1. 0 for horizontal flip, 1 for vertical flip.");
-  }
-};
+template<typename DType>
+inline DType saturate_cast(const float& src) {
+  return static_cast<DType>(src);
+}
 
-#define SWAP_IF_INPLACE(dst, dst_idx, src, src_idx) \
-  if (dst == src) {                                 \
-    std::swap(dst[dst_idx], src[src_idx]);          \
-  } else {                                          \
-    dst[dst_idx] = src[src_idx];                    \
-  }
+template<>
+inline uint8_t saturate_cast(const float& src) {
+  return std::min(std::max(src, 0.f), 255.f);
+}
+
+inline bool ImageShape(const nnvm::NodeAttrs& attrs,
+                       std::vector<TShape> *in_attrs,
+                       std::vector<TShape> *out_attrs) {
+  TShape& dshape = (*in_attrs)[0];
+  CHECK_EQ(dshape.ndim(), 3)
+      << "Input image must have shape (height, width, channels), but got " << dshape;
+  auto nchannels = dshape[dshape.ndim()-1];
+  CHECK(nchannels == 3 || nchannels == 1)
+      << "The last dimension of input image must be the channel dimension with "
+      << "either 1 or 3 elements, but got input with shape " << dshape;
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
+  return true;
+}
 
 template<typename DType>
-static void FlipImpl(const TShape &shape, DType *src, DType *dst, int axis) {
-  const int height = shape[0];
-  const int width = shape[1];
-  const int nchannel = shape[2];
-
-  const int length = width * nchannel;
-  const int height_stride = (src == dst && axis == 1) ? (height >> 1) : height;
-  const int width_stride = (src == dst && axis == 0) ? (width >> 1) : width;
-
-  for (int h = 0; h < height_stride; ++h) {
-    const int h_dst = (axis == 0) ? h : (height - h);
-    for (int w = 0; w < width_stride; ++w) {
-      const int w_dst = (axis == 0) ? (width - w) : w;
-      const int idx_dst = h_dst * length + w_dst * nchannel;
-      const int idx_src = h * length + w * nchannel;
-      SWAP_IF_INPLACE(dst, idx_dst, src, idx_src);
-      if (nchannel > 1) {
-        SWAP_IF_INPLACE(dst, idx_dst + 1, src, idx_src + 1);
-        SWAP_IF_INPLACE(dst, idx_dst + 2, src, idx_src + 2);
+void FlipImpl(const TShape &shape, DType *src, DType *dst, int axis) {
+  int head = 1, mid = shape[axis], tail = 1;
+  for (int i = 0; i < axis; ++i) head *= shape[i];
+  for (int i = axis+1; i < shape.ndim(); ++i) tail *= shape[i];
+
+  for (int i = 0; i < head; ++i) {
+    for (int j = 0; j < (mid >>2); ++j) {
+      int idx1 = (i*mid + j)*tail;
+      int idx2 = idx1 + (mid - (j<<2))*tail;
+      for (int k = 0; k < tail; ++k, ++idx1, ++idx2) {
+        DType tmp = src[idx1];
+        dst[idx1] = src[idx2];
+        dst[idx2] = tmp;
       }
     }
   }
 }
 
-static void Flip(const nnvm::NodeAttrs &attrs,
-                  const OpContext &ctx,
-                  const std::vector<TBlob> &inputs,
-                  const std::vector<OpReqType> &req,
-                  const std::vector<TBlob> &outputs) {
-  const FlipParam &param = nnvm::get<FlipParam>(attrs.parsed);
-  CHECK(param.axis == 0 || param.axis == 1) << "flip axis must be 0 or 1.";
-  CheckIsImage(inputs[0]);
-  const TShape& ishape = inputs[0].shape_;
+void RandomHorizontalFlip(
+    const nnvm::NodeAttrs &attrs,
+    const OpContext &ctx,
+    const std::vector<TBlob> &inputs,
+    const std::vector<OpReqType> &req,
+    const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  Random<cpu> *prnd = ctx.requested[0].get_random<cpu, float>(s);
+  if (std::bernoulli_distribution()(prnd->GetRndEngine())) return;
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    FlipImpl(inputs[0].shape_, inputs[0].dptr<DType>(),
+             outputs[0].dptr<DType>(), 1);
+  });
+}
+
+void RandomVerticalFlip(
+    const nnvm::NodeAttrs &attrs,
+    const OpContext &ctx,
+    const std::vector<TBlob> &inputs,
+    const std::vector<OpReqType> &req,
+    const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  Random<cpu> *prnd = ctx.requested[0].get_random<cpu, float>(s);
+  if (std::bernoulli_distribution()(prnd->GetRndEngine())) return;
   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-    FlipImpl(ishape, inputs[0].dptr<DType>(), outputs[0].dptr<DType>(), param.axis);
+    FlipImpl(inputs[0].shape_, inputs[0].dptr<DType>(),
+             outputs[0].dptr<DType>(), 0);
   });
 }
 
-struct RandomBrightnessParam : public dmlc::Parameter<RandomBrightnessParam> {
-  float max_brightness;
-  DMLC_DECLARE_PARAMETER(RandomBrightnessParam) {
-    DMLC_DECLARE_FIELD(max_brightness)
+struct RandomEnhanceParam : public dmlc::Parameter<RandomEnhanceParam> {
+  float min_factor;
+  float max_factor;
+  DMLC_DECLARE_PARAMETER(RandomEnhanceParam) {
+    DMLC_DECLARE_FIELD(min_factor)
+    .set_lower_bound(0.0)
+    .describe("Minimum factor.");
+    DMLC_DECLARE_FIELD(max_factor)
     .set_lower_bound(0.0)
-    .describe("Max Brightness.");
+    .describe("Maximum factor.");
   }
 };
 
-static void RandomBrightness(const nnvm::NodeAttrs &attrs,
-                             const OpContext &ctx,
-                             const std::vector<TBlob> &inputs,
-                             const std::vector<OpReqType> &req,
-                             const std::vector<TBlob> &outputs) {
+inline void AdjustBrightnessImpl(const float& alpha_b,
+                                 const OpContext &ctx,
+                                 const std::vector<TBlob> &inputs,
+                                 const std::vector<OpReqType> &req,
+                                 const std::vector<TBlob> &outputs) {
   using namespace mshadow;
-  const RandomBrightnessParam &param = nnvm::get<RandomBrightnessParam>(attrs.parsed);
-
   int length = inputs[0].Size();
 
-  uint8_t* output = outputs[0].dptr<uint8_t>();
-  uint8_t* input = inputs[0].dptr<uint8_t>();
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    DType* output = outputs[0].dptr<DType>();
+    DType* input = inputs[0].dptr<DType>();
+    for (int l = 0; l < length; ++l) {
+      float val = static_cast<float>(input[l]) * alpha_b;
+      output[l] = saturate_cast<DType>(val);
+    }
+  });
+}
+
+void RandomBrightness(const nnvm::NodeAttrs &attrs,
+                      const OpContext &ctx,
+                      const std::vector<TBlob> &inputs,
+                      const std::vector<OpReqType> &req,
+                      const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  const RandomEnhanceParam &param = nnvm::get<RandomEnhanceParam>(attrs.parsed);
+
 
   Stream<cpu> *s = ctx.get_stream<cpu>();
   Random<cpu> *prnd = ctx.requested[0].get_random<cpu, float>(s);
-  float alpha_b = 1.0 + std::uniform_real_distribution<float>(
-      -param.max_brightness, param.max_brightness)(prnd->GetRndEngine());
+  float alpha_b = std::uniform_real_distribution<float>(
+      param.min_factor, param.max_factor)(prnd->GetRndEngine());
 
-  for (int l = 0; l < length; ++l) {
-    float val = static_cast<float>(input[l]) * alpha_b;
-    val = std::min(std::max(val, 0.f), 255.f);
-    output[l] = static_cast<uint8_t>(val);
-  }
+  AdjustBrightnessImpl(alpha_b, ctx, inputs, req, outputs);
 }
 
+inline void AdjustContrastImpl(const float& alpha_c,
+                               const OpContext &ctx,
+                               const std::vector<TBlob> &inputs,
+                               const std::vector<OpReqType> &req,
+                               const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  static const float coef[] = { 0.299f, 0.587f, 0.114f };
 
-struct RandomContrastParam : public dmlc::Parameter<RandomContrastParam> {
-  float max_contrast;
-  DMLC_DECLARE_PARAMETER(RandomContrastParam) {
-    DMLC_DECLARE_FIELD(max_contrast)
-    .set_lower_bound(0.0)
-    .describe("Max Contrast.");
-  }
-};
+  int length = inputs[0].shape_[0] * inputs[0].shape_[1];
+  int nchannels = inputs[0].shape_[2];
 
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    DType* output = outputs[0].dptr<DType>();
+    DType* input = inputs[0].dptr<DType>();
+
+    float sum = 0.f;
+    if (nchannels > 1) {
+      for (int l = 0; l < length; ++l) {
+        for (int c = 0; c < 3; ++c) sum += input[l*3 + c] * coef[c];
+      }
+    } else {
+      for (int l = 0; l < length; ++l) sum += input[l];
+    }
+    float gray_mean = sum / static_cast<float>(length);
+    float beta = (1 - alpha_c) * gray_mean;
 
-static void RandomContrast(const nnvm::NodeAttrs &attrs,
+    for (int l = 0; l < length * nchannels; ++l) {
+      float val = input[l] * alpha_c + beta;
+      output[l] = saturate_cast<DType>(val);
+    }
+  });
+}
+
+inline void RandomContrast(const nnvm::NodeAttrs &attrs,
                            const OpContext &ctx,
                            const std::vector<TBlob> &inputs,
                            const std::vector<OpReqType> &req,
                            const std::vector<TBlob> &outputs) {
   using namespace mshadow;
+  const RandomEnhanceParam &param = nnvm::get<RandomEnhanceParam>(attrs.parsed);
+
+
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
+  float alpha_c = std::uniform_real_distribution<float>(
+      param.min_factor, param.max_factor)(prnd->GetRndEngine());
+
+  AdjustContrastImpl(alpha_c, ctx, inputs, req, outputs);
+}
+
+inline void AdjustSaturationImpl(const float& alpha_s,
+                                 const OpContext &ctx,
+                                 const std::vector<TBlob> &inputs,
+                                 const std::vector<OpReqType> &req,
+                                 const std::vector<TBlob> &outputs) {
   static const float coef[] = { 0.299f, 0.587f, 0.114f };
-  const RandomContrastParam &param = nnvm::get<RandomContrastParam>(attrs.parsed);
 
   int length = inputs[0].shape_[0] * inputs[0].shape_[1];
   int nchannels = inputs[0].shape_[2];
 
-  uint8_t* output = outputs[0].dptr<uint8_t>();
-  uint8_t* input = inputs[0].dptr<uint8_t>();
+  float alpha_o = 1.f - alpha_s;
 
-  Stream<cpu> *s = ctx.get_stream<cpu>();
-  Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
-  float alpha_c = 1.0 + std::uniform_real_distribution<float>(
-    -param.max_contrast, param.max_contrast)(prnd->GetRndEngine());
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    DType* output = outputs[0].dptr<DType>();
+    DType* input = inputs[0].dptr<DType>();
 
-  float sum = 0.f;
-  if (nchannels > 1) {
-    for (int l = 0; l < length; ++l) {
-      for (int c = 0; c < nchannels; ++c) sum += input[l*nchannels + c] * coef[c];
+    if (nchannels == 1) {
+      for (int l = 0; l < length; ++l) output[l] = input[l];
+      return;
     }
-  } else {
-    for (int l = 0; l < length; ++l) sum += input[l];
-  }
-  float gray_mean = sum / static_cast<float>(length);
-  float beta = (1 - alpha_c) * gray_mean;
 
-  for (int l = 0; l < length * nchannels; ++l) {
-    float val = input[l] * alpha_c + beta;
-    val = std::min(std::max(val, 0.f), 255.f);
-    output[l] = static_cast<uint8_t>(val);
-  }
+    for (int l = 0; l < length; ++l) {
+      float gray = 0.f;
+      for (int c = 0; c < 3; ++c) {
+        gray = input[l*3 + c] * coef[c];
+      }
+      gray *= alpha_o;
+      for (int c = 0; c < 3; ++c) {
+        float val = gray + input[l*3 + c] * alpha_s;
+        output[l*3 + c] = saturate_cast<DType>(val);
+      }
+    }
+  });
 }
 
-struct RandomSaturationParam : public dmlc::Parameter<RandomSaturationParam> {
-  float max_saturation;
-  DMLC_DECLARE_PARAMETER(RandomSaturationParam) {
-    DMLC_DECLARE_FIELD(max_saturation)
-    .set_default(0.0)
-    .describe("Max Saturation.");
-  }
-};
-
-static void RandomSaturation(const nnvm::NodeAttrs &attrs,
+inline void RandomSaturation(const nnvm::NodeAttrs &attrs,
                              const OpContext &ctx,
                              const std::vector<TBlob> &inputs,
                              const std::vector<OpReqType> &req,
                              const std::vector<TBlob> &outputs) {
   using namespace mshadow;
-  const RandomSaturationParam &param = nnvm::get<RandomSaturationParam>(attrs.parsed);
-  static const float coef[] = { 0.299f, 0.587f, 0.114f };
-
-  int length = inputs[0].shape_[0] * inputs[0].shape_[1];
-  int nchannels = inputs[0].shape_[2];
-
-  uint8_t* output = outputs[0].dptr<uint8_t>();
-  uint8_t* input = inputs[0].dptr<uint8_t>();
+  const RandomEnhanceParam &param = nnvm::get<RandomEnhanceParam>(attrs.parsed);
 
   Stream<cpu> *s = ctx.get_stream<cpu>();
   Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
-  float alpha_s = 1.f + std::uniform_real_distribution<float>(
-    -param.max_saturation, param.max_saturation)(prnd->GetRndEngine());
-  float alpha_o = 1.f - alpha_s;
-
-  if (nchannels == 1) {
-    for (int l = 0; l < length * nchannels; ++l) output[l] = input[l];
-    return;
-  }
+  float alpha_s = std::uniform_real_distribution<float>(
+      param.min_factor, param.max_factor)(prnd->GetRndEngine());
 
-  for (int l = 0; l < length; ++l) {
-    float gray = 0.f;
-    for (int c = 0; c < nchannels; ++c) {
-      gray = input[l*nchannels + c] * coef[c];
-    }
-    gray *= alpha_o;
-    for (int c = 0; c < nchannels; ++c) {
-      float val = gray + input[l*nchannels + c] * alpha_s;
-      val = std::min(std::max(val, 0.f), 255.f);
-      output[l*nchannels + c] = static_cast<uint8_t>(val);
-    }
-  }
+  AdjustSaturationImpl(alpha_s, ctx, inputs, req, outputs);
 }
 
-struct RandomHueParam : public dmlc::Parameter<RandomHueParam> {
-  float max_hue;
-  DMLC_DECLARE_PARAMETER(RandomHueParam) {
-    DMLC_DECLARE_FIELD(max_hue)
-    .set_default(0.0)
-    .describe("Max Hue.");
-  }
-};
-
-template <typename DType> static
-void RGB2HLSConvert(const DType src_r,
-                    const DType src_g,
-                    const DType src_b,
-                    DType *dst_h,
-                    DType *dst_l,
-                    DType *dst_s
-                   ) {
-  DType b = src_b, g = src_g, r = src_r;
-  DType h = 0.f, s = 0.f, l;
-  DType vmin;
-  DType vmax;
-  DType diff;
+void RGB2HLSConvert(const float& src_r,
+                    const float& src_g,
+                    const float& src_b,
+                    float *dst_h,
+                    float *dst_l,
+                    float *dst_s) {
+  float b = src_b / 255.f, g = src_g / 255.f, r = src_r / 255.f;
+  float h = 0.f, s = 0.f, l;
+  float vmin;
+  float vmax;
+  float diff;
 
   vmax = vmin = r;
   vmax = fmax(vmax, g);
@@ -372,7 +400,7 @@ void RGB2HLSConvert(const DType src_r,
   diff = vmax - vmin;
   l = (vmax + vmin) * 0.5f;
 
-  if (diff > std::numeric_limits<DType>::epsilon()) {
+  if (diff > std::numeric_limits<float>::epsilon()) {
     s = (l < 0.5f) * diff / (vmax + vmin);
     s += (l >= 0.5f) * diff / (2.0f - vmax - vmin);
 
@@ -389,23 +417,20 @@ void RGB2HLSConvert(const DType src_r,
   *dst_s = s;
 }
 
-
-static  int c_HlsSectorData[6][3] = {
-  { 1, 3, 0 },
-  { 1, 0, 2 },
-  { 3, 0, 1 },
-  { 0, 2, 1 },
-  { 0, 1, 3 },
-  { 2, 1, 0 }
-};
-
-template <typename DType>  static  void HLS2RGBConvert(const DType src_h,
-    const DType src_l,
-    const DType src_s,
-    DType *dst_r,
-    DType *dst_g,
-    DType *dst_b) {
-
+void HLS2RGBConvert(const float& src_h,
+                    const float& src_l,
+                    const float& src_s,
+                    float *dst_r,
+                    float *dst_g,
+                    float *dst_b) {
+  static const int c_HlsSectorData[6][3] = {
+    { 1, 3, 0 },
+    { 1, 0, 2 },
+    { 3, 0, 1 },
+    { 0, 2, 1 },
+    { 0, 1, 3 },
+    { 2, 1, 0 }
+  };
 
   float h = src_h, l = src_l, s = src_s;
   float b = l, g = l, r = l;
@@ -415,6 +440,8 @@ template <typename DType>  static  void HLS2RGBConvert(const DType src_h,
     p2 += (l > 0.5f) * (l + s - l * s);
     float p1 = 2 * l - p2;
 
+    h *= 1.f / 60.f;
+
     if (h < 0) {
       do { h += 6; } while (h < 0);
     } else if (h >= 6) {
@@ -436,177 +463,202 @@ template <typename DType>  static  void HLS2RGBConvert(const DType src_h,
     r = tab[c_HlsSectorData[sector][2]];
   }
 
-  *dst_b = b;
-  *dst_g = g;
-  *dst_r = r;
+  *dst_b = b * 255.f;
+  *dst_g = g * 255.f;
+  *dst_r = r * 255.f;
 }
 
-template<typename xpu, typename DType>
-static  void RandomHueKernal(const TBlob &input,
-                             const TBlob &output,
-                             Stream<xpu> *s,
-                             int hight,
-                             int weight,
-                             DType alpha) {
-  auto input_3d = input.get<xpu, 3, DType>(s);
-  auto output_3d = output.get<xpu, 3, DType>(s);
-  for (int h_index = 0; h_index < hight; ++h_index) {
-    for (int w_index = 0; w_index < weight; ++w_index) {
-      DType h;
-      DType l;
-      DType s;
-      RGB2HLSConvert(input_3d[0][h_index][w_index],
-                     input_3d[1][h_index][w_index],
-                     input_3d[2][h_index][w_index],
-                     &h, &l, &s);
-      h += alpha;
-      h = std::max(DType(0), std::min(DType(180), h));
-
-      HLS2RGBConvert(
-        h, l, s,
-        &output_3d[0][h_index][w_index],
-        &output_3d[1][h_index][w_index],
-        &output_3d[2][h_index][w_index]);
+void AdjustHueImpl(float alpha,
+                   const OpContext &ctx,
+                   const std::vector<TBlob> &inputs,
+                   const std::vector<OpReqType> &req,
+                   const std::vector<TBlob> &outputs) {
+  int length = inputs[0].shape_[0] * inputs[0].shape_[1];
+  if (inputs[0].shape_[2] == 1) return;
+
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    DType* input = inputs[0].dptr<DType>();
+    DType* output = outputs[0].dptr<DType>();
+
+    for (int i = 0; i < length; ++i) {
+      float h, l, s;
+      float r = static_cast<float>(*(input++));
+      float g = static_cast<float>(*(input++));
+      float b = static_cast<float>(*(input++));
+      RGB2HLSConvert(r, g, b, &h, &l, &s);
+      h += alpha * 360.f;
+      HLS2RGBConvert(h, l, s, &r, &g, &b);
+      *(output++) = saturate_cast<DType>(r);
+      *(output++) = saturate_cast<DType>(g);
+      *(output++) = saturate_cast<DType>(b);
     }
-  }
+  });
 }
 
-template<typename xpu>
-static void RandomHue(const nnvm::NodeAttrs &attrs,
-                      const OpContext &ctx,
-                      const std::vector<TBlob> &inputs,
-                      const std::vector<OpReqType> &req,
-                      const std::vector<TBlob> &outputs) {
+void RandomHue(const nnvm::NodeAttrs &attrs,
+               const OpContext &ctx,
+               const std::vector<TBlob> &inputs,
+               const std::vector<OpReqType> &req,
+               const std::vector<TBlob> &outputs) {
   using namespace mshadow;
-  auto input = inputs[0];
-  auto output = outputs[0];
-  int channel = input.shape_[0];
-  int hight = input.shape_[1];
-  int weight = input.shape_[2];
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  Random<xpu> *prnd = ctx.requested[kRandom].get_random<xpu, real_t>(s);
-
-  const RandomHueParam &param = nnvm::get<RandomHueParam>(attrs.parsed);
-  float alpha =  std::uniform_real_distribution<float>(
-    -param.max_hue, param.max_hue)(prnd->GetRndEngine());
-  auto output_float = output.get<xpu, 3, float>(s);
-
-  MSHADOW_TYPE_SWITCH(input.type_flag_, DType, {
-    RandomHueKernal<xpu, DType>(input, output, s, hight, weight, alpha);
-  });
+  const RandomEnhanceParam &param = nnvm::get<RandomEnhanceParam>(attrs.parsed);
+
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
+  float alpha = std::uniform_real_distribution<float>(
+      param.min_factor, param.max_factor)(prnd->GetRndEngine());
+
+  AdjustHueImpl(alpha, ctx, inputs, req, outputs);
 }
 
-static void RandomColorJitter(const nnvm::NodeAttrs &attrs,
-                              const OpContext &ctx,
-                              const std::vector<TBlob> &inputs,
-                              const std::vector<OpReqType> &req,
-                              const std::vector<TBlob> &outputs) {
+struct RandomColorJitterParam : public dmlc::Parameter<RandomColorJitterParam> {
+  float brightness;
+  float contrast;
+  float saturation;
+  float hue;
+  DMLC_DECLARE_PARAMETER(RandomColorJitterParam) {
+    DMLC_DECLARE_FIELD(brightness)
+    .describe("How much to jitter brightness.");
+    DMLC_DECLARE_FIELD(contrast)
+    .describe("How much to jitter contrast.");
+    DMLC_DECLARE_FIELD(saturation)
+    .describe("How much to jitter saturation.");
+    DMLC_DECLARE_FIELD(hue)
+    .describe("How much to jitter hue.");
+  }
+};
+
+void RandomColorJitter(const nnvm::NodeAttrs &attrs,
+                       const OpContext &ctx,
+                       const std::vector<TBlob> &inputs,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  const RandomColorJitterParam &param = nnvm::get<RandomColorJitterParam>(attrs.parsed);
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
+
+  int order[4] = {0, 1, 2, 3};
+  std::shuffle(order, order + 4, prnd->GetRndEngine());
+  bool flag = false;
+
+  for (int i = 0; i < 4; ++i) {
+    switch (order[i]) {
+      case 0:
+        if (param.brightness > 0) {
+          float alpha_b = 1.0 + std::uniform_real_distribution<float>(
+              -param.brightness, param.brightness)(prnd->GetRndEngine());
+          AdjustBrightnessImpl(alpha_b, ctx, flag ? outputs : inputs, req, outputs);
+          flag = true;
+        }
+        break;
+      case 1:
+        if (param.contrast > 0) {
+          float alpha_c = 1.0 + std::uniform_real_distribution<float>(
+              -param.contrast, param.contrast)(prnd->GetRndEngine());
+          AdjustContrastImpl(alpha_c, ctx, flag ? outputs : inputs, req, outputs);
+          flag = true;
+        }
+        break;
+      case 2:
+        if (param.saturation > 0) {
+          float alpha_s = 1.f + std::uniform_real_distribution<float>(
+              -param.saturation, param.saturation)(prnd->GetRndEngine());
+          AdjustSaturationImpl(alpha_s, ctx, flag ? outputs : inputs, req, outputs);
+          flag = true;
+        }
+        break;
+      case 3:
+        if (param.hue > 0) {
+          float alpha_h = std::uniform_real_distribution<float>(
+              -param.hue, param.hue)(prnd->GetRndEngine());
+          AdjustHueImpl(alpha_h, ctx, flag ? outputs : inputs, req, outputs);
+          flag = true;
+        }
+        break;
+    }
+  }
 }
 
 struct AdjustLightingParam : public dmlc::Parameter<AdjustLightingParam> {
-  nnvm::Tuple<float> alpha_rgb;
-  nnvm::Tuple<float> eigval;
-  nnvm::Tuple<float> eigvec;
+  nnvm::Tuple<float> alpha;
   DMLC_DECLARE_PARAMETER(AdjustLightingParam) {
-    DMLC_DECLARE_FIELD(alpha_rgb)
-    .set_default({0, 0, 0})
+    DMLC_DECLARE_FIELD(alpha)
     .describe("The lighting alphas for the R, G, B channels.");
-    DMLC_DECLARE_FIELD(eigval)
-    .describe("Eigen value.")
-    .set_default({ 55.46, 4.794, 1.148 });
-    DMLC_DECLARE_FIELD(eigvec)
-    .describe("Eigen vector.")
-    .set_default({ -0.5675,  0.7192,  0.4009,
-                   -0.5808, -0.0045, -0.8140,
-                   -0.5808, -0.0045, -0.8140 });
   }
 };
 
 struct RandomLightingParam : public dmlc::Parameter<RandomLightingParam> {
   float alpha_std;
-  nnvm::Tuple<float> eigval;
-  nnvm::Tuple<float> eigvec;
   DMLC_DECLARE_PARAMETER(RandomLightingParam) {
     DMLC_DECLARE_FIELD(alpha_std)
     .set_default(0.05)
     .describe("Level of the lighting noise.");
-    DMLC_DECLARE_FIELD(eigval)
-    .describe("Eigen value.")
-    .set_default({ 55.46, 4.794, 1.148 });
-    DMLC_DECLARE_FIELD(eigvec)
-    .describe("Eigen vector.")
-    .set_default({ -0.5675,  0.7192,  0.4009,
-                   -0.5808, -0.0045, -0.8140,
-                   -0.5808, -0.0045, -0.8140 });
   }
 };
 
-void AdjustLightingImpl(uint8_t* dst, const uint8_t* src,
-                        float alpha_r, float alpha_g, float alpha_b,
-                        const nnvm::Tuple<float> eigval, const nnvm::Tuple<float> eigvec,
-                        int H, int W) {
-    alpha_r *= eigval[0];
-    alpha_g *= eigval[1];
-    alpha_b *= eigval[2];
-    float pca_r = alpha_r * eigvec[0] + alpha_g * eigvec[1] + alpha_b * eigvec[2];
-    float pca_g = alpha_r * eigvec[3] + alpha_g * eigvec[4] + alpha_b * eigvec[5];
-    float pca_b = alpha_r * eigvec[6] + alpha_g * eigvec[7] + alpha_b * eigvec[8];
-    for (int i = 0; i < H * W; i++) {
-        int base_ind = 3 * i;
-        float in_r = static_cast<float>(src[base_ind]);
-        float in_g = static_cast<float>(src[base_ind + 1]);
-        float in_b = static_cast<float>(src[base_ind + 2]);
-        dst[base_ind] = std::min(255, std::max(0, static_cast<int>(in_r + pca_r)));
-        dst[base_ind + 1] = std::min(255, std::max(0, static_cast<int>(in_g + pca_g)));
-        dst[base_ind + 2] = std::min(255, std::max(0, static_cast<int>(in_b + pca_b)));
+void AdjustLightingImpl(const nnvm::Tuple<float>& alpha,
+                        const OpContext &ctx,
+                        const std::vector<TBlob> &inputs,
+                        const std::vector<OpReqType> &req,
+                        const std::vector<TBlob> &outputs) {
+  static const float eig[3][3] = {
+      { 55.46 * -0.5675, 4.794 * 0.7192,  1.148 * 0.4009 },
+      { 55.46 * -0.5808, 4.794 * -0.0045, 1.148 * -0.8140 },
+      { 55.46 * -0.5836, 4.794 * -0.6948, 1.148 * 0.4203 }
+    };
+
+  int length = inputs[0].shape_[0] * inputs[0].shape_[1];
+  int channels = inputs[0].shape_[2];
+  if (channels == 1) return;
+
+  float pca_r = eig[0][0] * alpha[0] + eig[0][1] * alpha[1] + eig[0][2] * alpha[2];
+  float pca_g = eig[1][0] * alpha[0] + eig[1][1] * alpha[1] + eig[1][2] * alpha[2];
+  float pca_b = eig[2][0] * alpha[0] + eig[2][1] * alpha[1] + eig[2][2] * alpha[2];
+
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    DType* output = outputs[0].dptr<DType>();
+    DType* input = inputs[0].dptr<DType>();
+
+    for (int i = 0; i < length; i++) {
+      int base_ind = 3 * i;
+      float in_r = static_cast<float>(input[base_ind]);
+      float in_g = static_cast<float>(input[base_ind + 1]);
+      float in_b = static_cast<float>(input[base_ind + 2]);
+      output[base_ind] = saturate_cast<DType>(in_r + pca_r);
+      output[base_ind + 1] = saturate_cast<DType>(in_g + pca_g);
+      output[base_ind + 2] = saturate_cast<DType>(in_b + pca_b);
     }
+  });
 }
 
-static void AdjustLighting(const nnvm::NodeAttrs &attrs,
-                           const OpContext &ctx,
-                           const std::vector<TBlob> &inputs,
-                           const std::vector<OpReqType> &req,
-                           const std::vector<TBlob> &outputs) {
-    using namespace mshadow;
-    const AdjustLightingParam &param = nnvm::get<AdjustLightingParam>(attrs.parsed);
-    CHECK_EQ(param.eigval.ndim(), 3) << "There should be 3 numbers in the eigval.";
-    CHECK_EQ(param.eigvec.ndim(), 9) << "There should be 9 numbers in the eigvec.";
-    CHECK_EQ(inputs[0].ndim(), 3);
-    CHECK_EQ(inputs[0].size(2), 3);
-    int H = inputs[0].size(0);
-    int W = inputs[0].size(1);
-    AdjustLightingImpl(outputs[0].dptr<uint8_t>(), inputs[0].dptr<uint8_t>(),
-                       param.alpha_rgb[0], param.alpha_rgb[1], param.alpha_rgb[2],
-                       param.eigval, param.eigvec, H, W);
+void AdjustLighting(const nnvm::NodeAttrs &attrs,
+                    const OpContext &ctx,
+                    const std::vector<TBlob> &inputs,
+                    const std::vector<OpReqType> &req,
+                    const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  const AdjustLightingParam &param = nnvm::get<AdjustLightingParam>(attrs.parsed);
+  AdjustLightingImpl(param.alpha, ctx, inputs, req, outputs);
 }
 
-static void RandomLighting(const nnvm::NodeAttrs &attrs,
-                           const OpContext &ctx,
-                           const std::vector<TBlob> &inputs,
-                           const std::vector<OpReqType> &req,
-                           const std::vector<TBlob> &outputs) {
-    using namespace mshadow;
-    const RandomLightingParam &param = nnvm::get<RandomLightingParam>(attrs.parsed);
-    CHECK_EQ(param.eigval.ndim(), 3) << "There should be 3 numbers in the eigval.";
-    CHECK_EQ(param.eigvec.ndim(), 9) << "There should be 9 numbers in the eigvec.";
-    CHECK_EQ(inputs[0].ndim(), 3);
-    CHECK_EQ(inputs[0].size(2), 3);
-    int H = inputs[0].size(0);
-    int W = inputs[0].size(1);
-    Stream<cpu> *s = ctx.get_stream<cpu>();
-    Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
-    std::normal_distribution<float> dist(0, param.alpha_std);
-    float alpha_r = dist(prnd->GetRndEngine());
-    float alpha_g = dist(prnd->GetRndEngine());
-    float alpha_b = dist(prnd->GetRndEngine());
-    AdjustLightingImpl(outputs[0].dptr<uint8_t>(), inputs[0].dptr<uint8_t>(),
-                       alpha_r, alpha_g, alpha_b,
-                       param.eigval, param.eigvec, H, W);
+void RandomLighting(const nnvm::NodeAttrs &attrs,
+                    const OpContext &ctx,
+                    const std::vector<TBlob> &inputs,
+                    const std::vector<OpReqType> &req,
+                    const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  const RandomLightingParam &param = nnvm::get<RandomLightingParam>(attrs.parsed);
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  Random<cpu> *prnd = ctx.requested[0].get_random<cpu, float>(s);
+  std::normal_distribution<float> dist(0, param.alpha_std);
+  float alpha_r = dist(prnd->GetRndEngine());
+  float alpha_g = dist(prnd->GetRndEngine());
+  float alpha_b = dist(prnd->GetRndEngine());
+  AdjustLightingImpl({alpha_r, alpha_g, alpha_b}, ctx, inputs, req, outputs);
 }
 
-
-
-
+}  // namespace image
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc
index 29edeedeaa..5a21bf8c60 100644
--- a/src/operator/image/image_random.cc
+++ b/src/operator/image/image_random.cc
@@ -30,6 +30,13 @@
 
 namespace mxnet {
 namespace op {
+namespace image {
+
+DMLC_REGISTER_PARAMETER(NormalizeParam);
+DMLC_REGISTER_PARAMETER(RandomEnhanceParam);
+DMLC_REGISTER_PARAMETER(AdjustLightingParam);
+DMLC_REGISTER_PARAMETER(RandomLightingParam);
+DMLC_REGISTER_PARAMETER(RandomColorJitterParam);
 
 NNVM_REGISTER_OP(_image_to_tensor)
 .describe(R"code()code" ADD_FILELINE)
@@ -42,13 +49,12 @@ NNVM_REGISTER_OP(_image_to_tensor)
 .add_argument("data", "NDArray-or-Symbol", "The input.");
 
 
-DMLC_REGISTER_PARAMETER(NormalizeParam);
 NNVM_REGISTER_OP(_image_normalize)
 .describe(R"code()code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
 .set_attr_parser(ParamParser<NormalizeParam>)
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", NormalizeShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
   [](const NodeAttrs& attrs){
@@ -59,33 +65,31 @@ NNVM_REGISTER_OP(_image_normalize)
 .add_argument("data", "NDArray-or-Symbol", "The input.")
 .add_arguments(NormalizeParam::__FIELDS__());
 
-DMLC_REGISTER_PARAMETER(FlipParam);
-NNVM_REGISTER_OP(_image_flip)
+
+NNVM_REGISTER_OP(_image_random_horizontal_flip)
 .describe(R"code()code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
-.set_attr_parser(ParamParser<FlipParam>)
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", ImageShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
-                                [](const NodeAttrs& attrs){
-                                  return std::vector<std::pair<int, int> >{{0, 0}};
-                                })
-.set_attr<FCompute>("FCompute<cpu>", Flip)
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.set_attr<FCompute>("FCompute<cpu>", RandomHorizontalFlip)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
-.add_argument("data", "NDArray-or-Symbol", "The input.")
-.add_arguments(FlipParam::__FIELDS__());
+.add_argument("data", "NDArray-or-Symbol", "The input.");
+
 
-DMLC_REGISTER_PARAMETER(RandomBrightnessParam);
 NNVM_REGISTER_OP(_image_random_brightness)
 .describe(R"code()code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
-.set_attr_parser(ParamParser<RandomBrightnessParam>)
+.set_attr_parser(ParamParser<RandomEnhanceParam>)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
   return std::vector<ResourceRequest>{ResourceRequest::kRandom};
 })
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", ImageShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
   [](const NodeAttrs& attrs){
@@ -94,18 +98,18 @@ NNVM_REGISTER_OP(_image_random_brightness)
 .set_attr<FCompute>("FCompute<cpu>", RandomBrightness)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
 .add_argument("data", "NDArray-or-Symbol", "The input.")
-.add_arguments(RandomBrightnessParam::__FIELDS__());
+.add_arguments(RandomEnhanceParam::__FIELDS__());
+
 
-DMLC_REGISTER_PARAMETER(RandomContrastParam);
 NNVM_REGISTER_OP(_image_random_contrast)
 .describe(R"code()code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
-.set_attr_parser(ParamParser<RandomContrastParam>)
+.set_attr_parser(ParamParser<RandomEnhanceParam>)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
   return std::vector<ResourceRequest>{ResourceRequest::kRandom};
 })
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", ImageShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
   [](const NodeAttrs& attrs){
@@ -114,18 +118,18 @@ NNVM_REGISTER_OP(_image_random_contrast)
 .set_attr<FCompute>("FCompute<cpu>", RandomContrast)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
 .add_argument("data", "NDArray-or-Symbol", "The input.")
-.add_arguments(RandomContrastParam::__FIELDS__());
+.add_arguments(RandomEnhanceParam::__FIELDS__());
+
 
-DMLC_REGISTER_PARAMETER(RandomSaturationParam);
 NNVM_REGISTER_OP(_image_random_saturation)
 .describe(R"code()code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
-.set_attr_parser(ParamParser<RandomSaturationParam>)
+.set_attr_parser(ParamParser<RandomEnhanceParam>)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
   return std::vector<ResourceRequest>{ResourceRequest::kRandom};
 })
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", ImageShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
   [](const NodeAttrs& attrs){
@@ -134,31 +138,44 @@ NNVM_REGISTER_OP(_image_random_saturation)
 .set_attr<FCompute>("FCompute<cpu>", RandomSaturation)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
 .add_argument("data", "NDArray-or-Symbol", "The input.")
-.add_arguments(RandomSaturationParam::__FIELDS__());
+.add_arguments(RandomEnhanceParam::__FIELDS__());
 
-DMLC_REGISTER_PARAMETER(RandomHueParam);
 NNVM_REGISTER_OP(_image_random_hue)
 .describe(R"code()code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
-.set_attr_parser(ParamParser<RandomHueParam>)
+.set_attr_parser(ParamParser<RandomEnhanceParam>)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
   return std::vector<ResourceRequest>{ResourceRequest::kRandom};
 })
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", ImageShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
-.set_attr<FCompute>("FCompute<cpu>", RandomHue<cpu>)
+.set_attr<FCompute>("FCompute<cpu>", RandomHue)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
 .add_argument("data", "NDArray-or-Symbol", "The input.")
-.add_arguments(RandomHueParam::__FIELDS__());
+.add_arguments(RandomEnhanceParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_image_random_color_jitter)
+.describe(R"code()code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<RandomColorJitterParam>)
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
+  return std::vector<ResourceRequest>{ResourceRequest::kRandom};
+})
+.set_attr<nnvm::FInferShape>("FInferShape", ImageShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FCompute>("FCompute<cpu>", RandomColorJitter)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
+.add_argument("data", "NDArray-or-Symbol", "The input.")
+.add_arguments(RandomColorJitterParam::__FIELDS__());
 
-DMLC_REGISTER_PARAMETER(AdjustLightingParam);
 NNVM_REGISTER_OP(_image_adjust_lighting)
 .describe(R"code(Adjust the lighting level of the input. Follow the AlexNet style.)code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
 .set_attr_parser(ParamParser<AdjustLightingParam>)
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", ImageShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
   [](const NodeAttrs& attrs){
@@ -169,7 +186,7 @@ NNVM_REGISTER_OP(_image_adjust_lighting)
 .add_argument("data", "NDArray-or-Symbol", "The input.")
 .add_arguments(AdjustLightingParam::__FIELDS__());
 
-DMLC_REGISTER_PARAMETER(RandomLightingParam);
+
 NNVM_REGISTER_OP(_image_random_lighting)
 .describe(R"code(Randomly add PCA noise. Follow the AlexNet style.)code" ADD_FILELINE)
 .set_num_inputs(1)
@@ -178,7 +195,7 @@ NNVM_REGISTER_OP(_image_random_lighting)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
   return std::vector<ResourceRequest>{ResourceRequest::kRandom};
 })
-.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferShape>("FInferShape", ImageShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
   [](const NodeAttrs& attrs){
@@ -189,5 +206,6 @@ NNVM_REGISTER_OP(_image_random_lighting)
 .add_argument("data", "NDArray-or-Symbol", "The input.")
 .add_arguments(RandomLightingParam::__FIELDS__());
 
+}  // namespace image
 }  // namespace op
 }  // namespace mxnet


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services