You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by wk...@apache.org on 2019/06/12 06:41:15 UTC

[incubator-mxnet] branch master updated: Fixed a bug in Gluon DataLoader. (#15195)

This is an automated email from the ASF dual-hosted git repository.

wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 2e20094  Fixed a bug in Gluon DataLoader. (#15195)
2e20094 is described below

commit 2e2009457abeec70afc0ffebb74ab2815c376bc6
Author: Chandana Satya Prakash <ch...@gmail.com>
AuthorDate: Wed Jun 12 02:40:54 2019 -0400

    Fixed a bug in Gluon DataLoader. (#15195)
    
    * Fixed a bug in Gluon DataLoader.
        Issue: https://github.com/apache/incubator-mxnet/issues/15025
        Fix: Broadened the scope of worker pool to iterators. Passed a reference of dataloader to the multi worker iterator
    
    * Fixed a bug in Gluon DataLoader.
        Issue: https://github.com/apache/incubator-mxnet/issues/15025
        Fix: Broadened the scope of worker pool to iterators. Passed a reference of dataloader to the multi worker iterator
    
    * Fixed a bug in Gluon DataLoader.
        Issue: https://github.com/apache/incubator-mxnet/issues/15025
        Fix: Broadened the scope of worker pool to iterators. Passed a reference of dataloader to the multi worker iterator
    
    * Fixed a bug in Gluon DataLoader.
        Issue: https://github.com/apache/incubator-mxnet/issues/15025
        Fix: Broadened the scope of worker pool to iterators. Passed a reference of dataloader to the multi worker iterator
---
 python/mxnet/gluon/data/dataloader.py    |  6 ++++--
 tests/python/unittest/test_gluon_data.py | 25 +++++++++++++++++++++++++
 2 files changed, 29 insertions(+), 2 deletions(-)

diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 934f2d5..accd968 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -409,7 +409,7 @@ def _thread_worker_fn(samples, batchify_fn, dataset):
 class _MultiWorkerIter(object):
     """Internal multi-worker iterator for DataLoader."""
     def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
-                 pin_device_id=0, worker_fn=_worker_fn, prefetch=0, dataset=None):
+                 pin_device_id=0, worker_fn=_worker_fn, prefetch=0, dataset=None, data_loader=None):
         self._worker_pool = worker_pool
         self._batchify_fn = batchify_fn
         self._batch_sampler = batch_sampler
@@ -421,6 +421,7 @@ class _MultiWorkerIter(object):
         self._pin_memory = pin_memory
         self._pin_device_id = pin_device_id
         self._dataset = dataset
+        self._data_loader = data_loader
         # pre-fetch
         for _ in range(prefetch):
             self._push_next()
@@ -582,7 +583,8 @@ class DataLoader(object):
                                 pin_memory=self._pin_memory, pin_device_id=self._pin_device_id,
                                 worker_fn=_thread_worker_fn if self._thread_pool else _worker_fn,
                                 prefetch=self._prefetch,
-                                dataset=self._dataset if self._thread_pool else None)
+                                dataset=self._dataset if self._thread_pool else None,
+                                data_loader=self)
 
     def __len__(self):
         return len(self._batch_sampler)
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index 1939de8..58e241b 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -28,6 +28,7 @@ from mxnet.gluon.data import DataLoader
 import mxnet.ndarray as nd
 from mxnet import context
 from mxnet.gluon.data.dataset import Dataset
+from mxnet.gluon.data.dataset import ArrayDataset
 
 @with_seed()
 def test_array_dataset():
@@ -279,6 +280,30 @@ def test_dataloader_context():
     for _, x in enumerate(loader3):
         assert x.context == context.cpu_pinned(custom_dev_id)
 
+def batchify(a):
+    return a
+
+def test_dataloader_scope():
+    """
+    Bug: Gluon DataLoader terminates the process pool early while
+    _MultiWorkerIter is operating on the pool.
+
+    Tests that DataLoader is not garbage collected while the iterator is
+    in use.
+    """
+    args = {'num_workers': 1, 'batch_size': 2}
+    dataset = nd.ones(5)
+    iterator = iter(DataLoader(
+            dataset,
+            batchify_fn=batchify,
+            **args
+        )
+    )
+
+    item = next(iterator)
+
+    assert item is not None
+
 
 if __name__ == '__main__':
     import nose