You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/19 22:23:21 UTC

[GitHub] piiswrong closed pull request #11027: Add standard ResNet data augmentation for ImageRecordIter

piiswrong closed pull request #11027: Add standard ResNet data augmentation for ImageRecordIter
URL: https://github.com/apache/incubator-mxnet/pull/11027
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/image-classification/common/data.py b/example/image-classification/common/data.py
index 05f5ddc4506..bfaadb3ff6b 100755
--- a/example/image-classification/common/data.py
+++ b/example/image-classification/common/data.py
@@ -43,9 +43,9 @@ def add_data_args(parser):
 def add_data_aug_args(parser):
     aug = parser.add_argument_group(
         'Image augmentations', 'implemented in src/io/image_aug_default.cc')
-    aug.add_argument('--random-crop', type=int, default=1,
+    aug.add_argument('--random-crop', type=int, default=0,
                      help='if or not randomly crop the image')
-    aug.add_argument('--random-mirror', type=int, default=1,
+    aug.add_argument('--random-mirror', type=int, default=0,
                      help='if or not randomly flip horizontally')
     aug.add_argument('--max-random-h', type=int, default=0,
                      help='max change of hue, whose range is [0, 180]')
@@ -53,8 +53,13 @@ def add_data_aug_args(parser):
                      help='max change of saturation, whose range is [0, 255]')
     aug.add_argument('--max-random-l', type=int, default=0,
                      help='max change of intensity, whose range is [0, 255]')
+    aug.add_argument('--min-random-aspect-ratio', type=float, default=None,
+                     help='min value of aspect ratio, whose value is either None or a positive value.')
     aug.add_argument('--max-random-aspect-ratio', type=float, default=0,
-                     help='max change of aspect ratio, whose range is [0, 1]')
+                     help='max value of aspect ratio. If min_random_aspect_ratio is None, '
+                          'the aspect ratio range is [1-max_random_aspect_ratio, '
+                          '1+max_random_aspect_ratio], otherwise it is '
+                          '[min_random_aspect_ratio, max_random_aspect_ratio].')
     aug.add_argument('--max-random-rotate-angle', type=int, default=0,
                      help='max angle to rotate, whose range is [0, 360]')
     aug.add_argument('--max-random-shear-ratio', type=float, default=0,
@@ -63,16 +68,28 @@ def add_data_aug_args(parser):
                      help='max ratio to scale')
     aug.add_argument('--min-random-scale', type=float, default=1,
                      help='min ratio to scale, should >= img_size/input_shape. otherwise use --pad-size')
+    aug.add_argument('--max-random-area', type=float, default=1,
+                     help='max area to crop in random resized crop, whose range is [0, 1]')
+    aug.add_argument('--min-random-area', type=float, default=1,
+                     help='min area to crop in random resized crop, whose range is [0, 1]')
+    aug.add_argument('--brightness', type=float, default=0,
+                     help='brightness jittering, whose range is [0, 1]')
+    aug.add_argument('--contrast', type=float, default=0,
+                     help='contrast jittering, whose range is [0, 1]')
+    aug.add_argument('--saturation', type=float, default=0,
+                     help='saturation jittering, whose range is [0, 1]')
+    aug.add_argument('--pca-noise', type=float, default=0,
+                     help='pca noise, whose range is [0, 1]')
+    aug.add_argument('--random-resized-crop', type=int, default=0,
+                     help='whether to use random resized crop')
     return aug
 
