You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/06/19 22:23:27 UTC

[incubator-mxnet] branch master updated: Add standard ResNet data augmentation for ImageRecordIter (#11027)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new ccee176  Add standard ResNet data augmentation for ImageRecordIter (#11027)
ccee176 is described below

commit ccee17672b23fa864f5c2e67d6bcea5ccff2979e
Author: Tong He <he...@gmail.com>
AuthorDate: Tue Jun 19 15:23:19 2018 -0700

    Add standard ResNet data augmentation for ImageRecordIter (#11027)
    
    * add resnet augmentation
    
    * add test
    
    * fix scope
    
    * fix warning
    
    * fix lint
    
    * fix lint
    
    * add color jitter and pca noise
    
    * fix center crop
    
    * merge
    
    * fix lint
    
    * Trigger CI
    
    * fix
    
    * fix augmentation implementation
    
    * add checks for parameters
    
    * modify training script
    
    * fix compile error
    
    * Trigger CI
    
    * Trigger CI
    
    * modify error message
    
    * Trigger CI
    
    * Trigger CI
    
    * Trigger CI
    
    * improve script in example
    
    * fix script
    
    * clear code
    
    * Trigger CI
    
    * set min_aspect_ratio to optional, move rotation and pad before random resized crop
    
    * fix
    
    * Trigger CI
    
    * Trigger CI
    
    * Trigger CI
    
    * fix default values
    
    * Trigger CI
---
 example/image-classification/common/data.py    |  48 +++--
 example/image-classification/train_imagenet.py |   4 +-
 src/io/image_aug_default.cc                    | 241 +++++++++++++++++++++++--
 tests/python/train/test_resnet_aug.py          | 173 ++++++++++++++++++
 4 files changed, 435 insertions(+), 31 deletions(-)

diff --git a/example/image-classification/common/data.py b/example/image-classification/common/data.py
index 05f5ddc..bfaadb3 100755
--- a/example/image-classification/common/data.py
+++ b/example/image-classification/common/data.py
@@ -43,9 +43,9 @@ def add_data_args(parser):
 def add_data_aug_args(parser):
     aug = parser.add_argument_group(
         'Image augmentations', 'implemented in src/io/image_aug_default.cc')
-    aug.add_argument('--random-crop', type=int, default=1,
+    aug.add_argument('--random-crop', type=int, default=0,
                      help='if or not randomly crop the image')
-    aug.add_argument('--random-mirror', type=int, default=1,
+    aug.add_argument('--random-mirror', type=int, default=0,
                      help='if or not randomly flip horizontally')
     aug.add_argument('--max-random-h', type=int, default=0,
                      help='max change of hue, whose range is [0, 180]')
@@ -53,8 +53,13 @@ def add_data_aug_args(parser):
                      help='max change of saturation, whose range is [0, 255]')
     aug.add_argument('--max-random-l', type=int, default=0,
                      help='max change of intensity, whose range is [0, 255]')
+    aug.add_argument('--min-random-aspect-ratio', type=float, default=None,
+                     help='min value of aspect ratio, whose value is either None or a positive value.')
     aug.add_argument('--max-random-aspect-ratio', type=float, default=0,
-                     help='max change of aspect ratio, whose range is [0, 1]')
+                     help='max value of aspect ratio. If min_random_aspect_ratio is None, '
+                          'the aspect ratio range is [1-max_random_aspect_ratio, '
+                          '1+max_random_aspect_ratio], otherwise it is '
+                          '[min_random_aspect_ratio, max_random_aspect_ratio].')
     aug.add_argument('--max-random-rotate-angle', type=int, default=0,
                      help='max angle to rotate, whose range is [0, 360]')
     aug.add_argument('--max-random-shear-ratio', type=float, default=0,
@@ -63,16 +68,28 @@ def add_data_aug_args(parser):
                      help='max ratio to scale')
     aug.add_argument('--min-random-scale', type=float, default=1,
                      help='min ratio to scale, should >= img_size/input_shape. otherwise use --pad-size')
+    aug.add_argument('--max-random-area', type=float, default=1,
+                     help='max area to crop in random resized crop, whose range is [0, 1]')
+    aug.add_argument('--min-random-area', type=float, default=1,
+                     help='min area to crop in random resized crop, whose range is [0, 1]')
+    aug.add_argument('--brightness', type=float, default=0,
+                     help='brightness jittering, whose range is [0, 1]')
+    aug.add_argument('--contrast', type=float, default=0,
+                     help='contrast jittering, whose range is [0, 1]')
+    aug.add_argument('--saturation', type=float, default=0,
+                     help='saturation jittering, whose range is [0, 1]')
+    aug.add_argument('--pca-noise', type=float, default=0,
+                     help='pca noise, whose range is [0, 1]')
+    aug.add_argument('--random-resized-crop', type=int, default=0,
+                     help='whether to use random resized crop')
     return aug
 
-def set_data_aug_level(aug, level):
-    if level >= 1:
-        aug.set_defaults(random_crop=1, random_mirror=1)
-    if level >= 2:
-        aug.set_defaults(max_random_h=36, max_random_s=50, max_random_l=50)
-    if level >= 3:
-        aug.set_defaults(max_random_rotate_angle=10, max_random_shear_ratio=0.1, max_random_aspect_ratio=0.25)
-
+def set_resnet_aug(aug):
+    # standard data augmentation setting for resnet training
+    aug.set_defaults(random_crop=1, random_resized_crop=1)
+    aug.set_defaults(min_random_area=0.08)
+    aug.set_defaults(max_random_aspect_ratio=4./3., min_random_aspect_ratio=3./4.)
+    aug.set_defaults(brightness=0.4, contrast=0.4, saturation=0.4, pca_noise=0.1)
 
 class SyntheticDataIter(DataIter):
     def __init__(self, num_classes, data_shape, max_iter, dtype):
@@ -135,8 +152,16 @@ def get_rec_iter(args, kv=None):
         max_random_scale    = args.max_random_scale,
         pad                 = args.pad_size,
         fill_value          = 127,
+        random_resized_crop = args.random_resized_crop,
         min_random_scale    = args.min_random_scale,
         max_aspect_ratio    = args.max_random_aspect_ratio,
+        min_aspect_ratio    = args.min_random_aspect_ratio,
+        max_random_area     = args.max_random_area,
+        min_random_area     = args.min_random_area,
+        brightness          = args.brightness,
+        contrast            = args.contrast,
+        saturation          = args.saturation,
+        pca_noise           = args.pca_noise,
         random_h            = args.max_random_h,
         random_s            = args.max_random_s,
         random_l            = args.max_random_l,
@@ -156,6 +181,7 @@ def get_rec_iter(args, kv=None):
         mean_r              = rgb_mean[0],
         mean_g              = rgb_mean[1],
         mean_b              = rgb_mean[2],
+        resize              = 256,
         data_name           = 'data',
         label_name          = 'softmax_label',
         batch_size          = args.batch_size,
diff --git a/example/image-classification/train_imagenet.py b/example/image-classification/train_imagenet.py
index f465fbc..a90b6ae 100644
--- a/example/image-classification/train_imagenet.py
+++ b/example/image-classification/train_imagenet.py
@@ -30,8 +30,8 @@ if __name__ == '__main__':
     fit.add_fit_args(parser)
     data.add_data_args(parser)
     data.add_data_aug_args(parser)
-    # use a large aug level
-    data.set_data_aug_level(parser, 3)
+    # uncomment to set standard augmentation for resnet training
+    # data.set_resnet_aug(parser)
     parser.set_defaults(
         # network
         network          = 'resnet',
diff --git a/src/io/image_aug_default.cc b/src/io/image_aug_default.cc
index 22af7d9..f7d08b9 100644
--- a/src/io/image_aug_default.cc
+++ b/src/io/image_aug_default.cc
@@ -46,10 +46,14 @@ struct DefaultImageAugmentParam : public dmlc::Parameter<DefaultImageAugmentPara
   int resize;
   /*! \brief whether we do random cropping */
   bool rand_crop;
+  /*! \brief whether we do random resized cropping */
+  bool random_resized_crop;
   /*! \brief [-max_rotate_angle, max_rotate_angle] */
   int max_rotate_angle;
   /*! \brief max aspect ratio */
   float max_aspect_ratio;
+  /*! \brief min aspect ratio */
+  dmlc::optional<float> min_aspect_ratio;
   /*! \brief random shear the image [-max_shear_ratio, max_shear_ratio] */
   float max_shear_ratio;
   /*! \brief max crop size */
@@ -58,12 +62,24 @@ struct DefaultImageAugmentParam : public dmlc::Parameter<DefaultImageAugmentPara
   int min_crop_size;
   /*! \brief max scale ratio */
   float max_random_scale;
-  /*! \brief min scale_ratio */
+  /*! \brief min scale ratio */
   float min_random_scale;
+  /*! \brief max area */
+  float max_random_area;
+  /*! \brief min area */
+  float min_random_area;
   /*! \brief min image size */
   float min_img_size;
   /*! \brief max image size */
   float max_img_size;
+  /*! \brief max random brightness */
+  float brightness;
+  /*! \brief max random contrast */
+  float contrast;
+  /*! \brief max random saturation */
+  float saturation;
+  /*! \brief pca noise level */
+  float pca_noise;
   /*! \brief max random in H channel */
   int random_h;
   /*! \brief max random in S channel */
@@ -87,33 +103,65 @@ struct DefaultImageAugmentParam : public dmlc::Parameter<DefaultImageAugmentPara
                   "before applying other augmentations.");
     DMLC_DECLARE_FIELD(rand_crop).set_default(false)
         .describe("If or not randomly crop the image");
+    DMLC_DECLARE_FIELD(random_resized_crop).set_default(false)
+        .describe("If or not perform random resized cropping "
+                  "on the image, as a standard preprocessing "
+                  "for resnet training on ImageNet data.");
     DMLC_DECLARE_FIELD(max_rotate_angle).set_default(0.0f)
         .describe("Rotate by a random degree in ``[-v, v]``");
     DMLC_DECLARE_FIELD(max_aspect_ratio).set_default(0.0f)
+        .describe("Change the aspect (namely width/height) to a random value. "
+                  "If min_aspect_ratio is None then the aspect ratio ins sampled from "
+                  "[1 - max_aspect_ratio, 1 + max_aspect_ratio], "
+                  "else it is in ``[min_aspect_ratio, max_aspect_ratio]``");
+    DMLC_DECLARE_FIELD(min_aspect_ratio).set_default(dmlc::optional<float>())
         .describe("Change the aspect (namely width/height) to a random value "
-                  "in ``[1 - max_aspect_ratio, 1 + max_aspect_ratio]``");
+                  "in ``[min_aspect_ratio, max_aspect_ratio]``");
     DMLC_DECLARE_FIELD(max_shear_ratio).set_default(0.0f)
         .describe("Apply a shear transformation (namely ``(x,y)->(x+my,y)``) "
                   "with ``m`` randomly chose from "
                   "``[-max_shear_ratio, max_shear_ratio]``");
     DMLC_DECLARE_FIELD(max_crop_size).set_default(-1)
         .describe("Crop both width and height into a random size in "
-                  "``[min_crop_size, max_crop_size]``");
+                  "``[min_crop_size, max_crop_size].``"
+                  "Ignored if ``random_resized_crop`` is True.");
     DMLC_DECLARE_FIELD(min_crop_size).set_default(-1)
         .describe("Crop both width and height into a random size in "
-                  "``[min_crop_size, max_crop_size]``");
+                  "``[min_crop_size, max_crop_size].``"
+                  "Ignored if ``random_resized_crop`` is True.");
     DMLC_DECLARE_FIELD(max_random_scale).set_default(1.0f)
         .describe("Resize into ``[width*s, height*s]`` with ``s`` randomly"
-                  " chosen from ``[min_random_scale, max_random_scale]``");
+                  " chosen from ``[min_random_scale, max_random_scale]``. "
+                  "Ignored if ``random_resized_crop`` is True.");
     DMLC_DECLARE_FIELD(min_random_scale).set_default(1.0f)
         .describe("Resize into ``[width*s, height*s]`` with ``s`` randomly"
-                  " chosen from ``[min_random_scale, max_random_scale]``");
+                  " chosen from ``[min_random_scale, max_random_scale]``"
+                  "Ignored if ``random_resized_crop`` is True.");
+    DMLC_DECLARE_FIELD(max_random_area).set_default(1.0f)
+        .describe("Change the area (namely width * height) to a random value "
+                  "in ``[min_random_area, max_random_area]``. "
+                  "Ignored if ``random_resized_crop`` is False.");
+    DMLC_DECLARE_FIELD(min_random_area).set_default(1.0f)
+        .describe("Change the area (namely width * height) to a random value "
+                  "in ``[min_random_area, max_random_area]``. "
+                  "Ignored if ``random_resized_crop`` is False.");
     DMLC_DECLARE_FIELD(max_img_size).set_default(1e10f)
         .describe("Set the maximal width and height after all resize and"
                   " rotate argumentation  are applied");
     DMLC_DECLARE_FIELD(min_img_size).set_default(0.0f)
         .describe("Set the minimal width and height after all resize and"
                   " rotate argumentation  are applied");
