You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/08/18 00:18:40 UTC

[incubator-mxnet] branch master updated: [MXNET-737]Add last batch handle for imageiter (#12131)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new afb77f8  [MXNET-737]Add last batch handle for imageiter (#12131)
afb77f8 is described below

commit afb77f89a5a883a13b676d90b4363181311da02a
Author: Jake Lee <gs...@gmail.com>
AuthorDate: Fri Aug 17 17:18:30 2018 -0700

    [MXNET-737]Add last batch handle for imageiter (#12131)
    
    [MXNET-737]Add last batch handle for imageiter
---
 python/mxnet/image/image.py         | 100 ++++++++++++++++++++++++++++++------
 tests/python/unittest/test_image.py |  84 ++++++++++++++++++++++++------
 2 files changed, 152 insertions(+), 32 deletions(-)

diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py
index c2a1906..24f5309 100644
--- a/python/mxnet/image/image.py
+++ b/python/mxnet/image/image.py
@@ -1059,6 +1059,12 @@ class ImageIter(io.DataIter):
         Label name for provided symbols.
     dtype : str
         Label data type. Default: float32. Other options: int32, int64, float64
+    last_batch_handle : str, optional
+        How to handle the last batch.
+        This parameter can be 'pad'(default), 'discard' or 'roll_over'.
+        If 'pad', the last batch will be padded with data starting from the begining
+        If 'discard', the last batch will be discarded
+        If 'roll_over', the remaining elements will be rolled over to the next iteration
     kwargs : ...
         More arguments for creating augmenter. See mx.image.CreateAugmenter.
     """
@@ -1066,7 +1072,8 @@ class ImageIter(io.DataIter):
     def __init__(self, batch_size, data_shape, label_width=1,
                  path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None,
                  shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None,
-                 data_name='data', label_name='softmax_label', dtype='float32', **kwargs):
+                 data_name='data', label_name='softmax_label', dtype='float32',
+                 last_batch_handle='pad', **kwargs):
         super(ImageIter, self).__init__()
         assert path_imgrec or path_imglist or (isinstance(imglist, list))
         assert dtype in ['int32', 'float32', 'int64', 'float64'], dtype + ' label not supported'
@@ -1129,7 +1136,6 @@ class ImageIter(io.DataIter):
         self.batch_size = batch_size
         self.data_shape = data_shape
         self.label_width = label_width
-
         self.shuffle = shuffle
         if self.imgrec is None:
             self.seq = imgkeys
@@ -1149,22 +1155,49 @@ class ImageIter(io.DataIter):
         else:
             self.auglist = aug_list
         self.cur = 0
+        self._allow_read = True
+        self.last_batch_handle = last_batch_handle
+        self.num_image = len(self.seq) if self.seq is not None else None
+        self._cache_data = None
+        self._cache_label = None
+        self._cache_idx = None
         self.reset()
 
     def reset(self):
         """Resets the iterator to the beginning of the data."""
-        if self.shuffle:
+        if self.seq is not None and self.shuffle:
+            random.shuffle(self.seq)
+        if self.last_batch_handle != 'roll_over' or \
+            self._cache_data is None:
+            if self.imgrec is not None:
+                self.imgrec.reset()
+            self.cur = 0
+            if self._allow_read is False:
+                self._allow_read = True
+
+    def hard_reset(self):
+        """Resets the iterator and ignore roll over data"""
+        if self.seq is not None and self.shuffle:
             random.shuffle(self.seq)
         if self.imgrec is not None:
             self.imgrec.reset()
         self.cur = 0
+        self._allow_read = True
+        self._cache_data = None
+        self._cache_label = None
+        self._cache_idx = None
 
     def next_sample(self):
         """Helper function for reading in next sample."""
+        if self._allow_read is False:
+            raise StopIteration
         if self.seq is not None:
-            if self.cur >= len(self.seq):
+            if self.cur < self.num_image:
+                idx = self.seq[self.cur]
+            else:
+                if self.last_batch_handle != 'discard':
+                    self.cur = 0
                 raise StopIteration
-            idx = self.seq[self.cur]
             self.cur += 1
             if self.imgrec is not None:
                 s = self.imgrec.read_idx(idx)
@@ -1179,17 +1212,16 @@ class ImageIter(io.DataIter):
         else:
             s = self.imgrec.read()
             if s is None:
+                if self.last_batch_handle != 'discard':
+                    self.imgrec.reset()
                 raise StopIteration
             header, img = recordio.unpack(s)
             return header.label, img
 
-    def next(self):
-        """Returns the next batch of data."""
+    def _batchify(self, batch_data, batch_label, start=0):
+        """Helper function for batchifying data"""
+        i = start
         batch_size = self.batch_size
-        c, h, w = self.data_shape
-        batch_data = nd.empty((batch_size, c, h, w))
-        batch_label = nd.empty(self.provide_label[0][1])
-        i = 0
         try:
             while i < batch_size:
                 label, s = self.next_sample()
@@ -1207,8 +1239,47 @@ class ImageIter(io.DataIter):
         except StopIteration:
             if not i:
                 raise StopIteration
+        return i
 
-        return io.DataBatch([batch_data], [batch_label], batch_size - i)
+    def next(self):
+        """Returns the next batch of data."""
+        batch_size = self.batch_size
+        c, h, w = self.data_shape
+        # if last batch data is rolled over
+        if self._cache_data is not None:
+            # check both the data and label have values
+            assert self._cache_label is not None, "_cache_label didn't have values"
+            assert self._cache_idx is not None, "_cache_idx didn't have values"
+            batch_data = self._cache_data
+            batch_label = self._cache_label
+            i = self._cache_idx
+            # clear the cache data
+        else:
+            batch_data = nd.empty((batch_size, c, h, w))
+            batch_label = nd.empty(self.provide_label[0][1])
+            i = self._batchify(batch_data, batch_label)
+        # calculate the padding
+        pad = batch_size - i
+        # handle padding for the last batch
+        if pad != 0:
+            if self.last_batch_handle == 'discard':
+                raise StopIteration
+            # if the option is 'roll_over', throw StopIteration and cache the data
+            elif self.last_batch_handle == 'roll_over' and \
+                self._cache_data is None:
+                self._cache_data = batch_data
+                self._cache_label = batch_label
+                self._cache_idx = i
+                raise StopIteration
+            else:
+                _ = self._batchify(batch_data, batch_label, i)
+                if self.last_batch_handle == 'pad':
+                    self._allow_read = False
+                else:
+                    self._cache_data = None
+                    self._cache_label = None
+                    self._cache_idx = None
+        return io.DataBatch([batch_data], [batch_label], pad=pad)
 
     def check_data_shape(self, data_shape):
         """Checks if the input data shape is valid"""
@@ -1228,9 +1299,9 @@ class ImageIter(io.DataIter):
         def locate():
             """Locate the image file/index if decode fails."""
             if self.seq is not None:
-                idx = self.seq[self.cur - 1]
+                idx = self.seq[(self.cur % self.num_image) - 1]
             else:
-                idx = self.cur - 1
+                idx = (self.cur % self.num_image) - 1
             if self.imglist is not None:
                 _, fname = self.imglist[idx]
                 msg = "filename: {}".format(fname)
@@ -1245,7 +1316,6 @@ class ImageIter(io.DataIter):
 
     def read_image(self, fname):
         """Reads an input image `fname` and returns the decoded raw bytes.
-
         Example usage:
         ----------
         >>> dataIter.read_image('Face.jpg') # returns decoded raw bytes.
diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py
index 9eec183..0df08af 100644
--- a/tests/python/unittest/test_image.py
+++ b/tests/python/unittest/test_image.py
@@ -80,7 +80,6 @@ class TestImage(unittest.TestCase):
                 image_read = mx.img.image.imread(img)
                 same(image.asnumpy(), image_read.asnumpy())
 
-
     def test_imdecode(self):
         try:
             import cv2
@@ -130,29 +129,81 @@ class TestImage(unittest.TestCase):
                 mx.nd.array(mean), mx.nd.array(std))
             assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, atol=1e-3)
 
