You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/16 06:40:30 UTC

[GitHub] piiswrong closed pull request #10096: Fix multi worker

piiswrong closed pull request #10096: Fix multi worker
URL: https://github.com/apache/incubator-mxnet/pull/10096
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 6efaf35e881..7f09e286742 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -119,6 +119,83 @@ def worker_loop(dataset, key_queue, data_queue, batchify_fn):
         batch = batchify_fn([dataset[i] for i in samples])
         data_queue.put((idx, batch))
 
+class _MultiWorkerIter(object):
+    """Interal multi-worker iterator for DataLoader."""
+    def __init__(self, num_workers, dataset, batchify_fn, batch_sampler):
+        assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers)
+        self._num_workers = num_workers
+        self._dataset = dataset
+        self._batchify_fn = batchify_fn
+        self._batch_sampler = batch_sampler
+        self._key_queue = Queue()
+        self._data_queue = Queue(2*self._num_workers)
+        self._data_buffer = {}
+        self._rcvd_idx = 0
+        self._sent_idx = 0
+        self._iter = iter(self._batch_sampler)
+        self._shutdown = False
+
+        workers = []
+        for _ in range(self._num_workers):
+            worker = multiprocessing.Process(
+                target=worker_loop,
+                args=(self._dataset, self._key_queue, self._data_queue, self._batchify_fn))
+            worker.daemon = True
+            worker.start()
+            workers.append(worker)
+
+        # pre-fetch
+        for _ in range(2 * self._num_workers):
+            self._push_next()
+
+    def __len__(self):
+        return len(self._batch_sampler)
+
+    def __del__(self):
+        self.shutdown()
+
+    def _push_next(self):
+        """Assign next batch workload to workers."""
+        r = next(self._iter, None)
+        if r is None:
+            return
+        self._key_queue.put((self._sent_idx, r))
+        self._sent_idx += 1
+
+    def __next__(self):
+        assert not self._shutdown, "call __next__ after shutdown is forbidden"
+        if self._rcvd_idx == self._sent_idx:
+            assert not self._data_buffer, "Data buffer should be empty at this moment"
+            self.shutdown()
+            raise StopIteration
+
+        while True:
+            if self._rcvd_idx in self._data_buffer:
+                batch = self._data_buffer.pop(self._rcvd_idx)
+                self._rcvd_idx += 1
+                self._push_next()
+                return batch
+            idx, batch = self._data_queue.get()
+            self._data_buffer[idx] = batch
+
+    def next(self):
+        return self.__next__()
+
+    def __iter__(self):
+        return self
+
+    def shutdown(self):
+        """Shutdown internal workers by pushing terminate signals."""
+        if not self._shutdown:
+            for _ in range(self._num_workers):
+                self._key_queue.put((None, None))
+            try:
+                while not self._data_queue.empty():
+                    self._data_queue.get()
+            except IOError:
+                pass
+            self._shutdown = True
+
 
 class DataLoader(object):
     """Loads data from a dataset and returns mini-batches of data.
@@ -187,7 +264,7 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
                              "not be specified if batch_sampler is specified.")
 
         self._batch_sampler = batch_sampler
-        self._num_workers = num_workers
+        self._num_workers = num_workers if num_workers >= 0 else 0
         if batchify_fn is None:
             if num_workers > 0:
                 self._batchify_fn = default_mp_batchify_fn
@@ -198,41 +275,13 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
 
     def __iter__(self):
         if self._num_workers == 0:
-            for batch in self._batch_sampler:
-                yield self._batchify_fn([self._dataset[idx] for idx in batch])
-            return
-
-        key_queue = Queue()
-        data_queue = Queue(2*self._num_workers)
-
-        workers = []
-        for _ in range(self._num_workers):
-            worker = multiprocessing.Process(
-                target=worker_loop,
-                args=(self._dataset, key_queue, data_queue, self._batchify_fn))
-            worker.daemon = True
-            worker.start()
-            workers.append(worker)
-
-        idx = 0
-        for idx, batch in enumerate(self._batch_sampler):
-            key_queue.put((idx, batch))
-        num_batches = idx + 1
-
-        data_buffer = {}
-        curr_idx = 0
-        for _ in range(num_batches):
-            idx, batch = data_queue.get()
-            data_buffer[idx] = batch
-            while curr_idx in data_buffer:
-                yield data_buffer.pop(curr_idx)
-                curr_idx += 1
-
-        for _ in range(self._num_workers):
-            key_queue.put((None, None))
+            generator = lambda: [(yield self._batchify_fn([self._dataset[idx] for idx in batch]))
+                                 for batch in self._batch_sampler]
+            return generator()
 
-        for worker in workers:
-            worker.join()
+        # multi-worker
+        return _MultiWorkerIter(self._num_workers, self._dataset,
+                                self._batchify_fn, self._batch_sampler)
 
     def __len__(self):
         return len(self._batch_sampler)
diff --git a/src/storage/cpu_shared_storage_manager.h b/src/storage/cpu_shared_storage_manager.h
index 2a75a97df57..e2de30db923 100644
--- a/src/storage/cpu_shared_storage_manager.h
+++ b/src/storage/cpu_shared_storage_manager.h
@@ -69,6 +69,7 @@ class CPUSharedStorageManager final : public StorageManager {
 
   void Alloc(Storage::Handle* handle) override;
   void Free(Storage::Handle handle) override {
+    std::lock_guard<std::recursive_mutex> lock(mutex_);
     pool_.erase(handle.dptr);
     FreeImpl(handle);
   }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services