You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/06/25 23:28:11 UTC
[incubator-mxnet] branch master updated: fix recordfile dataset
with multi worker (#11370)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 72fa6a9 fix recordfile dataset with multi worker (#11370)
72fa6a9 is described below
commit 72fa6a9349143a08c23e3949932ca53218804163
Author: Joshua Z. Zhang <ch...@gmail.com>
AuthorDate: Mon Jun 25 16:28:05 2018 -0700
fix recordfile dataset with multi worker (#11370)
* fix recordfile dataset with multi worker
* fix another test
* fix
---
python/mxnet/gluon/data/dataloader.py | 1 +
python/mxnet/gluon/data/dataset.py | 13 +++++++++++--
tests/python/unittest/test_gluon_data.py | 12 ++++++++++++
tests/python/unittest/test_io.py | 3 ++-
4 files changed, 26 insertions(+), 3 deletions(-)
diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 29b9b81..eb1eb41 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -151,6 +151,7 @@ def default_mp_batchify_fn(data):
def worker_loop(dataset, key_queue, data_queue, batchify_fn):
"""Worker loop for multiprocessing DataLoader."""
+ dataset._fork()
while True:
idx, samples = key_queue.get()
if idx is None:
diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py
index bf5fa0a..13e2b57 100644
--- a/python/mxnet/gluon/data/dataset.py
+++ b/python/mxnet/gluon/data/dataset.py
@@ -94,6 +94,11 @@ class Dataset(object):
return fn(x)
return self.transform(base_fn, lazy)
+ def _fork(self):
+ """Protective operations required when launching multiprocess workers."""
+ # for non file descriptor related datasets, just skip
+ pass
+
class SimpleDataset(Dataset):
"""Simple Dataset wrapper for lists and arrays.
@@ -173,8 +178,12 @@ class RecordFileDataset(Dataset):
Path to rec file.
"""
def __init__(self, filename):
- idx_file = os.path.splitext(filename)[0] + '.idx'
- self._record = recordio.MXIndexedRecordIO(idx_file, filename, 'r')
+ self.idx_file = os.path.splitext(filename)[0] + '.idx'
+ self.filename = filename
+ self._fork()
+
+ def _fork(self):
+ self._record = recordio.MXIndexedRecordIO(self.idx_file, self.filename, 'r')
def __getitem__(self, idx):
return self._record.read_idx(self._record.keys[idx])
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index ef2ba2a..0438044 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -73,6 +73,18 @@ def test_recordimage_dataset():
assert y.asscalar() == i
@with_seed()
+def test_recordimage_dataset_with_data_loader_multiworker():
+ # This test is pointless on Windows because Windows doesn't fork
+ if platform.system() != 'Windows':
+ recfile = prepare_record()
+ dataset = gluon.data.vision.ImageRecordDataset(recfile)
+ loader = gluon.data.DataLoader(dataset, 1, num_workers=5)
+
+ for i, (x, y) in enumerate(loader):
+ assert x.shape[0] == 1 and x.shape[3] == 3
+ assert y.asscalar() == i
+
+@with_seed()
def test_sampler():
seq_sampler = gluon.data.SequentialSampler(10)
assert list(seq_sampler) == list(range(10))
diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py
index f0928a6..dbd327d 100644
--- a/tests/python/unittest/test_io.py
+++ b/tests/python/unittest/test_io.py
@@ -407,7 +407,8 @@ def test_ImageRecordIter_seed_augmentation():
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
data_shape=(3, 28, 28),
- batch_size=3)
+ batch_size=3,
+ seed_aug=seed_aug)
batch = dataiter.next()
data = batch.data[0].asnumpy().astype(np.uint8)