You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/12/08 01:06:19 UTC

[incubator-mxnet] branch master updated: fix the situation where idx didn't align with rec (#13550)

This is an automated email from the ASF dual-hosted git repository.

zhreshold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 186a746  fix the situation where idx didn't align with rec (#13550)
186a746 is described below

commit 186a746e557d8ef9551f52c6e4c1175394c323c0
Author: Jake Lee <gs...@gmail.com>
AuthorDate: Fri Dec 7 17:06:05 2018 -0800

    fix the situation where idx didn't align with rec (#13550)
    
    minor fix the image.py
    
    add last_batch_handle for imagedeiter
    
    remove the label type
    
    refactor the imageiter unit test
    
    fix the trailing whitespace
    
    fix coding style
    
    add new line
    
    move helper function to the top of the file
---
 python/mxnet/image/detection.py     |  64 +++++++++++--
 python/mxnet/image/image.py         |   5 +-
 tests/python/unittest/test_image.py | 184 ++++++++++++++++++++----------------
 3 files changed, 157 insertions(+), 96 deletions(-)

diff --git a/python/mxnet/image/detection.py b/python/mxnet/image/detection.py
index b27917c..d5b5eca 100644
--- a/python/mxnet/image/detection.py
+++ b/python/mxnet/image/detection.py
@@ -658,19 +658,26 @@ class ImageDetIter(ImageIter):
         Data name for provided symbols.
     label_name : str
         Name for detection labels