+    DMLC_DECLARE_FIELD(brightness).set_default(0.0f)
+        .describe("Add a random value in ``[-brightness, brightness]`` to "
+                  "the brightness of image.");
+    DMLC_DECLARE_FIELD(contrast).set_default(0.0f)
+        .describe("Add a random value in ``[-contrast, contrast]`` to "
+                  "the contrast of image.");
+    DMLC_DECLARE_FIELD(saturation).set_default(0.0f)
+        .describe("Add a random value in ``[-saturation, saturation]`` to "
+                  "the saturation of image.");
+        DMLC_DECLARE_FIELD(pca_noise).set_default(0.0f)
+                .describe("Add PCA based noise to the image.");
     DMLC_DECLARE_FIELD(random_h).set_default(0)
         .describe("Add a random value in ``[-random_h, random_h]`` to "
                   "the H channel in HSL color space.");
@@ -197,6 +245,18 @@ class DefaultImageAugmenter : public ImageAugmenter {
   cv::Mat Process(const cv::Mat &src, std::vector<float> *label,
                   common::RANDOM_ENGINE *prnd) override {
     using mshadow::index_t;
+    bool is_cropped = false;
+
+    float max_aspect_ratio = 1.0f;
+    float min_aspect_ratio = 1.0f;
+    if (param_.min_aspect_ratio.has_value()) {
+      max_aspect_ratio = param_.max_aspect_ratio;
+      min_aspect_ratio = param_.min_aspect_ratio.value();
+    } else {
+      max_aspect_ratio = 1 + param_.max_aspect_ratio;
+      min_aspect_ratio = 1 - param_.max_aspect_ratio;
+    }
+
     cv::Mat res;
     if (param_.resize != -1) {
       int new_height, new_width;
@@ -220,8 +280,9 @@ class DefaultImageAugmenter : public ImageAugmenter {
 
     // normal augmentation by affine transformation.
     if (param_.max_rotate_angle > 0 || param_.max_shear_ratio > 0.0f
-        || param_.rotate > 0 || rotate_list_.size() > 0 || param_.max_random_scale != 1.0
-        || param_.min_random_scale != 1.0 || param_.max_aspect_ratio != 0.0f
+        || param_.rotate > 0 || rotate_list_.size() > 0
+        || param_.max_random_scale != 1.0f || param_.min_random_scale != 1.0
+        || min_aspect_ratio != 1.0f || max_aspect_ratio != 1.0f
         || param_.max_img_size != 1e10f || param_.min_img_size != 0.0f) {
       std::uniform_real_distribution<float> rand_uniform(0, 1);
       // shear
@@ -236,11 +297,17 @@ class DefaultImageAugmenter : public ImageAugmenter {
       float a = cos(angle / 180.0 * M_PI);
       float b = sin(angle / 180.0 * M_PI);
       // scale
-      float scale = rand_uniform(*prnd) *
-          (param_.max_random_scale - param_.min_random_scale) + param_.min_random_scale;
+      float scale = 1.0f;
+      if (!param_.random_resized_crop) {
+        scale = rand_uniform(*prnd) *
+            (param_.max_random_scale - param_.min_random_scale) + param_.min_random_scale;
+      }
       // aspect ratio
-      float ratio = rand_uniform(*prnd) *
-          param_.max_aspect_ratio * 2 - param_.max_aspect_ratio + 1;
+      float ratio = 1.0f;
+      if (!param_.random_resized_crop) {
+        ratio = rand_uniform(*prnd) *
+            (max_aspect_ratio - min_aspect_ratio) + min_aspect_ratio;
+      }
       float hs = 2 * scale / (1 + ratio);
       float ws = ratio * hs;
       // new width and height
@@ -276,8 +343,59 @@ class DefaultImageAugmenter : public ImageAugmenter {
                          cv::Scalar(param_.fill_value, param_.fill_value, param_.fill_value));
     }
 
-    // crop logic
-    if (param_.max_crop_size != -1 || param_.min_crop_size != -1) {
+    if (param_.random_resized_crop) {
+      // random resize crop
+      CHECK(param_.min_random_scale == 1.0f &&
+        param_.max_random_scale == 1.0f &&
+        param_.min_crop_size == -1 &&
+        param_.max_crop_size == -1 &&
+        !param_.rand_crop) <<
+        "\nSetting random_resized_crop to true conflicts with "
+        "min_random_scale, max_random_scale, "
+        "min_crop_size, max_crop_size, "
+        "and rand_crop.";
+
+      if (param_.max_random_area != 1.0f || param_.min_random_area != 1.0f
+          || max_aspect_ratio != 1.0f || min_aspect_ratio != 1.0f) {
+            CHECK(min_aspect_ratio > 0.0f);
+            CHECK(param_.min_random_area <= param_.max_random_area);
+            CHECK(min_aspect_ratio <= max_aspect_ratio);
+            std::uniform_real_distribution<float> rand_uniform_area(param_.min_random_area,
+                                                                    param_.max_random_area);
+            std::uniform_real_distribution<float> rand_uniform_ratio(min_aspect_ratio,
+                                                                     max_aspect_ratio);
+            std::uniform_real_distribution<float> rand_uniform(0, 1);
+            float area = res.rows * res.cols;
+            for (int i = 0; i < 10; ++i) {
+              float rand_area = rand_uniform_area(*prnd);
+              float ratio = rand_uniform_ratio(*prnd);
+              float target_area = area * rand_area;
+              int y_area = std::round(std::sqrt(target_area / ratio));
+              int x_area = std::round(std::sqrt(target_area * ratio));
+              if (rand_uniform(*prnd) > 0.5) {
+                float temp_y_area = y_area;
+                y_area = x_area;
+                x_area = temp_y_area;
+              }
+              if (y_area <= res.rows && x_area <= res.cols) {
+                index_t rand_y_area =
+                    std::uniform_int_distribution<index_t>(0, res.rows - y_area)(*prnd);
+                index_t rand_x_area =
+                    std::uniform_int_distribution<index_t>(0, res.cols - x_area)(*prnd);
+                cv::Rect roi(rand_x_area, rand_y_area, x_area, y_area);
+                int interpolation_method = GetInterMethod(param_.inter_method, x_area, y_area,
+                                                          param_.data_shape[2],
+                                                          param_.data_shape[1], prnd);
+                cv::resize(res(roi), res, cv::Size(param_.data_shape[2], param_.data_shape[1]),
+                           0, 0, interpolation_method);
+                is_cropped = true;
+                break;
+              }
+            }
+      }
+    } else if (!param_.random_resized_crop &&
+        (param_.max_crop_size != -1 || param_.min_crop_size != -1)) {
+      // random_crop
       CHECK(res.cols >= param_.max_crop_size && res.rows >= \
               param_.max_crop_size && param_.max_crop_size >= param_.min_crop_size)
           << "input image size smaller than max_crop_size";
@@ -296,7 +414,28 @@ class DefaultImageAugmenter : public ImageAugmenter {
                                                 param_.data_shape[2], param_.data_shape[1], prnd);
       cv::resize(res(roi), res, cv::Size(param_.data_shape[2], param_.data_shape[1])
                 , 0, 0, interpolation_method);
-    } else {
+      is_cropped = true;
+    }
+
+    if (!is_cropped) {
+      // center crop
+      int interpolation_method = GetInterMethod(param_.inter_method, res.cols, res.rows,
+                                                param_.data_shape[2],
+                                                param_.data_shape[1], prnd);
+      if (res.rows < param_.data_shape[1]) {
+        index_t new_cols = static_cast<index_t>(static_cast<float>(param_.data_shape[1]) /
+                                                static_cast<float>(res.rows) *
+                                                static_cast<float>(res.cols));
+        cv::resize(res, res, cv::Size(new_cols, param_.data_shape[1]),
+                   0, 0, interpolation_method);
+      }
+      if (res.cols < param_.data_shape[2]) {
+        index_t new_rows = static_cast<index_t>(static_cast<float>(param_.data_shape[2]) /
+                                                static_cast<float>(res.cols) *
+                                                static_cast<float>(res.rows));
+        cv::resize(res, res, cv::Size(param_.data_shape[2], new_rows),
+                   0, 0, interpolation_method);
+      }
       CHECK(static_cast<index_t>(res.rows) >= param_.data_shape[1]
             && static_cast<index_t>(res.cols) >= param_.data_shape[2])
           << "input image size smaller than input shape";
@@ -312,13 +451,48 @@ class DefaultImageAugmenter : public ImageAugmenter {
       res = res(roi);
     }
 
+    // color jitter
+    if (param_.brightness > 0.0f || param_.contrast > 0.0f || param_.saturation > 0.0f) {
+      std::uniform_real_distribution<float> rand_uniform(0, 1);
+      float alpha_b = 1.0 + std::uniform_real_distribution<float>(-param_.brightness,
+                                                                  param_.brightness)(*prnd);
+      float alpha_c = 1.0 + std::uniform_real_distribution<float>(-param_.contrast,
+                                                                  param_.contrast)(*prnd);
+      float alpha_s = 1.0 + std::uniform_real_distribution<float>(-param_.saturation,
+                                                                  param_.saturation)(*prnd);
+      int rand_order[3] = {0, 1, 2};
+      std::random_shuffle(std::begin(rand_order), std::end(rand_order));
+      for (int i = 0; i < 3; ++i) {
+        if (rand_order[i] == 0) {
+          // brightness
+          res.convertTo(res, -1, alpha_b, 0);
+        }
+        if (rand_order[i] == 1) {
+          // contrast
+          cvtColor(res, temp_, CV_RGB2GRAY);
+          float gray_mean = cv::mean(temp_)[0];
+          res.convertTo(res, -1, alpha_c, (1 - alpha_c) * gray_mean);
+        }
+        if (rand_order[i] == 2) {
+          // saturation
+          cvtColor(res, temp_, CV_RGB2GRAY);
+          cvtColor(temp_, temp_, CV_GRAY2BGR);
+          cv::addWeighted(res, alpha_s, temp_, 1 - alpha_s, 0.0, res);
+        }
+      }
+    }
+
     // color space augmentation
     if (param_.random_h != 0 || param_.random_s != 0 || param_.random_l != 0) {
       std::uniform_real_distribution<float> rand_uniform(0, 1);
       cvtColor(res, res, CV_BGR2HLS);
-      int h = rand_uniform(*prnd) * param_.random_h * 2 - param_.random_h;
-      int s = rand_uniform(*prnd) * param_.random_s * 2 - param_.random_s;
-      int l = rand_uniform(*prnd) * param_.random_l * 2 - param_.random_l;
+      // use an approximation of gaussian distribution to reduce extreme value
+      float rh = rand_uniform(*prnd); rh += 4 * rand_uniform(*prnd); rh = rh / 5;
+      float rs = rand_uniform(*prnd); rs += 4 * rand_uniform(*prnd); rs = rs / 5;
+      float rl = rand_uniform(*prnd); rl += 4 * rand_uniform(*prnd); rl = rl / 5;
+      int h = rh * param_.random_h * 2 - param_.random_h;
+      int s = rs * param_.random_s * 2 - param_.random_s;
+      int l = rl * param_.random_l * 2 - param_.random_l;
       int temp[3] = {h, l, s};
       int limit[3] = {180, 255, 255};
       for (int i = 0; i < res.rows; ++i) {
@@ -333,14 +507,45 @@ class DefaultImageAugmenter : public ImageAugmenter {
       }
       cvtColor(res, res, CV_HLS2BGR);
     }
+
+    // pca noise
+    if (param_.pca_noise > 0.0f) {
+      std::normal_distribution<float> rand_normal(0, param_.pca_noise);
+      float pca_alpha_r = rand_normal(*prnd);
+      float pca_alpha_g = rand_normal(*prnd);
+      float pca_alpha_b = rand_normal(*prnd);
+      float pca_r = eigvec[0][0] * pca_alpha_r + eigvec[0][1] * pca_alpha_g +
+           eigvec[0][2] * pca_alpha_b;
+      float pca_g = eigvec[1][0] * pca_alpha_r + eigvec[1][1] * pca_alpha_g +
+           eigvec[1][2] * pca_alpha_b;
+      float pca_b = eigvec[2][0] * pca_alpha_r + eigvec[2][1] * pca_alpha_g +
+           eigvec[2][2] * pca_alpha_b;
+      float pca[3] = { pca_b, pca_g, pca_r };
+      for (int i = 0; i < res.rows; ++i) {
+        for (int j = 0; j < res.cols; ++j) {
+          for (int k = 0; k < 3; ++k) {
+            int vp = res.at<cv::Vec3b>(i, j)[k];
+            vp += pca[k];
+            vp = std::max(0, std::min(255, vp));
+            res.at<cv::Vec3b>(i, j)[k] = vp;
+          }
+        }
+      }
+    }
     return res;
   }
 
+
  private:
   // temporal space
   cv::Mat temp_;
   // rotation param
   cv::Mat rotateM_;
+  // eigval and eigvec for adding pca noise
+  // store eigval * eigvec as eigvec
+  float eigvec[3][3] = { { 55.46f * -0.5675f, 4.794f * 0.7192f,  1.148f * 0.4009f },
+                         { 55.46f * -0.5808f, 4.794f * -0.0045f, 1.148f * -0.8140f },
+                         { 55.46f * -0.5836f, 4.794f * -0.6948f, 1.148f * 0.4203f } };
   // parameters
   DefaultImageAugmentParam param_;
   /*! \brief list of possible rotate angle */
diff --git a/tests/python/train/test_resnet_aug.py b/tests/python/train/test_resnet_aug.py
new file mode 100644
index 0000000..62c531b
--- /dev/null
+++ b/tests/python/train/test_resnet_aug.py
@@ -0,0 +1,173 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: skip-file
+import sys
+sys.path.insert(0, '../../python')
+import mxnet as mx
+import numpy as np
+import os, pickle, gzip
+import logging
+from mxnet.test_utils import get_cifar10
+
+batch_size = 128
+
+# small mlp network
+def get_net():
+    data = mx.symbol.Variable('data')
+    float_data = mx.symbol.Cast(data=data, dtype="float32")
+    fc1 = mx.symbol.FullyConnected(float_data, name='fc1', num_hidden=128)
+    act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
+    fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
+    act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
+    fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
+    softmax = mx.symbol.SoftmaxOutput(fc3, name="softmax")
+    return softmax
+
+# check data
+get_cifar10()
+
+def get_iterator(kv):
+    data_shape = (3, 28, 28)
+
+    train = mx.io.ImageRecordIter(
+        path_imgrec = "data/cifar/train.rec",
+        mean_img    = "data/cifar/mean.bin",
+        data_shape  = data_shape,
+        batch_size  = batch_size,
+        random_resized_crop = True,
+        min_aspect_ratio = 0.75,
+        max_aspect_ratio = 1.33,
+        min_random_area = 0.08,
+        max_random_area = 1,
+        brightness  = 0.4,
+        contrast    = 0.4,
+        saturation  = 0.4,
+        pca_noise   = 0.1,
+        rand_mirror = True,
+        num_parts   = kv.num_workers,
+        part_index  = kv.rank)
+    train = mx.io.PrefetchingIter(train)
+
+    val = mx.io.ImageRecordIter(
+        path_imgrec = "data/cifar/test.rec",
+        mean_img    = "data/cifar/mean.bin",
+        rand_crop   = False,
+        rand_mirror = False,
+        data_shape  = data_shape,
+        batch_size  = batch_size,
+        num_parts   = kv.num_workers,
+        part_index  = kv.rank)
+
+    return (train, val)
+
+num_epoch = 1
+
+def run_cifar10(train, val, use_module):
+    train.reset()
+    val.reset()
+    devs = [mx.cpu(0)]
+    net = get_net()
+    mod = mx.mod.Module(net, context=devs)
+    optim_args = {'learning_rate': 0.001, 'wd': 0.00001, 'momentum': 0.9}
+    eval_metrics = ['accuracy']
+    if use_module:
+        executor = mx.mod.Module(net, context=devs)
+        executor.fit(
+            train,
+            eval_data=val,
+            optimizer_params=optim_args,
+            eval_metric=eval_metrics,
+            num_epoch=num_epoch,
+            arg_params=None,
+            aux_params=None,
+            begin_epoch=0,
+            batch_end_callback=mx.callback.Speedometer(batch_size, 50),
+            epoch_end_callback=None)
+    else:
+        executor = mx.model.FeedForward.create(
+            net,
+            train,
+            ctx=devs,
+            eval_data=val,
+            eval_metric=eval_metrics,
+            num_epoch=num_epoch,
+            arg_params=None,
+            aux_params=None,
+            begin_epoch=0,
+            batch_end_callback=mx.callback.Speedometer(batch_size, 50),
+            epoch_end_callback=None,
+            **optim_args)
+
+    ret = executor.score(val, eval_metrics)
+    if use_module:
+        ret = list(ret)
+        logging.info('final accuracy = %f', ret[0][1])
+        assert (ret[0][1] > 0.08)
+    else:
+        logging.info('final accuracy = %f', ret[0])
+        assert (ret[0] > 0.08)
+
+class CustomDataIter(mx.io.DataIter):
+    def __init__(self, data):
+        super(CustomDataIter, self).__init__()
+        self.data = data
+        self.batch_size = data.provide_data[0][1][0]
+
+        # use legacy tuple
+        self.provide_data = [(n, s) for n, s in data.provide_data]
+        self.provide_label = [(n, s) for n, s in data.provide_label]
+
+    def reset(self):
+        self.data.reset()
+
+    def next(self):
+        return self.data.next()
+
+    def iter_next(self):
+        return self.data.iter_next()
+
+    def getdata(self):
+        return self.data.getdata()
+
+    def getlabel(self):
+        return self.data.getlable()
+
+    def getindex(self):
+        return self.data.getindex()
+
+    def getpad(self):
+        return self.data.getpad()
+
+def test_cifar10():
+    # print logging by default
+    logging.basicConfig(level=logging.DEBUG)
+    console = logging.StreamHandler()
+    console.setLevel(logging.DEBUG)
+    logging.getLogger('').addHandler(console)
+    kv = mx.kvstore.create("local")
+    # test float32 input
+    (train, val) = get_iterator(kv)
+    run_cifar10(train, val, use_module=False)
+    run_cifar10(train, val, use_module=True)
+
+    # test legecay tuple in provide_data and provide_label
+    run_cifar10(CustomDataIter(train), CustomDataIter(val), use_module=False)
+    run_cifar10(CustomDataIter(train), CustomDataIter(val), use_module=True)
+
+if __name__ == "__main__":
+    test_cifar10()