You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/06/20 23:11:27 UTC

[incubator-mxnet] branch master updated: Add seed_aug parameter for ImageRecordItr to fix random seed for default augmentation (#11247)

This is an automated email from the ASF dual-hosted git repository.

zhreshold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9ffc03c  Add seed_aug parameter for ImageRecordItr to fix random seed for default augmentation (#11247)
9ffc03c is described below

commit 9ffc03c3b45fd76a1efe49263cea861f781dca5e
Author: Wen-Yang Chu <we...@gmail.com>
AuthorDate: Thu Jun 21 01:11:21 2018 +0200

    Add seed_aug parameter for ImageRecordItr to fix random seed for default augmentation (#11247)
    
    * add seed_aug parameter for ImageRecordItr to fix random seed for augmentation
    
    * remove white space
    
    * add test
    
    * fix test
    
    * improve according to review: using dmlc::optional and has_value()
    
    * missing header
    
    * change data type and way to get value
---
 src/io/image_aug_default.cc      |  12 +++
 tests/python/unittest/test_io.py | 167 +++++++++++++++++++++++++++++++--------
 2 files changed, 148 insertions(+), 31 deletions(-)

diff --git a/src/io/image_aug_default.cc b/src/io/image_aug_default.cc
index f7d08b9..ce9c79c 100644
--- a/src/io/image_aug_default.cc
+++ b/src/io/image_aug_default.cc
@@ -23,6 +23,7 @@
  * \brief Default augmenter.
  */
 #include <mxnet/base.h>
+#include <dmlc/optional.h>
 #include <utility>
 #include <string>
 #include <algorithm>
@@ -96,6 +97,9 @@ struct DefaultImageAugmentParam : public dmlc::Parameter<DefaultImageAugmentPara
   int pad;
   /*! \brief shape of the image data*/
   TShape data_shape;
+  /*! \brief random seed for augmentations */
+  dmlc::optional<int> seed_aug;
+
   // declare parameters
   DMLC_DECLARE_PARAMETER(DefaultImageAugmentParam) {
     DMLC_DECLARE_FIELD(resize).set_default(-1)
@@ -184,6 +188,8 @@ struct DefaultImageAugmentParam : public dmlc::Parameter<DefaultImageAugmentPara
     DMLC_DECLARE_FIELD(pad).set_default(0)
         .describe("Change size from ``[width, height]`` into "
                   "``[pad + width + pad, pad + height + pad]`` by padding pixes");
+    DMLC_DECLARE_FIELD(seed_aug).set_default(dmlc::optional<int>())
+        .describe("Random seed for augmentations.");
   }
 };
 
@@ -204,6 +210,7 @@ class DefaultImageAugmenter : public ImageAugmenter {
   // contructor
   DefaultImageAugmenter() {
     rotateM_ = cv::Mat(2, 3, CV_32F);
+    seed_init_state = false;
   }
   void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
     std::vector<std::pair<std::string, std::string> > kwargs_left;
@@ -244,6 +251,10 @@ class DefaultImageAugmenter : public ImageAugmenter {
   }
   cv::Mat Process(const cv::Mat &src, std::vector<float> *label,
                   common::RANDOM_ENGINE *prnd) override {
+    if (!seed_init_state && param_.seed_aug.has_value()) {
+      prnd->seed(param_.seed_aug.value());
+      seed_init_state = true;
+    }
     using mshadow::index_t;
     bool is_cropped = false;
 
@@ -550,6 +561,7 @@ class DefaultImageAugmenter : public ImageAugmenter {
   DefaultImageAugmentParam param_;
   /*! \brief list of possible rotate angle */
   std::vector<int> rotate_list_;
+  bool seed_init_state;
 };
 
 ImageAugmenter* ImageAugmenter::Create(const std::string& name) {
diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py
index 7e6ef1a..c758208 100644
--- a/tests/python/unittest/test_io.py
+++ b/tests/python/unittest/test_io.py
@@ -20,7 +20,8 @@ import mxnet as mx
 from mxnet.test_utils import *
 from mxnet.base import MXNetError
 import numpy as np
-import os, gzip
+import os
+import gzip
 import pickle as pickle
 import time
 try:
@@ -31,16 +32,17 @@ import sys
 from common import assertRaises
 import unittest
 
+
 def test_MNISTIter():
     # prepare data
     get_mnist_ubyte()
 
     batch_size = 100
     train_dataiter = mx.io.MNISTIter(
-            image="data/train-images-idx3-ubyte",
-            label="data/train-labels-idx1-ubyte",
-            data_shape=(784,),
-            batch_size=batch_size, shuffle=1, flat=1, silent=0, seed=10)
+        image="data/train-images-idx3-ubyte",
+        label="data/train-labels-idx1-ubyte",
+        data_shape=(784,),
+        batch_size=batch_size, shuffle=1, flat=1, silent=0, seed=10)
     # test_loop
     nbatch = 60000 / batch_size
     batch_count = 0
@@ -60,18 +62,19 @@ def test_MNISTIter():
     label_1 = train_dataiter.getlabel().asnumpy().flatten()
     assert(sum(label_0 - label_1) == 0)
 
+
 def test_Cifar10Rec():
     get_cifar10()
     dataiter = mx.io.ImageRecordIter(
-            path_imgrec="data/cifar/train.rec",
-            mean_img="data/cifar/cifar10_mean.bin",
-            rand_crop=False,
-            and_mirror=False,
-            shuffle=False,
-            data_shape=(3,28,28),
-            batch_size=100,
-            preprocess_threads=4,
-            prefetch_buffer=1)
+        path_imgrec="data/cifar/train.rec",
+        mean_img="data/cifar/cifar10_mean.bin",
+        rand_crop=False,
+        and_mirror=False,
+        shuffle=False,
+        data_shape=(3, 28, 28),
+        batch_size=100,
+        preprocess_threads=4,
+        prefetch_buffer=1)
     labelcount = [0 for i in range(10)]
     batchcount = 0
     for batch in dataiter:
@@ -84,23 +87,26 @@ def test_Cifar10Rec():
     for i in range(10):
         assert(labelcount[i] == 5000)
 
+
 def test_NDArrayIter():
     data = np.ones([1000, 2, 2])
     label = np.ones([1000, 1])
     for i in range(1000):
         data[i] = i / 100
         label[i] = i / 100
-    dataiter = mx.io.NDArrayIter(data, label, 128, True, last_batch_handle='pad')
+    dataiter = mx.io.NDArrayIter(
+        data, label, 128, True, last_batch_handle='pad')
     batchidx = 0
     for batch in dataiter:
         batchidx += 1
     assert(batchidx == 8)
-    dataiter = mx.io.NDArrayIter(data, label, 128, False, last_batch_handle='pad')
+    dataiter = mx.io.NDArrayIter(
+        data, label, 128, False, last_batch_handle='pad')
     batchidx = 0
     labelcount = [0 for i in range(10)]
     for batch in dataiter:
         label = batch.label[0].asnumpy().flatten()
-        assert((batch.data[0].asnumpy()[:,0,0] == label).all())
+        assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
         for i in range(label.shape[0]):
             labelcount[int(label[i])] += 1
 
@@ -110,6 +116,7 @@ def test_NDArrayIter():
         else:
             assert(labelcount[i] == 100)
 
+
 def test_NDArrayIter_h5py():
     if not h5py:
         return
@@ -128,17 +135,19 @@ def test_NDArrayIter_h5py():
         f.create_dataset("data", data=data)
         f.create_dataset("label", data=label)
 
-        dataiter = mx.io.NDArrayIter(f["data"], f["label"], 128, True, last_batch_handle='pad')
+        dataiter = mx.io.NDArrayIter(
+            f["data"], f["label"], 128, True, last_batch_handle='pad')
         batchidx = 0
         for batch in dataiter:
             batchidx += 1
         assert(batchidx == 8)
 
-        dataiter = mx.io.NDArrayIter(f["data"], f["label"], 128, False, last_batch_handle='pad')
+        dataiter = mx.io.NDArrayIter(
+            f["data"], f["label"], 128, False, last_batch_handle='pad')
         labelcount = [0 for i in range(10)]
         for batch in dataiter:
             label = batch.label[0].asnumpy().flatten()
-            assert((batch.data[0].asnumpy()[:,0,0] == label).all())
+            assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
             for i in range(label.shape[0]):
                 labelcount[int(label[i])] += 1
 
@@ -153,6 +162,7 @@ def test_NDArrayIter_h5py():
         else:
             assert(labelcount[i] == 100)
 
+
 def test_NDArrayIter_csr():
     # creating toy data
     num_rows = rnd.randint(5, 15)
@@ -163,17 +173,19 @@ def test_NDArrayIter_csr():
     dns = csr.asnumpy()
 
     # CSRNDArray or scipy.sparse.csr_matrix with last_batch_handle not equal to 'discard' will throw NotImplementedError
-    assertRaises(NotImplementedError, mx.io.NDArrayIter, {'data': csr}, dns, batch_size)
+    assertRaises(NotImplementedError, mx.io.NDArrayIter,
+                 {'data': csr}, dns, batch_size)
     try:
         import scipy.sparse as spsp
         train_data = spsp.csr_matrix(dns)
-        assertRaises(NotImplementedError, mx.io.NDArrayIter, {'data': train_data}, dns, batch_size)
+        assertRaises(NotImplementedError, mx.io.NDArrayIter,
+                     {'data': train_data}, dns, batch_size)
     except ImportError:
         pass
 
     # CSRNDArray with shuffle
     csr_iter = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns}, dns, batch_size,
-                    shuffle=True, last_batch_handle='discard'))
+                                      shuffle=True, last_batch_handle='discard'))
     num_batch = 0
     for batch in csr_iter:
         num_batch += 1