-def set_data_aug_level(aug, level):
-    if level >= 1:
-        aug.set_defaults(random_crop=1, random_mirror=1)
-    if level >= 2:
-        aug.set_defaults(max_random_h=36, max_random_s=50, max_random_l=50)
-    if level >= 3:
-        aug.set_defaults(max_random_rotate_angle=10, max_random_shear_ratio=0.1, max_random_aspect_ratio=0.25)
-
+def set_resnet_aug(aug):
+    # standard data augmentation setting for resnet training
+    aug.set_defaults(random_crop=1, random_resized_crop=1)
+    aug.set_defaults(min_random_area=0.08)
+    aug.set_defaults(max_random_aspect_ratio=4./3., min_random_aspect_ratio=3./4.)
+    aug.set_defaults(brightness=0.4, contrast=0.4, saturation=0.4, pca_noise=0.1)
 
 class SyntheticDataIter(DataIter):
     def __init__(self, num_classes, data_shape, max_iter, dtype):
@@ -135,8 +152,16 @@ def get_rec_iter(args, kv=None):
         max_random_scale    = args.max_random_scale,
         pad                 = args.pad_size,
         fill_value          = 127,
+        random_resized_crop = args.random_resized_crop,
         min_random_scale    = args.min_random_scale,
         max_aspect_ratio    = args.max_random_aspect_ratio,
+        min_aspect_ratio    = args.min_random_aspect_ratio,
+        max_random_area     = args.max_random_area,
+        min_random_area     = args.min_random_area,
+        brightness          = args.brightness,
+        contrast            = args.contrast,
+        saturation          = args.saturation,
+        pca_noise           = args.pca_noise,
         random_h            = args.max_random_h,
         random_s            = args.max_random_s,
         random_l            = args.max_random_l,
@@ -156,6 +181,7 @@ def get_rec_iter(args, kv=None):
         mean_r              = rgb_mean[0],
         mean_g              = rgb_mean[1],
         mean_b              = rgb_mean[2],
+        resize              = 256,
         data_name           = 'data',
         label_name          = 'softmax_label',
         batch_size          = args.batch_size,
diff --git a/example/image-classification/train_imagenet.py b/example/image-classification/train_imagenet.py
index f465fbc5f46..a90b6aead23 100644
--- a/example/image-classification/train_imagenet.py
+++ b/example/image-classification/train_imagenet.py
@@ -30,8 +30,8 @@
     fit.add_fit_args(parser)
     data.add_data_args(parser)
     data.add_data_aug_args(parser)
-    # use a large aug level
-    data.set_data_aug_level(parser, 3)
+    # uncomment to set standard augmentation for resnet training
+    # data.set_resnet_aug(parser)
     parser.set_defaults(
         # network
         network          = 'resnet',
diff --git a/src/io/image_aug_default.cc b/src/io/image_aug_default.cc
index 22af7d92750..f7d08b92f17 100644
--- a/src/io/image_aug_default.cc
+++ b/src/io/image_aug_default.cc
@@ -46,10 +46,14 @@ struct DefaultImageAugmentParam : public dmlc::Parameter<DefaultImageAugmentPara
   int resize;
   /*! \brief whether we do random cropping */
   bool rand_crop;
+  /*! \brief whether we do random resized cropping */
+  bool random_resized_crop;
   /*! \brief [-max_rotate_angle, max_rotate_angle] */
   int max_rotate_angle;
   /*! \brief max aspect ratio */
   float max_aspect_ratio;
+  /*! \brief min aspect ratio */
+  dmlc::optional<float> min_aspect_ratio;
   /*! \brief random shear the image [-max_shear_ratio, max_shear_ratio] */
   float max_shear_ratio;
   /*! \brief max crop size */
@@ -58,12 +62,24 @@ struct DefaultImageAugmentParam : public dmlc::Parameter<DefaultImageAugmentPara
   int min_crop_size;
   /*! \brief max scale ratio */
   float max_random_scale;
-  /*! \brief min scale_ratio */
+  /*! \brief min scale ratio */
   float min_random_scale;
+  /*! \brief max area */
+  float max_random_area;
+  /*! \brief min area */
+  float min_random_area;
   /*! \brief min image size */
   float min_img_size;
   /*! \brief max image size */
   float max_img_size;
+  /*! \brief max random brightness */
+  float brightness;
+  /*! \brief max random contrast */
+  float contrast;
+  /*! \brief max random saturation */
+  float saturation;
+  /*! \brief pca noise level */
+  float pca_noise;
   /*! \brief max random in H channel */
   int random_h;
   /*! \brief max random in S channel */
@@ -87,33 +103,65 @@ struct DefaultImageAugmentParam : public dmlc::Parameter<DefaultImageAugmentPara
                   "before applying other augmentations.");
     DMLC_DECLARE_FIELD(rand_crop).set_default(false)
         .describe("If or not randomly crop the image");
