You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/12/08 01:06:19 UTC
[incubator-mxnet] branch master updated: fix the situation where
idx didn't align with rec (#13550)
This is an automated email from the ASF dual-hosted git repository.
zhreshold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 186a746 fix the situation where idx didn't align with rec (#13550)
186a746 is described below
commit 186a746e557d8ef9551f52c6e4c1175394c323c0
Author: Jake Lee <gs...@gmail.com>
AuthorDate: Fri Dec 7 17:06:05 2018 -0800
fix the situation where idx didn't align with rec (#13550)
minor fix the image.py
add last_batch_handle for imagedeiter
remove the label type
refactor the imageiter unit test
fix the trailing whitespace
fix coding style
add new line
move helper function to the top of the file
---
python/mxnet/image/detection.py | 64 +++++++++++--
python/mxnet/image/image.py | 5 +-
tests/python/unittest/test_image.py | 184 ++++++++++++++++++++----------------
3 files changed, 157 insertions(+), 96 deletions(-)
diff --git a/python/mxnet/image/detection.py b/python/mxnet/image/detection.py
index b27917c..d5b5eca 100644
--- a/python/mxnet/image/detection.py
+++ b/python/mxnet/image/detection.py
@@ -658,19 +658,26 @@ class ImageDetIter(ImageIter):
Data name for provided symbols.
label_name : str
Name for detection labels
+ last_batch_handle : str, optional
+ How to handle the last batch.
+ This parameter can be 'pad'(default), 'discard' or 'roll_over'.
+ If 'pad', the last batch will be padded with data starting from the begining
+ If 'discard', the last batch will be discarded
+ If 'roll_over', the remaining elements will be rolled over to the next iteration
kwargs : ...
More arguments for creating augmenter. See mx.image.CreateDetAugmenter.
"""
def __init__(self, batch_size, data_shape,
path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None,
shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None,
- data_name='data', label_name='label', **kwargs):
+ data_name='data', label_name='label', last_batch_handle='pad', **kwargs):
super(ImageDetIter, self).__init__(batch_size=batch_size, data_shape=data_shape,
path_imgrec=path_imgrec, path_imglist=path_imglist,
path_root=path_root, path_imgidx=path_imgidx,
shuffle=shuffle, part_index=part_index,
num_parts=num_parts, aug_list=[], imglist=imglist,
- data_name=data_name, label_name=label_name)
+ data_name=data_name, label_name=label_name,
+ last_batch_handle=last_batch_handle)
if aug_list is None:
self.auglist = CreateDetAugmenter(data_shape, **kwargs)
@@ -751,14 +758,10 @@ class ImageDetIter(ImageIter):
self.provide_label = [(self.provide_label[0][0], (self.batch_size,) + label_shape)]
self.label_shape = label_shape
- def next(self):
- """Override the function for returning next batch."""
+ def _batchify(self, batch_data, batch_label, start=0):
+ """Override the helper function for batchifying data"""
+ i = start
batch_size = self.batch_size
- c, h, w = self.data_shape
- batch_data = nd.zeros((batch_size, c, h, w))
- batch_label = nd.empty(self.provide_label[0][1])
- batch_label[:] = -1
- i = 0
try:
while i < batch_size:
label, s = self.next_sample()
@@ -783,7 +786,48 @@ class ImageDetIter(ImageIter):
if not i:
raise StopIteration
- return io.DataBatch([batch_data], [batch_label], batch_size - i)
+ return i
+
+ def next(self):
+ """Override the function for returning next batch."""
+ batch_size = self.batch_size
+ c, h, w = self.data_shape
+ # if last batch data is rolled over
+ if self._cache_data is not None:
+ # check both the data and label have values
+ assert self._cache_label is not None, "_cache_label didn't have values"
+ assert self._cache_idx is not None, "_cache_idx didn't have values"
+ batch_data = self._cache_data
+ batch_label = self._cache_label
+ i = self._cache_idx
+ else:
+ batch_data = nd.zeros((batch_size, c, h, w))
+ batch_label = nd.empty(self.provide_label[0][1])
+ batch_label[:] = -1
+ i = self._batchify(batch_data, batch_label)
+ # calculate the padding
+ pad = batch_size - i
+ # handle padding for the last batch
+ if pad != 0:
+ if self.last_batch_handle == 'discard':
+ raise StopIteration
+ # if the option is 'roll_over', throw StopIteration and cache the data
+ elif self.last_batch_handle == 'roll_over' and \
+ self._cache_data is None:
+ self._cache_data = batch_data
+ self._cache_label = batch_label
+ self._cache_idx = i
+ raise StopIteration
+ else:
+ _ = self._batchify(batch_data, batch_label, i)
+ if self.last_batch_handle == 'pad':
+ self._allow_read = False
+ else:
+ self._cache_data = None
+ self._cache_label = None
+ self._cache_idx = None
+
+ return io.DataBatch([batch_data], [batch_label], pad=pad)
def augmentation_transform(self, data, label): # pylint: disable=arguments-differ
"""Override Transforms input data with specified augmentations."""
diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py
index c9a457f..9c2a1cb 100644
--- a/python/mxnet/image/image.py
+++ b/python/mxnet/image/image.py
@@ -1145,7 +1145,7 @@ class ImageIter(io.DataIter):
self.shuffle = shuffle
if self.imgrec is None:
self.seq = imgkeys
- elif shuffle or num_parts > 1:
+ elif shuffle or num_parts > 1 or path_imgidx:
assert self.imgidx is not None
self.seq = self.imgidx
else:
@@ -1261,7 +1261,7 @@ class ImageIter(io.DataIter):
i = self._cache_idx
# clear the cache data
else:
- batch_data = nd.empty((batch_size, c, h, w))
+ batch_data = nd.zeros((batch_size, c, h, w))
batch_label = nd.empty(self.provide_label[0][1])
i = self._batchify(batch_data, batch_label)
# calculate the padding
@@ -1285,6 +1285,7 @@ class ImageIter(io.DataIter):
self._cache_data = None
self._cache_label = None
self._cache_idx = None
+
return io.DataBatch([batch_data], [batch_label], pad=pad)
def check_data_shape(self, data_shape):
diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py
index 4f66823..4063027 100644
--- a/tests/python/unittest/test_image.py
+++ b/tests/python/unittest/test_image.py
@@ -25,6 +25,7 @@ import unittest
from nose.tools import raises
+
def _get_data(url, dirname):
import os, tarfile
download(url, dirname=dirname, overwrite=False)
@@ -50,6 +51,62 @@ def _generate_objects():
label = np.hstack((cid[:, np.newaxis], boxes)).ravel().tolist()
return [2, 5] + label
+def _test_imageiter_last_batch(imageiter_list, assert_data_shape):
+ test_iter = imageiter_list[0]
+ # test batch data shape
+ for _ in range(3):
+ for batch in test_iter:
+ assert batch.data[0].shape == assert_data_shape
+ test_iter.reset()
+ # test last batch handle(discard)
+ test_iter = imageiter_list[1]
+ i = 0
+ for batch in test_iter:
+ i += 1
+ assert i == 5
+ # test last_batch_handle(pad)
+ test_iter = imageiter_list[2]
+ i = 0
+ for batch in test_iter:
+ if i == 0:
+ first_three_data = batch.data[0][:2]
+ if i == 5:
+ last_three_data = batch.data[0][1:]
+ i += 1
+ assert i == 6
+ assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy())
+ # test last_batch_handle(roll_over)
+ test_iter = imageiter_list[3]
+ i = 0
+ for batch in test_iter:
+ if i == 0:
+ first_image = batch.data[0][0]
+ i += 1
+ assert i == 5
+ test_iter.reset()
+ first_batch_roll_over = test_iter.next()
+ assert np.array_equal(
+ first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy())
+ assert first_batch_roll_over.pad == 2
+ # test iteratopr work properly after calling reset several times when last_batch_handle is roll_over
+ for _ in test_iter:
+ pass
+ test_iter.reset()
+ first_batch_roll_over_twice = test_iter.next()
+ assert np.array_equal(
+ first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy())
+ assert first_batch_roll_over_twice.pad == 1
+ # we've called next once
+ i = 1
+ for _ in test_iter:
+ i += 1
+ # test the third epoch with size 6
+ assert i == 6
+ # test shuffle option for sanity test
+ test_iter = imageiter_list[4]
+ for _ in test_iter:
+ pass
+
class TestImage(unittest.TestCase):
IMAGES_URL = "http://data.mxnet.io/data/test_images.tar.gz"
@@ -151,86 +208,32 @@ class TestImage(unittest.TestCase):
assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, atol=1e-3)
def test_imageiter(self):
- def check_imageiter(dtype='float32'):
- im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
- fname = './data/test_imageiter.lst'
- file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x])
- for k, x in enumerate(TestImage.IMAGES)]
- with open(fname, 'w') as f:
- for line in file_list:
- f.write(line + '\n')
-
- test_list = ['imglist', 'path_imglist']
+ im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
+ fname = './data/test_imageiter.lst'
+ file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x])
+ for k, x in enumerate(TestImage.IMAGES)]
+ with open(fname, 'w') as f:
+ for line in file_list:
+ f.write(line + '\n')
+ test_list = ['imglist', 'path_imglist']
+ for dtype in ['int32', 'float32', 'int64', 'float64']:
for test in test_list:
imglist = im_list if test == 'imglist' else None
path_imglist = fname if test == 'path_imglist' else None
-
- test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=imglist,
- path_imglist=path_imglist, path_root='', dtype=dtype)
- # test batch data shape
- for _ in range(3):
- for batch in test_iter:
- assert batch.data[0].shape == (2, 3, 224, 224)
- test_iter.reset()
- # test last batch handle(discard)
- test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
- path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard')
- i = 0
- for batch in test_iter:
- i += 1
- assert i == 5
- # test last_batch_handle(pad)
- test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
- path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
- i = 0
- for batch in test_iter:
- if i == 0:
- first_three_data = batch.data[0][:2]
- if i == 5:
- last_three_data = batch.data[0][1:]
- i += 1
- assert i == 6
- assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy())
- # test last_batch_handle(roll_over)
- test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
- path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over')
- i = 0
- for batch in test_iter:
- if i == 0:
- first_image = batch.data[0][0]
- i += 1
- assert i == 5
- test_iter.reset()
- first_batch_roll_over = test_iter.next()
- assert np.array_equal(
- first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy())
- assert first_batch_roll_over.pad == 2
- # test iteratopr work properly after calling reset several times when last_batch_handle is roll_over
- for _ in test_iter:
- pass
- test_iter.reset()
- first_batch_roll_over_twice = test_iter.next()
- assert np.array_equal(
- first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy())
- assert first_batch_roll_over_twice.pad == 1
- # we've called next once
- i = 1
- for _ in test_iter:
- i += 1
- # test the third epoch with size 6
- assert i == 6
- # test shuffle option for sanity test
- test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True,
- path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
- for _ in test_iter:
- pass
-
- for dtype in ['int32', 'float32', 'int64', 'float64']:
- check_imageiter(dtype)
-
- # test with default dtype
- check_imageiter()
+ imageiter_list = [
+ mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=imglist,
+ path_imglist=path_imglist, path_root='', dtype=dtype),
+ mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
+ path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard'),
+ mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
+ path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad'),
+ mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
+ path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over'),
+ mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True,
+ path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
+ ]
+ _test_imageiter_last_batch(imageiter_list, (2, 3, 224, 224))
@with_seed()
def test_augmenters(self):
@@ -259,16 +262,20 @@ class TestImage(unittest.TestCase):
im_list = [_generate_objects() + [x] for x in TestImage.IMAGES]
det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='')
for _ in range(3):
- for batch in det_iter:
+ for _ in det_iter:
pass
- det_iter.reset()
-
+ det_iter.reset()
val_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='')
det_iter = val_iter.sync_label_shape(det_iter)
assert det_iter.data_shape == val_iter.data_shape
assert det_iter.label_shape == val_iter.label_shape
- # test file list
+ # test batch_size is not divisible by number of images
+ det_iter = mx.image.ImageDetIter(4, (3, 300, 300), imglist=im_list, path_root='')
+ for _ in det_iter:
+ pass
+
+ # test file list with last batch handle
fname = './data/test_imagedetiter.lst'
im_list = [[k] + _generate_objects() + [x] for k, x in enumerate(TestImage.IMAGES)]
with open(fname, 'w') as f:
@@ -276,10 +283,19 @@ class TestImage(unittest.TestCase):
line = '\t'.join([str(k) for k in line])
f.write(line + '\n')
- det_iter = mx.image.ImageDetIter(2, (3, 400, 400), path_imglist=fname,
- path_root='')
- for batch in det_iter:
- pass
+ imageiter_list = [
+ mx.image.ImageDetIter(2, (3, 400, 400),
+ path_imglist=fname, path_root=''),
+ mx.image.ImageDetIter(3, (3, 400, 400),
+ path_imglist=fname, path_root='', last_batch_handle='discard'),
+ mx.image.ImageDetIter(3, (3, 400, 400),
+ path_imglist=fname, path_root='', last_batch_handle='pad'),
+ mx.image.ImageDetIter(3, (3, 400, 400),
+ path_imglist=fname, path_root='', last_batch_handle='roll_over'),
+ mx.image.ImageDetIter(3, (3, 400, 400), shuffle=True,
+ path_imglist=fname, path_root='', last_batch_handle='pad')
+ ]
+ _test_imageiter_last_batch(imageiter_list, (2, 3, 400, 400))
def test_det_augmenters(self):
# only test if all augmenters will work