You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/03/16 06:40:38 UTC
[incubator-mxnet] branch master updated: Fix multi worker (#10096)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 24a8b78 Fix multi worker (#10096)
24a8b78 is described below
commit 24a8b78ca4032a2bc9a1b66187e786e7122d8285
Author: Joshua Z. Zhang <ch...@gmail.com>
AuthorDate: Fri Mar 16 01:40:28 2018 -0500
Fix multi worker (#10096)
* improve multi worker iterator
* debug
* debug
* fix python2
* fix
* update
* fix race condition in cpu shared storage free
* fix docstring
* update
* push workload in next
---
python/mxnet/gluon/data/dataloader.py | 119 ++++++++++++++++++++++---------
src/storage/cpu_shared_storage_manager.h | 1 +
2 files changed, 85 insertions(+), 35 deletions(-)
diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 6efaf35..7f09e28 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -119,6 +119,83 @@ def worker_loop(dataset, key_queue, data_queue, batchify_fn):
batch = batchify_fn([dataset[i] for i in samples])
data_queue.put((idx, batch))
+class _MultiWorkerIter(object):
+ """Interal multi-worker iterator for DataLoader."""
+ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler):
+ assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers)
+ self._num_workers = num_workers
+ self._dataset = dataset
+ self._batchify_fn = batchify_fn
+ self._batch_sampler = batch_sampler
+ self._key_queue = Queue()
+ self._data_queue = Queue(2*self._num_workers)
+ self._data_buffer = {}
+ self._rcvd_idx = 0
+ self._sent_idx = 0
+ self._iter = iter(self._batch_sampler)
+ self._shutdown = False
+
+ workers = []
+ for _ in range(self._num_workers):
+ worker = multiprocessing.Process(
+ target=worker_loop,
+ args=(self._dataset, self._key_queue, self._data_queue, self._batchify_fn))
+ worker.daemon = True
+ worker.start()
+ workers.append(worker)
+
+ # pre-fetch
+ for _ in range(2 * self._num_workers):
+ self._push_next()
+
+ def __len__(self):
+ return len(self._batch_sampler)
+
+ def __del__(self):
+ self.shutdown()
+
+ def _push_next(self):
+ """Assign next batch workload to workers."""
+ r = next(self._iter, None)
+ if r is None:
+ return
+ self._key_queue.put((self._sent_idx, r))
+ self._sent_idx += 1
+
+ def __next__(self):
+ assert not self._shutdown, "call __next__ after shutdown is forbidden"
+ if self._rcvd_idx == self._sent_idx:
+ assert not self._data_buffer, "Data buffer should be empty at this moment"
+ self.shutdown()
+ raise StopIteration
+
+ while True:
+ if self._rcvd_idx in self._data_buffer:
+ batch = self._data_buffer.pop(self._rcvd_idx)
+ self._rcvd_idx += 1
+ self._push_next()
+ return batch
+ idx, batch = self._data_queue.get()
+ self._data_buffer[idx] = batch
+
+ def next(self):
+ return self.__next__()
+
+ def __iter__(self):
+ return self
+
+ def shutdown(self):
+ """Shutdown internal workers by pushing terminate signals."""
+ if not self._shutdown:
+ for _ in range(self._num_workers):
+ self._key_queue.put((None, None))
+ try:
+ while not self._data_queue.empty():
+ self._data_queue.get()
+ except IOError:
+ pass
+ self._shutdown = True
+
class DataLoader(object):
"""Loads data from a dataset and returns mini-batches of data.
@@ -187,7 +264,7 @@ class DataLoader(object):
"not be specified if batch_sampler is specified.")
self._batch_sampler = batch_sampler
- self._num_workers = num_workers
+ self._num_workers = num_workers if num_workers >= 0 else 0
if batchify_fn is None:
if num_workers > 0:
self._batchify_fn = default_mp_batchify_fn
@@ -198,41 +275,13 @@ class DataLoader(object):
def __iter__(self):
if self._num_workers == 0:
- for batch in self._batch_sampler:
- yield self._batchify_fn([self._dataset[idx] for idx in batch])
- return
-
- key_queue = Queue()
- data_queue = Queue(2*self._num_workers)
-
- workers = []
- for _ in range(self._num_workers):
- worker = multiprocessing.Process(
- target=worker_loop,
- args=(self._dataset, key_queue, data_queue, self._batchify_fn))
- worker.daemon = True
- worker.start()
- workers.append(worker)
-
- idx = 0
- for idx, batch in enumerate(self._batch_sampler):
- key_queue.put((idx, batch))
- num_batches = idx + 1
-
- data_buffer = {}
- curr_idx = 0
- for _ in range(num_batches):
- idx, batch = data_queue.get()
- data_buffer[idx] = batch
- while curr_idx in data_buffer:
- yield data_buffer.pop(curr_idx)
- curr_idx += 1
-
- for _ in range(self._num_workers):
- key_queue.put((None, None))
+ generator = lambda: [(yield self._batchify_fn([self._dataset[idx] for idx in batch]))
+ for batch in self._batch_sampler]
+ return generator()
- for worker in workers:
- worker.join()
+ # multi-worker
+ return _MultiWorkerIter(self._num_workers, self._dataset,
+ self._batchify_fn, self._batch_sampler)
def __len__(self):
return len(self._batch_sampler)
diff --git a/src/storage/cpu_shared_storage_manager.h b/src/storage/cpu_shared_storage_manager.h
index 2a75a97..e2de30d 100644
--- a/src/storage/cpu_shared_storage_manager.h
+++ b/src/storage/cpu_shared_storage_manager.h
@@ -69,6 +69,7 @@ class CPUSharedStorageManager final : public StorageManager {
void Alloc(Storage::Handle* handle) override;
void Free(Storage::Handle handle) override {
+ std::lock_guard<std::recursive_mutex> lock(mutex_);
pool_.erase(handle.dptr);
FreeImpl(handle);
}
--
To stop receiving notification emails like this one, please contact
jxie@apache.org.