You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/08/18 00:18:40 UTC
[incubator-mxnet] branch master updated: [MXNET-737]Add last batch
handle for imageiter (#12131)
This is an automated email from the ASF dual-hosted git repository.
skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new afb77f8 [MXNET-737]Add last batch handle for imageiter (#12131)
afb77f8 is described below
commit afb77f89a5a883a13b676d90b4363181311da02a
Author: Jake Lee <gs...@gmail.com>
AuthorDate: Fri Aug 17 17:18:30 2018 -0700
[MXNET-737]Add last batch handle for imageiter (#12131)
[MXNET-737]Add last batch handle for imageiter
---
python/mxnet/image/image.py | 100 ++++++++++++++++++++++++++++++------
tests/python/unittest/test_image.py | 84 ++++++++++++++++++++++++------
2 files changed, 152 insertions(+), 32 deletions(-)
diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py
index c2a1906..24f5309 100644
--- a/python/mxnet/image/image.py
+++ b/python/mxnet/image/image.py
@@ -1059,6 +1059,12 @@ class ImageIter(io.DataIter):
Label name for provided symbols.
dtype : str
Label data type. Default: float32. Other options: int32, int64, float64
+ last_batch_handle : str, optional
+ How to handle the last batch.
+ This parameter can be 'pad'(default), 'discard' or 'roll_over'.
+ If 'pad', the last batch will be padded with data starting from the begining
+ If 'discard', the last batch will be discarded
+ If 'roll_over', the remaining elements will be rolled over to the next iteration
kwargs : ...
More arguments for creating augmenter. See mx.image.CreateAugmenter.
"""
@@ -1066,7 +1072,8 @@ class ImageIter(io.DataIter):
def __init__(self, batch_size, data_shape, label_width=1,
path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None,
shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None,
- data_name='data', label_name='softmax_label', dtype='float32', **kwargs):
+ data_name='data', label_name='softmax_label', dtype='float32',
+ last_batch_handle='pad', **kwargs):
super(ImageIter, self).__init__()
assert path_imgrec or path_imglist or (isinstance(imglist, list))
assert dtype in ['int32', 'float32', 'int64', 'float64'], dtype + ' label not supported'
@@ -1129,7 +1136,6 @@ class ImageIter(io.DataIter):
self.batch_size = batch_size
self.data_shape = data_shape
self.label_width = label_width
-
self.shuffle = shuffle
if self.imgrec is None:
self.seq = imgkeys
@@ -1149,22 +1155,49 @@ class ImageIter(io.DataIter):
else:
self.auglist = aug_list
self.cur = 0
+ self._allow_read = True
+ self.last_batch_handle = last_batch_handle
+ self.num_image = len(self.seq) if self.seq is not None else None
+ self._cache_data = None
+ self._cache_label = None
+ self._cache_idx = None
self.reset()
def reset(self):
"""Resets the iterator to the beginning of the data."""
- if self.shuffle:
+ if self.seq is not None and self.shuffle:
+ random.shuffle(self.seq)
+ if self.last_batch_handle != 'roll_over' or \
+ self._cache_data is None:
+ if self.imgrec is not None:
+ self.imgrec.reset()
+ self.cur = 0
+ if self._allow_read is False:
+ self._allow_read = True
+
+ def hard_reset(self):
+ """Resets the iterator and ignore roll over data"""
+ if self.seq is not None and self.shuffle:
random.shuffle(self.seq)
if self.imgrec is not None:
self.imgrec.reset()
self.cur = 0
+ self._allow_read = True
+ self._cache_data = None
+ self._cache_label = None
+ self._cache_idx = None
def next_sample(self):
"""Helper function for reading in next sample."""
+ if self._allow_read is False:
+ raise StopIteration
if self.seq is not None:
- if self.cur >= len(self.seq):
+ if self.cur < self.num_image:
+ idx = self.seq[self.cur]
+ else:
+ if self.last_batch_handle != 'discard':
+ self.cur = 0
raise StopIteration
- idx = self.seq[self.cur]
self.cur += 1
if self.imgrec is not None:
s = self.imgrec.read_idx(idx)
@@ -1179,17 +1212,16 @@ class ImageIter(io.DataIter):
else:
s = self.imgrec.read()
if s is None:
+ if self.last_batch_handle != 'discard':
+ self.imgrec.reset()
raise StopIteration
header, img = recordio.unpack(s)
return header.label, img
- def next(self):
- """Returns the next batch of data."""
+ def _batchify(self, batch_data, batch_label, start=0):
+ """Helper function for batchifying data"""
+ i = start
batch_size = self.batch_size
- c, h, w = self.data_shape
- batch_data = nd.empty((batch_size, c, h, w))
- batch_label = nd.empty(self.provide_label[0][1])
- i = 0
try:
while i < batch_size:
label, s = self.next_sample()
@@ -1207,8 +1239,47 @@ class ImageIter(io.DataIter):
except StopIteration:
if not i:
raise StopIteration
+ return i
- return io.DataBatch([batch_data], [batch_label], batch_size - i)
+ def next(self):
+ """Returns the next batch of data."""
+ batch_size = self.batch_size
+ c, h, w = self.data_shape
+ # if last batch data is rolled over
+ if self._cache_data is not None:
+ # check both the data and label have values
+ assert self._cache_label is not None, "_cache_label didn't have values"
+ assert self._cache_idx is not None, "_cache_idx didn't have values"
+ batch_data = self._cache_data
+ batch_label = self._cache_label
+ i = self._cache_idx
+ # clear the cache data
+ else:
+ batch_data = nd.empty((batch_size, c, h, w))
+ batch_label = nd.empty(self.provide_label[0][1])
+ i = self._batchify(batch_data, batch_label)
+ # calculate the padding
+ pad = batch_size - i
+ # handle padding for the last batch
+ if pad != 0:
+ if self.last_batch_handle == 'discard':
+ raise StopIteration
+ # if the option is 'roll_over', throw StopIteration and cache the data
+ elif self.last_batch_handle == 'roll_over' and \
+ self._cache_data is None:
+ self._cache_data = batch_data
+ self._cache_label = batch_label
+ self._cache_idx = i
+ raise StopIteration
+ else:
+ _ = self._batchify(batch_data, batch_label, i)
+ if self.last_batch_handle == 'pad':
+ self._allow_read = False
+ else:
+ self._cache_data = None
+ self._cache_label = None
+ self._cache_idx = None
+ return io.DataBatch([batch_data], [batch_label], pad=pad)
def check_data_shape(self, data_shape):
"""Checks if the input data shape is valid"""
@@ -1228,9 +1299,9 @@ class ImageIter(io.DataIter):
def locate():
"""Locate the image file/index if decode fails."""
if self.seq is not None:
- idx = self.seq[self.cur - 1]
+ idx = self.seq[(self.cur % self.num_image) - 1]
else:
- idx = self.cur - 1
+ idx = (self.cur % self.num_image) - 1
if self.imglist is not None:
_, fname = self.imglist[idx]
msg = "filename: {}".format(fname)
@@ -1245,7 +1316,6 @@ class ImageIter(io.DataIter):
def read_image(self, fname):
"""Reads an input image `fname` and returns the decoded raw bytes.
-
Example usage:
----------
>>> dataIter.read_image('Face.jpg') # returns decoded raw bytes.
diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py
index 9eec183..0df08af 100644
--- a/tests/python/unittest/test_image.py
+++ b/tests/python/unittest/test_image.py
@@ -80,7 +80,6 @@ class TestImage(unittest.TestCase):
image_read = mx.img.image.imread(img)
same(image.asnumpy(), image_read.asnumpy())
-
def test_imdecode(self):
try:
import cv2
@@ -130,29 +129,81 @@ class TestImage(unittest.TestCase):
mx.nd.array(mean), mx.nd.array(std))
assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, atol=1e-3)
-
def test_imageiter(self):
def check_imageiter(dtype='float32'):
im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
- test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=im_list,
- path_root='', dtype=dtype)
- for _ in range(3):
- for batch in test_iter:
- pass
- test_iter.reset()
-
- # test with list file
fname = './data/test_imageiter.lst'
- file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) \
- for k, x in enumerate(TestImage.IMAGES)]
+ file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x])
+ for k, x in enumerate(TestImage.IMAGES)]
with open(fname, 'w') as f:
for line in file_list:
f.write(line + '\n')
+
+ test_list = ['imglist', 'path_imglist']
- test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, path_imglist=fname,
- path_root='', dtype=dtype)
- for batch in test_iter:
- pass
+ for test in test_list:
+ imglist = im_list if test == 'imglist' else None
+ path_imglist = fname if test == 'path_imglist' else None
+
+ test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=imglist,
+ path_imglist=path_imglist, path_root='', dtype=dtype)
+ # test batch data shape
+ for _ in range(3):
+ for batch in test_iter:
+ assert batch.data[0].shape == (2, 3, 224, 224)
+ test_iter.reset()
+ # test last batch handle(discard)
+ test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
+ path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard')
+ i = 0
+ for batch in test_iter:
+ i += 1
+ assert i == 5
+ # test last_batch_handle(pad)
+ test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
+ path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
+ i = 0
+ for batch in test_iter:
+ if i == 0:
+ first_three_data = batch.data[0][:2]
+ if i == 5:
+ last_three_data = batch.data[0][1:]
+ i += 1
+ assert i == 6
+ assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy())
+ # test last_batch_handle(roll_over)
+ test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
+ path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over')
+ i = 0
+ for batch in test_iter:
+ if i == 0:
+ first_image = batch.data[0][0]
+ i += 1
+ assert i == 5
+ test_iter.reset()
+ first_batch_roll_over = test_iter.next()
+ assert np.array_equal(
+ first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy())
+ assert first_batch_roll_over.pad == 2
+ # test iteratopr work properly after calling reset several times when last_batch_handle is roll_over
+ for _ in test_iter:
+ pass
+ test_iter.reset()
+ first_batch_roll_over_twice = test_iter.next()
+ assert np.array_equal(
+ first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy())
+ assert first_batch_roll_over_twice.pad == 1
+ # we've called next once
+ i = 1
+ for _ in test_iter:
+ i += 1
+ # test the third epoch with size 6
+ assert i == 6
+ # test shuffle option for sanity test
+ test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True,
+ path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
+ for _ in test_iter:
+ pass
for dtype in ['int32', 'float32', 'int64', 'float64']:
check_imageiter(dtype)
@@ -183,7 +234,6 @@ class TestImage(unittest.TestCase):
for batch in test_iter:
pass
-
def test_image_detiter(self):
im_list = [_generate_objects() + [x] for x in TestImage.IMAGES]
det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='')