@@ -181,7 +193,8 @@ def test_NDArrayIter_csr():
     assert(num_batch == num_rows // batch_size)
 
     # make iterators
-    csr_iter = iter(mx.io.NDArrayIter(csr, csr, batch_size, last_batch_handle='discard'))
+    csr_iter = iter(mx.io.NDArrayIter(
+        csr, csr, batch_size, last_batch_handle='discard'))
     begin = 0
     for batch in csr_iter:
         expected = np.zeros((batch_size, num_cols))
@@ -192,6 +205,7 @@ def test_NDArrayIter_csr():
         assert_almost_equal(batch.data[0].asnumpy(), expected)
         begin += batch_size
 
+
 def test_LibSVMIter():
 
     def check_libSVMIter_synthetic():
@@ -214,8 +228,8 @@ def test_LibSVMIter():
         data_train = mx.io.LibSVMIter(data_libsvm=data_path, label_libsvm=label_path,
                                       data_shape=(3, ), label_shape=(3, ), batch_size=3)
 
-        first = mx.nd.array([[ 0.5, 0., 1.2], [ 0., 0., 0.], [ 0.6, 2.4, 1.2]])
-        second = mx.nd.array([[ 0., 0., -1.2], [ 0.5, 0., 1.2], [ 0., 0., 0.]])
+        first = mx.nd.array([[0.5, 0., 1.2], [0., 0., 0.], [0.6, 2.4, 1.2]])
+        second = mx.nd.array([[0., 0., -1.2], [0.5, 0., 1.2], [0., 0., 0.]])
         i = 0
         for batch in iter(data_train):
             expected = first.asnumpy() if i == 0 else second.asnumpy()
@@ -286,10 +300,13 @@ def test_DataBatch():
     from nose.tools import ok_
     from mxnet.io import DataBatch
     import re
-    batch = DataBatch(data=[mx.nd.ones((2,3))])
-    ok_(re.match('DataBatch: data shapes: \[\(2L?, 3L?\)\] label shapes: None', str(batch)))
-    batch = DataBatch(data=[mx.nd.ones((2,3)), mx.nd.ones((7,8))], label=[mx.nd.ones((4,5))])
-    ok_(re.match('DataBatch: data shapes: \[\(2L?, 3L?\), \(7L?, 8L?\)\] label shapes: \[\(4L?, 5L?\)\]', str(batch)))
+    batch = DataBatch(data=[mx.nd.ones((2, 3))])
+    ok_(re.match(
+        'DataBatch: data shapes: \[\(2L?, 3L?\)\] label shapes: None', str(batch)))
+    batch = DataBatch(data=[mx.nd.ones((2, 3)), mx.nd.ones(
+        (7, 8))], label=[mx.nd.ones((4, 5))])
+    ok_(re.match(
+        'DataBatch: data shapes: \[\(2L?, 3L?\), \(7L?, 8L?\)\] label shapes: \[\(4L?, 5L?\)\]', str(batch)))
 
 
 def test_CSVIter():
@@ -307,7 +324,7 @@ def test_CSVIter():
             for i in range(1000):
                 fout.write('0\n')
 
-        data_train = mx.io.CSVIter(data_csv=data_path, data_shape=(8,8),
+        data_train = mx.io.CSVIter(data_csv=data_path, data_shape=(8, 8),
                                    label_csv=label_path, batch_size=100, dtype=dtype)
         expected = mx.nd.ones((100, 8, 8), dtype=dtype) * int(entry_str)
         for batch in iter(data_train):
@@ -318,6 +335,93 @@ def test_CSVIter():
     for dtype in ['int32', 'float32']:
         check_CSVIter_synthetic(dtype=dtype)
 
+
+def test_ImageRecordIter_seed_augmentation():
+    get_cifar10()
+    seed_aug = 3
+
+    # check whether to get constant images after fixing seed_aug
+    dataiter = mx.io.ImageRecordIter(
+        path_imgrec="data/cifar/train.rec",
+        mean_img="data/cifar/cifar10_mean.bin",
+        shuffle=False,
+        data_shape=(3, 28, 28),
+        batch_size=3,
+        rand_crop=True,
+        rand_mirror=True,
+        max_random_scale=1.3,
+        max_random_illumination=3,
+        max_rotate_angle=10,
+        random_l=50,
+        random_s=40,
+        random_h=10,
+        max_shear_ratio=2,
+        seed_aug=seed_aug)
+    batch = dataiter.next()
+    data = batch.data[0].asnumpy().astype(np.uint8)
+
+    dataiter = mx.io.ImageRecordIter(
+        path_imgrec="data/cifar/train.rec",
+        mean_img="data/cifar/cifar10_mean.bin",
+        shuffle=False,
+        data_shape=(3, 28, 28),
+        batch_size=3,
+        rand_crop=True,
+        rand_mirror=True,
+        max_random_scale=1.3,
+        max_random_illumination=3,
+        max_rotate_angle=10,
+        random_l=50,
+        random_s=40,
+        random_h=10,
+        max_shear_ratio=2,
+        seed_aug=seed_aug)
+    batch = dataiter.next()
+    data2 = batch.data[0].asnumpy().astype(np.uint8)
+    assert(np.array_equal(data,data2))
+
+    # check whether to get different images after change seed_aug
+    dataiter = mx.io.ImageRecordIter(
+        path_imgrec="data/cifar/train.rec",
+        mean_img="data/cifar/cifar10_mean.bin",
+        shuffle=False,
+        data_shape=(3, 28, 28),
+        batch_size=3,
+        rand_crop=True,
+        rand_mirror=True,
+        max_random_scale=1.3,
+        max_random_illumination=3,
+        max_rotate_angle=10,
+        random_l=50,
+        random_s=40,
+        random_h=10,
+        max_shear_ratio=2,
+        seed_aug=seed_aug+1)
+    batch = dataiter.next()
+    data2 = batch.data[0].asnumpy().astype(np.uint8)
+    assert(not np.array_equal(data,data2))
+
+    # check whether seed_aug changes the iterator behavior
+    dataiter = mx.io.ImageRecordIter(
+        path_imgrec="data/cifar/train.rec",
+        mean_img="data/cifar/cifar10_mean.bin",
+        shuffle=False,
+        data_shape=(3, 28, 28),
+        batch_size=3)
+    batch = dataiter.next()
+    data = batch.data[0].asnumpy().astype(np.uint8)
+
+    dataiter = mx.io.ImageRecordIter(
+        path_imgrec="data/cifar/train.rec",
+        mean_img="data/cifar/cifar10_mean.bin",
+        shuffle=False,
+        data_shape=(3, 28, 28),
+        batch_size=3,
+        seed_aug=seed_aug)
+    batch = dataiter.next()
+    data2 = batch.data[0].asnumpy().astype(np.uint8)
+    assert(np.array_equal(data,data2))
+
 if __name__ == "__main__":
     test_NDArrayIter()
     if h5py:
@@ -327,3 +431,4 @@ if __name__ == "__main__":
     test_LibSVMIter()
     test_NDArrayIter_csr()
     test_CSVIter()
+    test_ImageRecordIter_seed_augmentation()