+    DMLC_DECLARE_FIELD(random_resized_crop).set_default(false)
+        .describe("If or not perform random resized cropping "
+                  "on the image, as a standard preprocessing "
+                  "for resnet training on ImageNet data.");
     DMLC_DECLARE_FIELD(max_rotate_angle).set_default(0.0f)
         .describe("Rotate by a random degree in ``[-v, v]``");
     DMLC_DECLARE_FIELD(max_aspect_ratio).set_default(0.0f)
+        .describe("Change the aspect (namely width/height) to a random value. "
+                  "If min_aspect_ratio is None then the aspect ratio ins sampled from "
+                  "[1 - max_aspect_ratio, 1 + max_aspect_ratio], "
+                  "else it is in ``[min_aspect_ratio, max_aspect_ratio]``");
+    DMLC_DECLARE_FIELD(min_aspect_ratio).set_default(dmlc::optional<float>())
         .describe("Change the aspect (namely width/height) to a random value "
-                  "in ``[1 - max_aspect_ratio, 1 + max_aspect_ratio]``");
+                  "in ``[min_aspect_ratio, max_aspect_ratio]``");
     DMLC_DECLARE_FIELD(max_shear_ratio).set_default(0.0f)
         .describe("Apply a shear transformation (namely ``(x,y)->(x+my,y)``) "
                   "with ``m`` randomly chose from "
                   "``[-max_shear_ratio, max_shear_ratio]``");
     DMLC_DECLARE_FIELD(max_crop_size).set_default(-1)
         .describe("Crop both width and height into a random size in "
-                  "``[min_crop_size, max_crop_size]``");
+                  "``[min_crop_size, max_crop_size].``"
+                  "Ignored if ``random_resized_crop`` is True.");
     DMLC_DECLARE_FIELD(min_crop_size).set_default(-1)
         .describe("Crop both width and height into a random size in "
-                  "``[min_crop_size, max_crop_size]``");
+                  "``[min_crop_size, max_crop_size].``"
+                  "Ignored if ``random_resized_crop`` is True.");
     DMLC_DECLARE_FIELD(max_random_scale).set_default(1.0f)
         .describe("Resize into ``[width*s, height*s]`` with ``s`` randomly"
-                  " chosen from ``[min_random_scale, max_random_scale]``");
+                  " chosen from ``[min_random_scale, max_random_scale]``. "
+                  "Ignored if ``random_resized_crop`` is True.");
     DMLC_DECLARE_FIELD(min_random_scale).set_default(1.0f)
         .describe("Resize into ``[width*s, height*s]`` with ``s`` randomly"
-                  " chosen from ``[min_random_scale, max_random_scale]``");
+                  " chosen from ``[min_random_scale, max_random_scale]``"
+                  "Ignored if ``random_resized_crop`` is True.");
+    DMLC_DECLARE_FIELD(max_random_area).set_default(1.0f)
+        .describe("Change the area (namely width * height) to a random value "
+                  "in ``[min_random_area, max_random_area]``. "
+                  "Ignored if ``random_resized_crop`` is False.");
+    DMLC_DECLARE_FIELD(min_random_area).set_default(1.0f)
+        .describe("Change the area (namely width * height) to a random value "
+                  "in ``[min_random_area, max_random_area]``. "
+                  "Ignored if ``random_resized_crop`` is False.");
     DMLC_DECLARE_FIELD(max_img_size).set_default(1e10f)
         .describe("Set the maximal width and height after all resize and"
                   " rotate argumentation  are applied");
     DMLC_DECLARE_FIELD(min_img_size).set_default(0.0f)
         .describe("Set the minimal width and height after all resize and"
                   " rotate argumentation  are applied");