+    last_batch_handle : str, optional
+        How to handle the last batch.
+        This parameter can be 'pad'(default), 'discard' or 'roll_over'.
+        If 'pad', the last batch will be padded with data starting from the begining
+        If 'discard', the last batch will be discarded
+        If 'roll_over', the remaining elements will be rolled over to the next iteration
     kwargs : ...
         More arguments for creating augmenter. See mx.image.CreateDetAugmenter.
     """
     def __init__(self, batch_size, data_shape,
                  path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None,
                  shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None,
-                 data_name='data', label_name='label', **kwargs):
+                 data_name='data', label_name='label', last_batch_handle='pad', **kwargs):
         super(ImageDetIter, self).__init__(batch_size=batch_size, data_shape=data_shape,
                                            path_imgrec=path_imgrec, path_imglist=path_imglist,
                                            path_root=path_root, path_imgidx=path_imgidx,
                                            shuffle=shuffle, part_index=part_index,
                                            num_parts=num_parts, aug_list=[], imglist=imglist,
-                                           data_name=data_name, label_name=label_name)
+                                           data_name=data_name, label_name=label_name,
+                                           last_batch_handle=last_batch_handle)
 
         if aug_list is None:
             self.auglist = CreateDetAugmenter(data_shape, **kwargs)
@@ -751,14 +758,10 @@ class ImageDetIter(ImageIter):
             self.provide_label = [(self.provide_label[0][0], (self.batch_size,) + label_shape)]
             self.label_shape = label_shape
 
-    def next(self):
-        """Override the function for returning next batch."""
+    def _batchify(self, batch_data, batch_label, start=0):
+        """Override the helper function for batchifying data"""
+        i = start
         batch_size = self.batch_size
-        c, h, w = self.data_shape
-        batch_data = nd.zeros((batch_size, c, h, w))
-        batch_label = nd.empty(self.provide_label[0][1])
-        batch_label[:] = -1
-        i = 0
         try:
             while i < batch_size:
                 label, s = self.next_sample()
@@ -783,7 +786,48 @@ class ImageDetIter(ImageIter):
             if not i:
                 raise StopIteration
 
-        return io.DataBatch([batch_data], [batch_label], batch_size - i)
+        return i
+
+    def next(self):
+        """Override the function for returning next batch."""
+        batch_size = self.batch_size
+        c, h, w = self.data_shape
+        # if last batch data is rolled over
+        if self._cache_data is not None:
+            # check both the data and label have values
+            assert self._cache_label is not None, "_cache_label didn't have values"
+            assert self._cache_idx is not None, "_cache_idx didn't have values"
+            batch_data = self._cache_data
+            batch_label = self._cache_label
+            i = self._cache_idx
+        else:
+            batch_data = nd.zeros((batch_size, c, h, w))
+            batch_label = nd.empty(self.provide_label[0][1])
+            batch_label[:] = -1
+            i = self._batchify(batch_data, batch_label)
+        # calculate the padding
+        pad = batch_size - i
+        # handle padding for the last batch
+        if pad != 0:
+            if self.last_batch_handle == 'discard':
+                raise StopIteration
+            # if the option is 'roll_over', throw StopIteration and cache the data
+            elif self.last_batch_handle == 'roll_over' and \
+                self._cache_data is None:
+                self._cache_data = batch_data
+                self._cache_label = batch_label
+                self._cache_idx = i
+                raise StopIteration
+            else:
+                _ = self._batchify(batch_data, batch_label, i)
+                if self.last_batch_handle == 'pad':
+                    self._allow_read = False
+                else:
+                    self._cache_data = None
+                    self._cache_label = None
+                    self._cache_idx = None
+
+        return io.DataBatch([batch_data], [batch_label], pad=pad)
 
     def augmentation_transform(self, data, label):  # pylint: disable=arguments-differ
         """Override Transforms input data with specified augmentations."""
diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py
index c9a457f..9c2a1cb 100644
--- a/python/mxnet/image/image.py
+++ b/python/mxnet/image/image.py
@@ -1145,7 +1145,7 @@ class ImageIter(io.DataIter):
         self.shuffle = shuffle
         if self.imgrec is None:
             self.seq = imgkeys
-        elif shuffle or num_parts > 1:
+        elif shuffle or num_parts > 1 or path_imgidx:
             assert self.imgidx is not None
             self.seq = self.imgidx
         else:
@@ -1261,7 +1261,7 @@ class ImageIter(io.DataIter):
             i = self._cache_idx
             # clear the cache data
         else:
-            batch_data = nd.empty((batch_size, c, h, w))
+            batch_data = nd.zeros((batch_size, c, h, w))
             batch_label = nd.empty(self.provide_label[0][1])
             i = self._batchify(batch_data, batch_label)
         # calculate the padding
@@ -1285,6 +1285,7 @@ class ImageIter(io.DataIter):
                     self._cache_data = None
                     self._cache_label = None
                     self._cache_idx = None
+
         return io.DataBatch([batch_data], [batch_label], pad=pad)
 
     def check_data_shape(self, data_shape):
diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py
index 4f66823..4063027 100644
--- a/tests/python/unittest/test_image.py
+++ b/tests/python/unittest/test_image.py
@@ -25,6 +25,7 @@ import unittest
 
 from nose.tools import raises
 
+
 def _get_data(url, dirname):
     import os, tarfile
     download(url, dirname=dirname, overwrite=False)
@@ -50,6 +51,62 @@ def _generate_objects():
     label = np.hstack((cid[:, np.newaxis], boxes)).ravel().tolist()
     return [2, 5] + label
 
+def _test_imageiter_last_batch(imageiter_list, assert_data_shape):
+    test_iter = imageiter_list[0]
+    # test batch data shape
+    for _ in range(3):
+        for batch in test_iter:
+            assert batch.data[0].shape == assert_data_shape
+        test_iter.reset()
+    # test last batch handle(discard)
+    test_iter = imageiter_list[1]
+    i = 0
+    for batch in test_iter:
+        i += 1
+    assert i == 5
+    # test last_batch_handle(pad)
+    test_iter = imageiter_list[2]
+    i = 0
+    for batch in test_iter:
+        if i == 0:
+            first_three_data = batch.data[0][:2]
+        if i == 5:
+            last_three_data = batch.data[0][1:]
+        i += 1
+    assert i == 6
+    assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy())
+    # test last_batch_handle(roll_over)
+    test_iter = imageiter_list[3]
+    i = 0
+    for batch in test_iter:
+        if i == 0:
+            first_image = batch.data[0][0]
+        i += 1
+    assert i == 5
+    test_iter.reset()
+    first_batch_roll_over = test_iter.next()
+    assert np.array_equal(
+        first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy())
+    assert first_batch_roll_over.pad == 2
+    # test iteratopr work properly after calling reset several times when last_batch_handle is roll_over
+    for _ in test_iter:
+        pass
+    test_iter.reset()
+    first_batch_roll_over_twice = test_iter.next()
+    assert np.array_equal(
+        first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy())
+    assert first_batch_roll_over_twice.pad == 1
+    # we've called next once
+    i = 1
+    for _ in test_iter:
+        i += 1
+    # test the third epoch with size 6
+    assert i == 6
+    # test shuffle option for sanity test
+    test_iter = imageiter_list[4]
+    for _ in test_iter:
+        pass
+
 
 class TestImage(unittest.TestCase):
     IMAGES_URL = "http://data.mxnet.io/data/test_images.tar.gz"
@@ -151,86 +208,32 @@ class TestImage(unittest.TestCase):
             assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, atol=1e-3)
 
     def test_imageiter(self):
-        def check_imageiter(dtype='float32'):
-            im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
-            fname = './data/test_imageiter.lst'
-            file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x])
-                         for k, x in enumerate(TestImage.IMAGES)]
-            with open(fname, 'w') as f:
-                for line in file_list:
-                    f.write(line + '\n')
-
-            test_list = ['imglist', 'path_imglist']
+        im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
+        fname = './data/test_imageiter.lst'
+        file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x])
+                        for k, x in enumerate(TestImage.IMAGES)]
+        with open(fname, 'w') as f:
+            for line in file_list:
+                f.write(line + '\n')
 
+        test_list = ['imglist', 'path_imglist']
+        for dtype in ['int32', 'float32', 'int64', 'float64']:
             for test in test_list:
                 imglist = im_list if test == 'imglist' else None
                 path_imglist = fname if test == 'path_imglist' else None
-
-                test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=imglist,
-                    path_imglist=path_imglist, path_root='', dtype=dtype)
-                # test batch data shape
-                for _ in range(3):
-                    for batch in test_iter:
-                        assert batch.data[0].shape == (2, 3, 224, 224)
-                    test_iter.reset()
-                # test last batch handle(discard)
-                test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
-                    path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard')
-                i = 0
-                for batch in test_iter:
-                    i += 1
-                assert i == 5
-                # test last_batch_handle(pad)
-                test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
-                    path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
-                i = 0
-                for batch in test_iter:
-                    if i == 0:
-                        first_three_data = batch.data[0][:2]
-                    if i == 5:
-                        last_three_data = batch.data[0][1:]
-                    i += 1
-                assert i == 6
-                assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy())
-                # test last_batch_handle(roll_over)
-                test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
-                    path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over')
-                i = 0
-                for batch in test_iter:
-                    if i == 0:
-                        first_image = batch.data[0][0]
-                    i += 1
-                assert i == 5
-                test_iter.reset()
-                first_batch_roll_over = test_iter.next()
-                assert np.array_equal(
-                    first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy())
-                assert first_batch_roll_over.pad == 2
-                # test iteratopr work properly after calling reset several times when last_batch_handle is roll_over
-                for _ in test_iter:
-                    pass
-                test_iter.reset()
-                first_batch_roll_over_twice = test_iter.next()
-                assert np.array_equal(
-                    first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy())
-                assert first_batch_roll_over_twice.pad == 1
-                # we've called next once
-                i = 1
-                for _ in test_iter:
-                    i += 1
-                # test the third epoch with size 6
-                assert i == 6
-                # test shuffle option for sanity test
-                test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True,
-                                               path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
-                for _ in test_iter:
-                    pass
-
-        for dtype in ['int32', 'float32', 'int64', 'float64']:
-            check_imageiter(dtype)
-
-        # test with default dtype
-        check_imageiter()
+                imageiter_list = [
+                    mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=imglist,
+                        path_imglist=path_imglist, path_root='', dtype=dtype),
+                    mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
+                        path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard'),
+                    mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
+                        path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad'),
+                    mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
+                        path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over'),
+                    mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True,
+                        path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
+                ]
+                _test_imageiter_last_batch(imageiter_list, (2, 3, 224, 224))
 
     @with_seed()
     def test_augmenters(self):
@@ -259,16 +262,20 @@ class TestImage(unittest.TestCase):
         im_list = [_generate_objects() + [x] for x in TestImage.IMAGES]
         det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='')
         for _ in range(3):
-            for batch in det_iter:
+            for _ in det_iter:
                 pass
-            det_iter.reset()
-
+        det_iter.reset()
         val_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='')
         det_iter = val_iter.sync_label_shape(det_iter)
         assert det_iter.data_shape == val_iter.data_shape
         assert det_iter.label_shape == val_iter.label_shape
 
-        # test file list
+        # test batch_size is not divisible by number of images
+        det_iter = mx.image.ImageDetIter(4, (3, 300, 300), imglist=im_list, path_root='')
+        for _ in det_iter:
+            pass
+
+        # test file list with last batch handle
         fname = './data/test_imagedetiter.lst'
         im_list = [[k] + _generate_objects() + [x] for k, x in enumerate(TestImage.IMAGES)]
         with open(fname, 'w') as f:
@@ -276,10 +283,19 @@ class TestImage(unittest.TestCase):
                 line = '\t'.join([str(k) for k in line])
                 f.write(line + '\n')
 
-        det_iter = mx.image.ImageDetIter(2, (3, 400, 400), path_imglist=fname,
-            path_root='')
-        for batch in det_iter:
-            pass
+        imageiter_list = [
+            mx.image.ImageDetIter(2, (3, 400, 400),
+                path_imglist=fname, path_root=''),
+            mx.image.ImageDetIter(3, (3, 400, 400),
+                path_imglist=fname, path_root='', last_batch_handle='discard'),
+            mx.image.ImageDetIter(3, (3, 400, 400),
+                path_imglist=fname, path_root='', last_batch_handle='pad'),
+            mx.image.ImageDetIter(3, (3, 400, 400),
+                path_imglist=fname, path_root='', last_batch_handle='roll_over'),
+            mx.image.ImageDetIter(3, (3, 400, 400), shuffle=True,
+                path_imglist=fname, path_root='', last_batch_handle='pad')
+        ]
+        _test_imageiter_last_batch(imageiter_list, (2, 3, 400, 400))
 
     def test_det_augmenters(self):
         # only test if all augmenters will work