You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/03/16 06:40:38 UTC

[incubator-mxnet] branch master updated: Fix multi worker (#10096)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 24a8b78  Fix multi worker (#10096)
24a8b78 is described below

commit 24a8b78ca4032a2bc9a1b66187e786e7122d8285
Author: Joshua Z. Zhang <ch...@gmail.com>
AuthorDate: Fri Mar 16 01:40:28 2018 -0500

    Fix multi worker (#10096)
    
    * improve multi worker iterator
    
    * debug
    
    * debug
    
    * fix python2
    
    * fix
    
    * update
    
    * fix race condition in cpu shared storage free
    
    * fix docstring
    
    * update
    
    * push workload in next
---
 python/mxnet/gluon/data/dataloader.py    | 119 ++++++++++++++++++++++---------
 src/storage/cpu_shared_storage_manager.h |   1 +
 2 files changed, 85 insertions(+), 35 deletions(-)

diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 6efaf35..7f09e28 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -119,6 +119,83 @@ def worker_loop(dataset, key_queue, data_queue, batchify_fn):
         batch = batchify_fn([dataset[i] for i in samples])
         data_queue.put((idx, batch))
 
+class _MultiWorkerIter(object):
+    """Interal multi-worker iterator for DataLoader."""
+    def __init__(self, num_workers, dataset, batchify_fn, batch_sampler):
+        assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers)
+        self._num_workers = num_workers
+        self._dataset = dataset
+        self._batchify_fn = batchify_fn
+        self._batch_sampler = batch_sampler
+        self._key_queue = Queue()
+        self._data_queue = Queue(2*self._num_workers)
+        self._data_buffer = {}
+        self._rcvd_idx = 0
+        self._sent_idx = 0
+        self._iter = iter(self._batch_sampler)
+        self._shutdown = False
+
+        workers = []
+        for _ in range(self._num_workers):
+            worker = multiprocessing.Process(
+                target=worker_loop,
+                args=(self._dataset, self._key_queue, self._data_queue, self._batchify_fn))
+            worker.daemon = True
+            worker.start()
+            workers.append(worker)
+
+        # pre-fetch
+        for _ in range(2 * self._num_workers):
+            self._push_next()
+
+    def __len__(self):
+        return len(self._batch_sampler)
+
+    def __del__(self):
+        self.shutdown()
+
+    def _push_next(self):
+        """Assign next batch workload to workers."""
+        r = next(self._iter, None)
+        if r is None:
+            return
+        self._key_queue.put((self._sent_idx, r))
+        self._sent_idx += 1
+
+    def __next__(self):
+        assert not self._shutdown, "call __next__ after shutdown is forbidden"
+        if self._rcvd_idx == self._sent_idx:
+            assert not self._data_buffer, "Data buffer should be empty at this moment"
+            self.shutdown()
+            raise StopIteration
+
+        while True:
+            if self._rcvd_idx in self._data_buffer:
+                batch = self._data_buffer.pop(self._rcvd_idx)
+                self._rcvd_idx += 1
+                self._push_next()
+                return batch
+            idx, batch = self._data_queue.get()
+            self._data_buffer[idx] = batch
+
+    def next(self):
+        return self.__next__()
+
+    def __iter__(self):
+        return self
+
+    def shutdown(self):
+        """Shutdown internal workers by pushing terminate signals."""
+        if not self._shutdown:
+            for _ in range(self._num_workers):
+                self._key_queue.put((None, None))
+            try:
+                while not self._data_queue.empty():
+                    self._data_queue.get()
+            except IOError:
+                pass
+            self._shutdown = True
+
 
 class DataLoader(object):
     """Loads data from a dataset and returns mini-batches of data.
@@ -187,7 +264,7 @@ class DataLoader(object):
                              "not be specified if batch_sampler is specified.")
 
         self._batch_sampler = batch_sampler
-        self._num_workers = num_workers
+        self._num_workers = num_workers if num_workers >= 0 else 0
         if batchify_fn is None:
             if num_workers > 0:
                 self._batchify_fn = default_mp_batchify_fn
@@ -198,41 +275,13 @@ class DataLoader(object):
 
     def __iter__(self):
         if self._num_workers == 0:
-            for batch in self._batch_sampler:
-                yield self._batchify_fn([self._dataset[idx] for idx in batch])
-            return
-
-        key_queue = Queue()
-        data_queue = Queue(2*self._num_workers)
-
-        workers = []
-        for _ in range(self._num_workers):
-            worker = multiprocessing.Process(
-                target=worker_loop,
-                args=(self._dataset, key_queue, data_queue, self._batchify_fn))
-            worker.daemon = True
-            worker.start()
-            workers.append(worker)
-
-        idx = 0
-        for idx, batch in enumerate(self._batch_sampler):
-            key_queue.put((idx, batch))
-        num_batches = idx + 1
-
-        data_buffer = {}
-        curr_idx = 0
-        for _ in range(num_batches):
-            idx, batch = data_queue.get()
-            data_buffer[idx] = batch
-            while curr_idx in data_buffer:
-                yield data_buffer.pop(curr_idx)
-                curr_idx += 1
-
-        for _ in range(self._num_workers):
-            key_queue.put((None, None))
+            generator = lambda: [(yield self._batchify_fn([self._dataset[idx] for idx in batch]))
+                                 for batch in self._batch_sampler]
+            return generator()
 
-        for worker in workers:
-            worker.join()
+        # multi-worker
+        return _MultiWorkerIter(self._num_workers, self._dataset,
+                                self._batchify_fn, self._batch_sampler)
 
     def __len__(self):
         return len(self._batch_sampler)
diff --git a/src/storage/cpu_shared_storage_manager.h b/src/storage/cpu_shared_storage_manager.h
index 2a75a97..e2de30d 100644
--- a/src/storage/cpu_shared_storage_manager.h
+++ b/src/storage/cpu_shared_storage_manager.h
@@ -69,6 +69,7 @@ class CPUSharedStorageManager final : public StorageManager {
 
   void Alloc(Storage::Handle* handle) override;
   void Free(Storage::Handle handle) override {
+    std::lock_guard<std::recursive_mutex> lock(mutex_);
     pool_.erase(handle.dptr);
     FreeImpl(handle);
   }

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.