+    DMLC_DECLARE_FIELD(brightness).set_default(0.0f)
+        .describe("Add a random value in ``[-brightness, brightness]`` to "
+                  "the brightness of image.");
+    DMLC_DECLARE_FIELD(contrast).set_default(0.0f)
+        .describe("Add a random value in ``[-contrast, contrast]`` to "
+                  "the contrast of image.");
+    DMLC_DECLARE_FIELD(saturation).set_default(0.0f)
+        .describe("Add a random value in ``[-saturation, saturation]`` to "
+                  "the saturation of image.");
+        DMLC_DECLARE_FIELD(pca_noise).set_default(0.0f)
+                .describe("Add PCA based noise to the image.");
     DMLC_DECLARE_FIELD(random_h).set_default(0)
         .describe("Add a random value in ``[-random_h, random_h]`` to "
                   "the H channel in HSL color space.");
@@ -197,6 +245,18 @@ class DefaultImageAugmenter : public ImageAugmenter {
   cv::Mat Process(const cv::Mat &src, std::vector<float> *label,
                   common::RANDOM_ENGINE *prnd) override {
     using mshadow::index_t;
+    bool is_cropped = false;
+
+    float max_aspect_ratio = 1.0f;
+    float min_aspect_ratio = 1.0f;
+    if (param_.min_aspect_ratio.has_value()) {
+      max_aspect_ratio = param_.max_aspect_ratio;
+      min_aspect_ratio = param_.min_aspect_ratio.value();
+    } else {
+      max_aspect_ratio = 1 + param_.max_aspect_ratio;
+      min_aspect_ratio = 1 - param_.max_aspect_ratio;
+    }
+
     cv::Mat res;
     if (param_.resize != -1) {
       int new_height, new_width;
@@ -220,8 +280,9 @@ class DefaultImageAugmenter : public ImageAugmenter {
 
     // normal augmentation by affine transformation.
     if (param_.max_rotate_angle > 0 || param_.max_shear_ratio > 0.0f
-        || param_.rotate > 0 || rotate_list_.size() > 0 || param_.max_random_scale != 1.0
-        || param_.min_random_scale != 1.0 || param_.max_aspect_ratio != 0.0f
+        || param_.rotate > 0 || rotate_list_.size() > 0
+        || param_.max_random_scale != 1.0f || param_.min_random_scale != 1.0
+        || min_aspect_ratio != 1.0f || max_aspect_ratio != 1.0f
         || param_.max_img_size != 1e10f || param_.min_img_size != 0.0f) {
       std::uniform_real_distribution<float> rand_uniform(0, 1);
       // shear
@@ -236,11 +297,17 @@ class DefaultImageAugmenter : public ImageAugmenter {
       float a = cos(angle / 180.0 * M_PI);
       float b = sin(angle / 180.0 * M_PI);
       // scale
-      float scale = rand_uniform(*prnd) *
-          (param_.max_random_scale - param_.min_random_scale) + param_.min_random_scale;
+      float scale = 1.0f;
+      if (!param_.random_resized_crop) {
+        scale = rand_uniform(*prnd) *
+            (param_.max_random_scale - param_.min_random_scale) + param_.min_random_scale;
+      }
       // aspect ratio
-      float ratio = rand_uniform(*prnd) *
-          param_.max_aspect_ratio * 2 - param_.max_aspect_ratio + 1;
+      float ratio = 1.0f;
+      if (!param_.random_resized_crop) {
+        ratio = rand_uniform(*prnd) *
+            (max_aspect_ratio - min_aspect_ratio) + min_aspect_ratio;
+      }
       float hs = 2 * scale / (1 + ratio);
       float ws = ratio * hs;
       // new width and height
@@ -276,8 +343,59 @@ class DefaultImageAugmenter : public ImageAugmenter {
                          cv::Scalar(param_.fill_value, param_.fill_value, param_.fill_value));
     }
 
-    // crop logic
-    if (param_.max_crop_size != -1 || param_.min_crop_size != -1) {
+    if (param_.random_resized_crop) {
+      // random resize crop
+      CHECK(param_.min_random_scale == 1.0f &&
+        param_.max_random_scale == 1.0f &&
+        param_.min_crop_size == -1 &&
+        param_.max_crop_size == -1 &&
+        !param_.rand_crop) <<
+        "\nSetting random_resized_crop to true conflicts with "
+        "min_random_scale, max_random_scale, "
+        "min_crop_size, max_crop_size, "
+        "and rand_crop.";
+
+      if (param_.max_random_area != 1.0f || param_.min_random_area != 1.0f
+          || max_aspect_ratio != 1.0f || min_aspect_ratio != 1.0f) {
+            CHECK(min_aspect_ratio > 0.0f);
+            CHECK(param_.min_random_area <= param_.max_random_area);
+            CHECK(min_aspect_ratio <= max_aspect_ratio);
+            std::uniform_real_distribution<float> rand_uniform_area(param_.min_random_area,
+                                                                    param_.max_random_area);
+            std::uniform_real_distribution<float> rand_uniform_ratio(min_aspect_ratio,
+                                                                     max_aspect_ratio);
+            std::uniform_real_distribution<float> rand_uniform(0, 1);
+            float area = res.rows * res.cols;
+            for (int i = 0; i < 10; ++i) {
+              float rand_area = rand_uniform_area(*prnd);
+              float ratio = rand_uniform_ratio(*prnd);
+              float target_area = area * rand_area;
+              int y_area = std::round(std::sqrt(target_area / ratio));
+              int x_area = std::round(std::sqrt(target_area * ratio));
+              if (rand_uniform(*prnd) > 0.5) {
+                float temp_y_area = y_area;
+                y_area = x_area;
+                x_area = temp_y_area;
+              }
+              if (y_area <= res.rows && x_area <= res.cols) {
+                index_t rand_y_area =
+                    std::uniform_int_distribution<index_t>(0, res.rows - y_area)(*prnd);
+                index_t rand_x_area =
+                    std::uniform_int_distribution<index_t>(0, res.cols - x_area)(*prnd);
+                cv::Rect roi(rand_x_area, rand_y_area, x_area, y_area);
+                int interpolation_method = GetInterMethod(param_.inter_method, x_area, y_area,
+                                                          param_.data_shape[2],
+                                                          param_.data_shape[1], prnd);
+                cv::resize(res(roi), res, cv::Size(param_.data_shape[2], param_.data_shape[1]),
+                           0, 0, interpolation_method);
+                is_cropped = true;
+                break;
+              }
+            }
+      }
+    } else if (!param_.random_resized_crop &&
+        (param_.max_crop_size != -1 || param_.min_crop_size != -1)) {
+      // random_crop
       CHECK(res.cols >= param_.max_crop_size && res.rows >= \
               param_.max_crop_size && param_.max_crop_size >= param_.min_crop_size)
           << "input image size smaller than max_crop_size";
@@ -296,7 +414,28 @@ class DefaultImageAugmenter : public ImageAugmenter {
                                                 param_.data_shape[2], param_.data_shape[1], prnd);
       cv::resize(res(roi), res, cv::Size(param_.data_shape[2], param_.data_shape[1])
                 , 0, 0, interpolation_method);
-    } else {
+      is_cropped = true;
+    }
+
+    if (!is_cropped) {
+      // center crop
+      int interpolation_method = GetInterMethod(param_.inter_method, res.cols, res.rows,
+                                                param_.data_shape[2],
+                                                param_.data_shape[1], prnd);
+      if (res.rows < param_.data_shape[1]) {
+        index_t new_cols = static_cast<index_t>(static_cast<float>(param_.data_shape[1]) /
+                                                static_cast<float>(res.rows) *
+                                                static_cast<float>(res.cols));
+        cv::resize(res, res, cv::Size(new_cols, param_.data_shape[1]),
+                   0, 0, interpolation_method);
+      }
+      if (res.cols < param_.data_shape[2]) {
+        index_t new_rows = static_cast<index_t>(static_cast<float>(param_.data_shape[2]) /
+                                                static_cast<float>(res.cols) *
+                                                static_cast<float>(res.rows));
+        cv::resize(res, res, cv::Size(param_.data_shape[2], new_rows),
+                   0, 0, interpolation_method);
+      }
       CHECK(static_cast<index_t>(res.rows) >= param_.data_shape[1]
             && static_cast<index_t>(res.cols) >= param_.data_shape[2])
           << "input image size smaller than input shape";
@@ -312,13 +451,48 @@ class DefaultImageAugmenter : public ImageAugmenter {
       res = res(roi);
     }
 
+    // color jitter
+    if (param_.brightness > 0.0f || param_.contrast > 0.0f || param_.saturation > 0.0f) {
+      std::uniform_real_distribution<float> rand_uniform(0, 1);
+      float alpha_b = 1.0 + std::uniform_real_distribution<float>(-param_.brightness,
+                                                                  param_.brightness)(*prnd);
+      float alpha_c = 1.0 + std::uniform_real_distribution<float>(-param_.contrast,
+                                                                  param_.contrast)(*prnd);
+      float alpha_s = 1.0 + std::uniform_real_distribution<float>(-param_.saturation,
+                                                                  param_.saturation)(*prnd);
+      int rand_order[3] = {0, 1, 2};
+      std::random_shuffle(std::begin(rand_order), std::end(rand_order));
+      for (int i = 0; i < 3; ++i) {
+        if (rand_order[i] == 0) {
+          // brightness
+          res.convertTo(res, -1, alpha_b, 0);
+        }
+        if (rand_order[i] == 1) {
+          // contrast
+          cvtColor(res, temp_, CV_RGB2GRAY);
+          float gray_mean = cv::mean(temp_)[0];
+          res.convertTo(res, -1, alpha_c, (1 - alpha_c) * gray_mean);
+        }
+        if (rand_order[i] == 2) {
+          // saturation
+          cvtColor(res, temp_, CV_RGB2GRAY);
+          cvtColor(temp_, temp_, CV_GRAY2BGR);
+          cv::addWeighted(res, alpha_s, temp_, 1 - alpha_s, 0.0, res);
+        }
+      }
+    }
+
     // color space augmentation
     if (param_.random_h != 0 || param_.random_s != 0 || param_.random_l != 0) {
       std::uniform_real_distribution<float> rand_uniform(0, 1);
       cvtColor(res, res, CV_BGR2HLS);
-      int h = rand_uniform(*prnd) * param_.random_h * 2 - param_.random_h;
-      int s = rand_uniform(*prnd) * param_.random_s * 2 - param_.random_s;
-      int l = rand_uniform(*prnd) * param_.random_l * 2 - param_.random_l;
+      // use an approximation of gaussian distribution to reduce extreme value
+      float rh = rand_uniform(*prnd); rh += 4 * rand_uniform(*prnd); rh = rh / 5;
+      float rs = rand_uniform(*prnd); rs += 4 * rand_uniform(*prnd); rs = rs / 5;
+      float rl = rand_uniform(*prnd); rl += 4 * rand_uniform(*prnd); rl = rl / 5;
+      int h = rh * param_.random_h * 2 - param_.random_h;
+      int s = rs * param_.random_s * 2 - param_.random_s;
+      int l = rl * param_.random_l * 2 - param_.random_l;
       int temp[3] = {h, l, s};
       int limit[3] = {180, 255, 255};
       for (int i = 0; i < res.rows; ++i) {
@@ -333,14 +507,45 @@ class DefaultImageAugmenter : public ImageAugmenter {
       }
       cvtColor(res, res, CV_HLS2BGR);
     }
+
+    // pca noise
+    if (param_.pca_noise > 0.0f) {
+      std::normal_distribution<float> rand_normal(0, param_.pca_noise);
+      float pca_alpha_r = rand_normal(*prnd);
+      float pca_alpha_g = rand_normal(*prnd);
+      float pca_alpha_b = rand_normal(*prnd);
+      float pca_r = eigvec[0][0] * pca_alpha_r + eigvec[0][1] * pca_alpha_g +
+           eigvec[0][2] * pca_alpha_b;
+      float pca_g = eigvec[1][0] * pca_alpha_r + eigvec[1][1] * pca_alpha_g +
+           eigvec[1][2] * pca_alpha_b;
+      float pca_b = eigvec[2][0] * pca_alpha_r + eigvec[2][1] * pca_alpha_g +
+           eigvec[2][2] * pca_alpha_b;
+      float pca[3] = { pca_b, pca_g, pca_r };
+      for (int i = 0; i < res.rows; ++i) {
+        for (int j = 0; j < res.cols; ++j) {
+          for (int k = 0; k < 3; ++k) {
+            int vp = res.at<cv::Vec3b>(i, j)[k];
+            vp += pca[k];
+            vp = std::max(0, std::min(255, vp));
+            res.at<cv::Vec3b>(i, j)[k] = vp;
+          }
+        }
+      }
+    }
     return res;
   }
 
+
  private:
   // temporal space
   cv::Mat temp_;
   // rotation param
   cv::Mat rotateM_;
+  // eigval and eigvec for adding pca noise
+  // store eigval * eigvec as eigvec
+  float eigvec[3][3] = { { 55.46f * -0.5675f, 4.794f * 0.7192f,  1.148f * 0.4009f },
+                         { 55.46f * -0.5808f, 4.794f * -0.0045f, 1.148f * -0.8140f },
+                         { 55.46f * -0.5836f, 4.794f * -0.6948f, 1.148f * 0.4203f } };
   // parameters
   DefaultImageAugmentParam param_;
   /*! \brief list of possible rotate angle */
diff --git a/tests/python/train/test_resnet_aug.py b/tests/python/train/test_resnet_aug.py
new file mode 100644
index 00000000000..62c531bb637
--- /dev/null
+++ b/tests/python/train/test_resnet_aug.py
@@ -0,0 +1,173 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: skip-file
+import sys
+sys.path.insert(0, '../../python')
+import mxnet as mx
+import numpy as np
+import os, pickle, gzip
+import logging
+from mxnet.test_utils import get_cifar10
+
+batch_size = 128
+
+# small mlp network
+def get_net():
+    data = mx.symbol.Variable('data')
+    float_data = mx.symbol.Cast(data=data, dtype="float32")
+    fc1 = mx.symbol.FullyConnected(float_data, name='fc1', num_hidden=128)
+    act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
+    fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
+    act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
+    fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
+    softmax = mx.symbol.SoftmaxOutput(fc3, name="softmax")
+    return softmax
+
+# check data
+get_cifar10()
+
+def get_iterator(kv):
+    data_shape = (3, 28, 28)
+
+    train = mx.io.ImageRecordIter(
+        path_imgrec = "data/cifar/train.rec",
+        mean_img    = "data/cifar/mean.bin",
+        data_shape  = data_shape,
+        batch_size  = batch_size,
+        random_resized_crop = True,
+        min_aspect_ratio = 0.75,
+        max_aspect_ratio = 1.33,
+        min_random_area = 0.08,
+        max_random_area = 1,
+        brightness  = 0.4,
+        contrast    = 0.4,
+        saturation  = 0.4,
+        pca_noise   = 0.1,
+        rand_mirror = True,
+        num_parts   = kv.num_workers,
+        part_index  = kv.rank)
+    train = mx.io.PrefetchingIter(train)
+
+    val = mx.io.ImageRecordIter(
+        path_imgrec = "data/cifar/test.rec",
+        mean_img    = "data/cifar/mean.bin",
+        rand_crop   = False,
+        rand_mirror = False,
+        data_shape  = data_shape,
+        batch_size  = batch_size,
+        num_parts   = kv.num_workers,
+        part_index  = kv.rank)
+
+    return (train, val)
+
+num_epoch = 1
+
+def run_cifar10(train, val, use_module):
+    train.reset()
+    val.reset()
+    devs = [mx.cpu(0)]
+    net = get_net()
+    mod = mx.mod.Module(net, context=devs)
+    optim_args = {'learning_rate': 0.001, 'wd': 0.00001, 'momentum': 0.9}
+    eval_metrics = ['accuracy']
+    if use_module:
+        executor = mx.mod.Module(net, context=devs)
+        executor.fit(
+            train,
+            eval_data=val,
+            optimizer_params=optim_args,
+            eval_metric=eval_metrics,
+            num_epoch=num_epoch,
+            arg_params=None,
+            aux_params=None,
+            begin_epoch=0,
+            batch_end_callback=mx.callback.Speedometer(batch_size, 50),
+            epoch_end_callback=None)
+    else:
+        executor = mx.model.FeedForward.create(
+            net,
+            train,
+            ctx=devs,
+            eval_data=val,
+            eval_metric=eval_metrics,
+            num_epoch=num_epoch,
+            arg_params=None,
+            aux_params=None,
+            begin_epoch=0,
+            batch_end_callback=mx.callback.Speedometer(batch_size, 50),
+            epoch_end_callback=None,
+            **optim_args)
+
+    ret = executor.score(val, eval_metrics)
+    if use_module:
+        ret = list(ret)
+        logging.info('final accuracy = %f', ret[0][1])
+        assert (ret[0][1] > 0.08)
+    else:
+        logging.info('final accuracy = %f', ret[0])
+        assert (ret[0] > 0.08)
+
+class CustomDataIter(mx.io.DataIter):
+    def __init__(self, data):
+        super(CustomDataIter, self).__init__()
+        self.data = data
+        self.batch_size = data.provide_data[0][1][0]
+
+        # use legacy tuple
+        self.provide_data = [(n, s) for n, s in data.provide_data]
+        self.provide_label = [(n, s) for n, s in data.provide_label]
+
+    def reset(self):
+        self.data.reset()
+
+    def next(self):
+        return self.data.next()
+
+    def iter_next(self):
+        return self.data.iter_next()
+
+    def getdata(self):
+        return self.data.getdata()
+
+    def getlabel(self):
+        return self.data.getlable()
+
+    def getindex(self):
+        return self.data.getindex()
+
+    def getpad(self):
+        return self.data.getpad()
+
+def test_cifar10():
+    # print logging by default
+    logging.basicConfig(level=logging.DEBUG)
+    console = logging.StreamHandler()
+    console.setLevel(logging.DEBUG)
+    logging.getLogger('').addHandler(console)
+    kv = mx.kvstore.create("local")
+    # test float32 input
+    (train, val) = get_iterator(kv)
+    run_cifar10(train, val, use_module=False)
+    run_cifar10(train, val, use_module=True)
+
+    # test legecay tuple in provide_data and provide_label
+    run_cifar10(CustomDataIter(train), CustomDataIter(val), use_module=False)
+    run_cifar10(CustomDataIter(train), CustomDataIter(val), use_module=True)
+
+if __name__ == "__main__":
+    test_cifar10()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services