-
     def test_imageiter(self):
         def check_imageiter(dtype='float32'):
             im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
-            test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=im_list,
-                path_root='', dtype=dtype)
-            for _ in range(3):
-                for batch in test_iter:
-                    pass
-                test_iter.reset()
-
-            # test with list file
             fname = './data/test_imageiter.lst'
-            file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) \
-                for k, x in enumerate(TestImage.IMAGES)]
+            file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x])
+                         for k, x in enumerate(TestImage.IMAGES)]
             with open(fname, 'w') as f:
                 for line in file_list:
                     f.write(line + '\n')
+            
+            test_list = ['imglist', 'path_imglist']
 
-            test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, path_imglist=fname,
-                path_root='', dtype=dtype)
-            for batch in test_iter:
-                pass
+            for test in test_list:
+                imglist = im_list if test == 'imglist' else None
+                path_imglist = fname if test == 'path_imglist' else None
+                
+                test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=imglist, 
+                    path_imglist=path_imglist, path_root='', dtype=dtype)
+                # test batch data shape
+                for _ in range(3):
+                    for batch in test_iter:
+                        assert batch.data[0].shape == (2, 3, 224, 224)
+                    test_iter.reset()
+                # test last batch handle(discard)
+                test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
+                    path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard')
+                i = 0
+                for batch in test_iter:
+                    i += 1
+                assert i == 5
+                # test last_batch_handle(pad)
+                test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, 
+                    path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
+                i = 0
+                for batch in test_iter:
+                    if i == 0:
+                        first_three_data = batch.data[0][:2]
+                    if i == 5:
+                        last_three_data = batch.data[0][1:]
+                    i += 1
+                assert i == 6
+                assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy())
+                # test last_batch_handle(roll_over)
+                test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
+                    path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over')
+                i = 0
+                for batch in test_iter:
+                    if i == 0:
+                        first_image = batch.data[0][0]
+                    i += 1
+                assert i == 5
+                test_iter.reset()
+                first_batch_roll_over = test_iter.next()
+                assert np.array_equal(
+                    first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy())
+                assert first_batch_roll_over.pad == 2
+                # test iteratopr work properly after calling reset several times when last_batch_handle is roll_over
+                for _ in test_iter:
+                    pass
+                test_iter.reset()
+                first_batch_roll_over_twice = test_iter.next()
+                assert np.array_equal(
+                    first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy())
+                assert first_batch_roll_over_twice.pad == 1
+                # we've called next once
+                i = 1
+                for _ in test_iter:
+                    i += 1
+                # test the third epoch with size 6
+                assert i == 6
+                # test shuffle option for sanity test
+                test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True,
+                                               path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
+                for _ in test_iter:
+                    pass
 
         for dtype in ['int32', 'float32', 'int64', 'float64']:
             check_imageiter(dtype)
@@ -183,7 +234,6 @@ class TestImage(unittest.TestCase):
         for batch in test_iter:
             pass
 
-
     def test_image_detiter(self):
         im_list = [_generate_objects() + [x] for x in TestImage.IMAGES]
         det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='')