You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/06/25 23:28:11 UTC

[incubator-mxnet] branch master updated: fix recordfile dataset with multi worker (#11370)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 72fa6a9  fix recordfile dataset with multi worker (#11370)
72fa6a9 is described below

commit 72fa6a9349143a08c23e3949932ca53218804163
Author: Joshua Z. Zhang <ch...@gmail.com>
AuthorDate: Mon Jun 25 16:28:05 2018 -0700

    fix recordfile dataset with multi worker (#11370)
    
    * fix recordfile dataset with multi worker
    
    * fix another test
    
    * fix
---
 python/mxnet/gluon/data/dataloader.py    |  1 +
 python/mxnet/gluon/data/dataset.py       | 13 +++++++++++--
 tests/python/unittest/test_gluon_data.py | 12 ++++++++++++
 tests/python/unittest/test_io.py         |  3 ++-
 4 files changed, 26 insertions(+), 3 deletions(-)

diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 29b9b81..eb1eb41 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -151,6 +151,7 @@ def default_mp_batchify_fn(data):
 
 def worker_loop(dataset, key_queue, data_queue, batchify_fn):
     """Worker loop for multiprocessing DataLoader."""
+    dataset._fork()
     while True:
         idx, samples = key_queue.get()
         if idx is None:
diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py
index bf5fa0a..13e2b57 100644
--- a/python/mxnet/gluon/data/dataset.py
+++ b/python/mxnet/gluon/data/dataset.py
@@ -94,6 +94,11 @@ class Dataset(object):
             return fn(x)
         return self.transform(base_fn, lazy)
 
+    def _fork(self):
+        """Protective operations required when launching multiprocess workers."""
+        # for non file descriptor related datasets, just skip
+        pass
+
 
 class SimpleDataset(Dataset):
     """Simple Dataset wrapper for lists and arrays.
@@ -173,8 +178,12 @@ class RecordFileDataset(Dataset):
         Path to rec file.
     """
     def __init__(self, filename):
-        idx_file = os.path.splitext(filename)[0] + '.idx'
-        self._record = recordio.MXIndexedRecordIO(idx_file, filename, 'r')
+        self.idx_file = os.path.splitext(filename)[0] + '.idx'
+        self.filename = filename
+        self._fork()
+
+    def _fork(self):
+        self._record = recordio.MXIndexedRecordIO(self.idx_file, self.filename, 'r')
 
     def __getitem__(self, idx):
         return self._record.read_idx(self._record.keys[idx])
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index ef2ba2a..0438044 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -73,6 +73,18 @@ def test_recordimage_dataset():
         assert y.asscalar() == i
 
 @with_seed()
+def test_recordimage_dataset_with_data_loader_multiworker():
+    # This test is pointless on Windows because Windows doesn't fork
+    if platform.system() != 'Windows':
+        recfile = prepare_record()
+        dataset = gluon.data.vision.ImageRecordDataset(recfile)
+        loader = gluon.data.DataLoader(dataset, 1, num_workers=5)
+
+        for i, (x, y) in enumerate(loader):
+            assert x.shape[0] == 1 and x.shape[3] == 3
+            assert y.asscalar() == i
+
+@with_seed()
 def test_sampler():
     seq_sampler = gluon.data.SequentialSampler(10)
     assert list(seq_sampler) == list(range(10))
diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py
index f0928a6..dbd327d 100644
--- a/tests/python/unittest/test_io.py
+++ b/tests/python/unittest/test_io.py
@@ -407,7 +407,8 @@ def test_ImageRecordIter_seed_augmentation():
         mean_img="data/cifar/cifar10_mean.bin",
         shuffle=False,
         data_shape=(3, 28, 28),
-        batch_size=3)
+        batch_size=3,
+        seed_aug=seed_aug)
     batch = dataiter.next()
     data = batch.data[0].asnumpy().astype(np.uint8)