You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/08/10 20:03:22 UTC

[incubator-mxnet] branch master updated: take custom dataset into consideration (#12093)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 2fc4248  take custom dataset into consideration (#12093)
2fc4248 is described below

commit 2fc4248550c325b02a76f67b1cec32161a32dc4f
Author: Joshua Z. Zhang <ch...@gmail.com>
AuthorDate: Fri Aug 10 13:03:11 2018 -0700

    take custom dataset into consideration (#12093)
---
 python/mxnet/gluon/data/dataloader.py    | 3 ++-
 tests/python/unittest/test_gluon_data.py | 7 +++++++
 2 files changed, 9 insertions(+), 1 deletion(-)

diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 13ab544..e0b6aec 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -160,7 +160,8 @@ def _as_in_context(data, ctx):
 
 def worker_loop(dataset, key_queue, data_queue, batchify_fn):
     """Worker loop for multiprocessing DataLoader."""
-    dataset._fork()
+    if hasattr(dataset, '_fork') and callable(dataset._fork):
+        dataset._fork()
     while True:
         idx, samples = key_queue.get()
         if idx is None:
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index 4dc4f3a..53ce600 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -116,6 +116,13 @@ def test_image_folder_dataset():
     assert dataset.synsets == ['test_images']
     assert len(dataset.items) == 16
 
+@with_seed()
+def test_list_dataset():
+    for num_worker in range(0, 3):
+        data = mx.gluon.data.DataLoader([([1,2], 0), ([3, 4], 1)], batch_size=1, num_workers=num_worker)
+        for d, l in data:
+            pass
+
 
 class Dataset(gluon.data.Dataset):
     def __len__(self):