You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/11/26 05:41:32 UTC

[incubator-mxnet] branch master updated: Improving multi-processing reliability for gluon DataLoader (#13318)

This is an automated email from the ASF dual-hosted git repository.

zhreshold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 7b1e7a5  Improving multi-processing reliability for gluon DataLoader (#13318)
7b1e7a5 is described below

commit 7b1e7a5040aee0adffb826a2b4acf1380c93b07b
Author: Yuting Zhang <zh...@yuting.link>
AuthorDate: Sun Nov 25 21:41:19 2018 -0800

    Improving multi-processing reliability for gluon DataLoader (#13318)
    
    * improving multi-processing reliability for gluon dataloader
    
    I found some multi-processing-related issues in the Gluon  DataLoader.
    
     1) Each time a _MultiWorkerIter shuts down, it could leave some dangling processes. The shutdown mechanism could not guarantee that all worker processes can be terminated. As a result, after running for several epochs, more and more dangling processes will stay there.
    
      This problem barely happens during training. In this case, there is a decent time interval between the last-batch data prefetching and the _MultiWorkerIter's shutting down).
      But the problem frequently happens 1) when I stop the iter before the end of an epoch, and 2) when I use the DataLoader for a data loading service and load the data as fast as possible. In both cases, the time interval between the most recent data prefetching and the iter shutdown are short. I guess that the _MultiWorkerIter iter is unable to shut down properly during active data prefetching.
    
      To fix this, I explicitly terminate the worker processes inside the shutdown function.
    
      2) When loading data fast (still mostly during testing and data serving), there seems to be a risk of data racing. The data iter uses a _MultiWorkerIter to cache prefetched data, but the dict does not seem to be thread-safe for concurrent inserting and deleting elements. So occasionally, the data can be missing from the  dict.
    
      To prevent this, I use a scope lock to guard the dict access.
    
    * do not wait for the workers to join, and kill any alive wokers as soon as possible
---
 python/mxnet/gluon/data/dataloader.py | 28 +++++++++++++++++++++++-----
 1 file changed, 23 insertions(+), 5 deletions(-)

diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 50e2ad9..86cb835f 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -189,7 +189,7 @@ def worker_loop(dataset, key_queue, data_queue, batchify_fn):
         batch = batchify_fn([dataset[i] for i in samples])
         data_queue.put((idx, batch))
 
-def fetcher_loop(data_queue, data_buffer, pin_memory=False):
+def fetcher_loop(data_queue, data_buffer, pin_memory=False, data_buffer_lock=None):
     """Fetcher loop for fetching data from queue and put in reorder dict."""
     while True:
         idx, batch = data_queue.get()
@@ -199,7 +199,11 @@ def fetcher_loop(data_queue, data_buffer, pin_memory=False):
             batch = _as_in_context(batch, context.cpu_pinned())
         else:
             batch = _as_in_context(batch, context.cpu())
-        data_buffer[idx] = batch
+        if data_buffer_lock is not None:
+            with data_buffer_lock:
+                data_buffer[idx] = batch
+        else:
+            data_buffer[idx] = batch
 
 
 class _MultiWorkerIter(object):
@@ -213,7 +217,10 @@ class _MultiWorkerIter(object):
         self._batch_sampler = batch_sampler
         self._key_queue = Queue()
         self._data_queue = Queue() if sys.version_info[0] <= 2 else SimpleQueue()
+
         self._data_buffer = {}
+        self._data_buffer_lock = threading.Lock()
+
         self._rcvd_idx = 0
         self._sent_idx = 0
         self._iter = iter(self._batch_sampler)
@@ -227,10 +234,11 @@ class _MultiWorkerIter(object):
             worker.daemon = True
             worker.start()
             workers.append(worker)
+        self._workers = workers
 
         self._fetcher = threading.Thread(
             target=fetcher_loop,
-            args=(self._data_queue, self._data_buffer, pin_memory))
+            args=(self._data_queue, self._data_buffer, pin_memory, self._data_buffer_lock))
         self._fetcher.daemon = True
         self._fetcher.start()
 
@@ -261,7 +269,8 @@ class _MultiWorkerIter(object):
 
         while True:
             if self._rcvd_idx in self._data_buffer:
-                batch = self._data_buffer.pop(self._rcvd_idx)
+                with self._data_buffer_lock:
+                    batch = self._data_buffer.pop(self._rcvd_idx)
                 self._rcvd_idx += 1
                 self._push_next()
                 return batch
@@ -275,9 +284,18 @@ class _MultiWorkerIter(object):
     def shutdown(self):
         """Shutdown internal workers by pushing terminate signals."""
         if not self._shutdown:
+            # send shutdown signal to the fetcher and join data queue first
+            # Remark:   loop_fetcher need to be joined prior to the workers.
+            #           otherwise, the the fetcher may fail at getting data
+            self._data_queue.put((None, None))
+            self._fetcher.join()
+            # send shutdown signal to all worker processes
             for _ in range(self._num_workers):
                 self._key_queue.put((None, None))
-            self._data_queue.put((None, None))
+            # force shut down any alive worker processes
+            for w in self._workers:
+                if w.is_alive():
+                    w.terminate()
             self._shutdown = True