You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/10 16:01:46 UTC

[GitHub] zhreshold closed pull request #11908: Fix shared memory with gluon dataloader, add option pin_memory

zhreshold closed pull request #11908: Fix shared memory with gluon dataloader, add option pin_memory
URL: https://github.com/apache/incubator-mxnet/pull/11908
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/ci/windows/test_py2_cpu.ps1 b/ci/windows/test_py2_cpu.ps1
index 1623d295610..aa38b81e392 100644
--- a/ci/windows/test_py2_cpu.ps1
+++ b/ci/windows/test_py2_cpu.ps1
@@ -16,6 +16,7 @@
 # under the License.
 
 7z x -y windows_package.7z
+$env:MXNET_LIBRARY_PATH=join-path $pwd.Path windows_package\lib\libmxnet.dll
 $env:PYTHONPATH=join-path $pwd.Path windows_package\python
 $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
 c:\Anaconda3\envs\py2\Scripts\pip install -r tests\requirements.txt
diff --git a/ci/windows/test_py2_gpu.ps1 b/ci/windows/test_py2_gpu.ps1
index 13cd5366e0d..5f8de5ac4f9 100644
--- a/ci/windows/test_py2_gpu.ps1
+++ b/ci/windows/test_py2_gpu.ps1
@@ -16,6 +16,7 @@
 # under the License.
 
 7z x -y windows_package.7z
+$env:MXNET_LIBRARY_PATH=join-path $pwd.Path windows_package\lib\libmxnet.dll
 $env:PYTHONPATH=join-path $pwd.Path windows_package\python
 $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
 c:\Anaconda3\envs\py2\Scripts\pip install -r tests\requirements.txt
diff --git a/ci/windows/test_py3_cpu.ps1 b/ci/windows/test_py3_cpu.ps1
index 98d4e410e8f..0dd48de26b3 100644
--- a/ci/windows/test_py3_cpu.ps1
+++ b/ci/windows/test_py3_cpu.ps1
@@ -16,6 +16,7 @@
 # under the License.
 
 7z x -y windows_package.7z
+$env:MXNET_LIBRARY_PATH=join-path $pwd.Path windows_package\lib\libmxnet.dll
 $env:PYTHONPATH=join-path $pwd.Path windows_package\python
 $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
 c:\Anaconda3\envs\py3\Scripts\pip install -r tests\requirements.txt
diff --git a/ci/windows/test_py3_gpu.ps1 b/ci/windows/test_py3_gpu.ps1
index b94b4f389be..4a0feb1ede8 100644
--- a/ci/windows/test_py3_gpu.ps1
+++ b/ci/windows/test_py3_gpu.ps1
@@ -16,6 +16,7 @@
 # under the License.
 
 7z x -y windows_package.7z
+$env:MXNET_LIBRARY_PATH=join-path $pwd.Path windows_package\lib\libmxnet.dll
 $env:PYTHONPATH=join-path $pwd.Path windows_package\python
 $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
 c:\Anaconda3\envs\py3\Scripts\pip install -r tests\requirements.txt
diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index eb1eb419cd0..13ab544a03d 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -16,7 +16,7 @@
 # under the License.
 
 # coding: utf-8
-# pylint: disable=
+# pylint: disable=ungrouped-imports
 """Dataset generator."""
 __all__ = ['DataLoader']
 
@@ -26,6 +26,7 @@
 import multiprocessing
 import multiprocessing.queues
 from multiprocessing.reduction import ForkingPickler
+import threading
 import numpy as np
 
 try:
@@ -149,6 +150,14 @@ def default_mp_batchify_fn(data):
                         ctx=context.Context('cpu_shared', 0))
 
 
+def _as_in_context(data, ctx):
+    """Move data into new context."""
+    if isinstance(data, nd.NDArray):
+        return data.as_in_context(ctx)
+    elif isinstance(data, (list, tuple)):
+        return [_as_in_context(d, ctx) for d in data]
+    return data
+
 def worker_loop(dataset, key_queue, data_queue, batchify_fn):
     """Worker loop for multiprocessing DataLoader."""
     dataset._fork()
@@ -159,9 +168,21 @@ def worker_loop(dataset, key_queue, data_queue, batchify_fn):
         batch = batchify_fn([dataset[i] for i in samples])
         data_queue.put((idx, batch))
 
+def fetcher_loop(data_queue, data_buffer, pin_memory=False):
+    """Fetcher loop for fetching data from queue and put in reorder dict."""
+    while True:
+        idx, batch = data_queue.get()
+        if idx is None:
+            break
+        if pin_memory:
+            batch = _as_in_context(batch, context.cpu_pinned())
+        else:
+            batch = _as_in_context(batch, context.cpu())
+        data_buffer[idx] = batch
+
 class _MultiWorkerIter(object):
     """Interal multi-worker iterator for DataLoader."""
-    def __init__(self, num_workers, dataset, batchify_fn, batch_sampler):
+    def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False):
         assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers)
         self._num_workers = num_workers
         self._dataset = dataset
@@ -184,6 +205,12 @@ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler):
             worker.start()
             workers.append(worker)
 
+        self._fetcher = threading.Thread(
+            target=fetcher_loop,
+            args=(self._data_queue, self._data_buffer, pin_memory))
+        self._fetcher.daemon = True
+        self._fetcher.start()
+
         # pre-fetch
         for _ in range(2 * self._num_workers):
             self._push_next()
@@ -210,13 +237,11 @@ def __next__(self):
             raise StopIteration
 
         while True:
-            self._push_next()
             if self._rcvd_idx in self._data_buffer:
                 batch = self._data_buffer.pop(self._rcvd_idx)
                 self._rcvd_idx += 1
+                self._push_next()
                 return batch
-            idx, batch = self._data_queue.get()
-            self._data_buffer[idx] = batch
 
     def next(self):
         return self.__next__()
@@ -229,11 +254,7 @@ def shutdown(self):
         if not self._shutdown:
             for _ in range(self._num_workers):
                 self._key_queue.put((None, None))
-            try:
-                while not self._data_queue.empty():
-                    self._data_queue.get()
-            except IOError:
-                pass
+            self._data_queue.put((None, None))
             self._shutdown = True
 
 
@@ -277,12 +298,16 @@ def default_batchify_fn(data):
 
     num_workers : int, default 0
         The number of multiprocessing workers to use for data preprocessing.
-        `num_workers > 0` is not supported on Windows yet.
+    pin_memory : boolean, default False
+        If ``True``, the dataloader will copy NDArrays into pinned memory
+        before returning them. Copying from CPU pinned memory to GPU is faster
+        than from normal CPU memory.
     """
     def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
                  last_batch=None, batch_sampler=None, batchify_fn=None,
-                 num_workers=0):
+                 num_workers=0, pin_memory=False):
         self._dataset = dataset
+        self._pin_memory = pin_memory
 
         if batch_sampler is None:
             if batch_size is None:
@@ -315,13 +340,17 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
 
     def __iter__(self):
         if self._num_workers == 0:
-            generator = lambda: [(yield self._batchify_fn([self._dataset[idx] for idx in batch]))
-                                 for batch in self._batch_sampler]
-            return generator()
+            def same_process_iter():
+                for batch in self._batch_sampler:
+                    ret = self._batchify_fn([self._dataset[idx] for idx in batch])
+                    if self._pin_memory:
+                        ret = _as_in_context(ret, context.cpu_pinned())
+                    yield ret
+            return same_process_iter()
 
         # multi-worker
         return _MultiWorkerIter(self._num_workers, self._dataset,
-                                self._batchify_fn, self._batch_sampler)
+                                self._batchify_fn, self._batch_sampler, self._pin_memory)
 
     def __len__(self):
         return len(self._batch_sampler)
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index 043804487b5..4dc4f3ac881 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -130,99 +130,89 @@ def test_multi_worker():
     for i, batch in enumerate(loader):
         assert (batch.asnumpy() == i).all()
 
-@with_seed()
-def test_multi_worker_forked_data_loader():
+class _Dummy(Dataset):
+    """Dummy dataset for randomized shape arrays."""
+    def __init__(self, random_shape):
+        self.random_shape = random_shape
+
+    def __getitem__(self, idx):
+        key = idx
+        if self.random_shape:
+            out = np.random.uniform(size=(random.randint(1000, 1100), 40))
+            labels = np.random.uniform(size=(random.randint(10, 15)))
+        else:
+            out = np.random.uniform(size=(1000, 40))
+            labels = np.random.uniform(size=(10))
+        return key, out, labels
+
+    def __len__(self):
+        return 50
+
+def _batchify_list(data):
+    """
+    return list of ndarray without stack/concat/pad
     """
-    Test should successfully run its course of multi-process/forked data loader without errors
+    if isinstance(data, (tuple, list)):
+        return list(data)
+    if isinstance(data, mx.nd.NDArray):
+        return [data]
+    return data
+
+def _batchify(data):
+    """
+    Collate data into batch. Use shared memory for stacking.
+    :param data: a list of array, with layout of 'NTC'.
+    :return either x  and x's unpadded lengths, or x, x's unpadded lengths, y and y's unpadded lengths
+            if labels are not supplied.
     """
-    class Dummy(Dataset):
-        def __init__(self, random_shape):
-            self.random_shape = random_shape
-
-        def __getitem__(self, idx):
-            key = idx
-            if self.random_shape:
-                out = np.random.uniform(size=(random.randint(1000, 1100), 40))
-                labels = np.random.uniform(size=(random.randint(10, 15)))
-            else:
-                out = np.random.uniform(size=(1000, 40))
-                labels = np.random.uniform(size=(10))
-            return key, out, labels
-
-        def __len__(self):
-            return 50
-
-        def batchify_list(self, data):
-            """
-            return list of ndarray without stack/concat/pad
-            """
-            if isinstance(data, (tuple, list)):
-                return list(data)
-            if isinstance(data, mx.nd.NDArray):
-                return [data]
-            return data
-
-        def batchify(self, data):
-            """
-            Collate data into batch. Use shared memory for stacking.
-
-            :param data: a list of array, with layout of 'NTC'.
-            :return either x  and x's unpadded lengths, or x, x's unpadded lengths, y and y's unpadded lengths
-                    if labels are not supplied.
-            """
-
-            # input layout is NTC
-            keys, inputs, labels = [item[0] for item in data], [item[1] for item in data], \
-                                   [item[2] for item in data]
-
-            if len(data) > 1:
-                max_data_len = max([seq.shape[0] for seq in inputs])
-                max_labels_len = 0 if not labels else max([seq.shape[0] for seq in labels])
-            else:
-                max_data_len = inputs[0].shape[0]
-                max_labels_len = 0 if not labels else labels[0].shape[0]
-
-            x_lens = [item.shape[0] for item in inputs]
-            y_lens = [item.shape[0] for item in labels]
-
-            for i, seq in enumerate(inputs):
-                pad_len = max_data_len - seq.shape[0]
-                inputs[i] = np.pad(seq, ((0, pad_len), (0, 0)), 'constant', constant_values=0)
-                labels[i] = np.pad(labels[i], (0, max_labels_len - labels[i].shape[0]),
-                                   'constant', constant_values=-1)
-
-            inputs = np.asarray(inputs, dtype=np.float32)
-            if labels is not None:
-                labels = np.asarray(labels, dtype=np.float32)
-            inputs = inputs.transpose((1, 0, 2))
-            labels = labels.transpose((1, 0))
-
-            return (nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
-                    nd.array(x_lens, ctx=context.Context('cpu_shared', 0))) \
-                if labels is None else (
-                nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
-                nd.array(x_lens, ctx=context.Context('cpu_shared', 0)),
-                nd.array(labels, dtype=labels.dtype, ctx=context.Context('cpu_shared', 0)),
-                nd.array(y_lens, ctx=context.Context('cpu_shared', 0)))
 
+    # input layout is NTC
+    keys, inputs, labels = [item[0] for item in data], [item[1] for item in data], \
+                           [item[2] for item in data]
+
+    if len(data) > 1:
+        max_data_len = max([seq.shape[0] for seq in inputs])
+        max_labels_len = 0 if not labels else max([seq.shape[0] for seq in labels])
+    else:
+        max_data_len = inputs[0].shape[0]
+        max_labels_len = 0 if not labels else labels[0].shape[0]
+
+    x_lens = [item.shape[0] for item in inputs]
+    y_lens = [item.shape[0] for item in labels]
+
+    for i, seq in enumerate(inputs):
+        pad_len = max_data_len - seq.shape[0]
+        inputs[i] = np.pad(seq, ((0, pad_len), (0, 0)), 'constant', constant_values=0)
+        labels[i] = np.pad(labels[i], (0, max_labels_len - labels[i].shape[0]),
+                           'constant', constant_values=-1)
+
+    inputs = np.asarray(inputs, dtype=np.float32)
+    if labels is not None:
+        labels = np.asarray(labels, dtype=np.float32)
+    inputs = inputs.transpose((1, 0, 2))
+    labels = labels.transpose((1, 0))
+
+    return (nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
+            nd.array(x_lens, ctx=context.Context('cpu_shared', 0))) \
+        if labels is None else (
+        nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
+        nd.array(x_lens, ctx=context.Context('cpu_shared', 0)),
+        nd.array(labels, dtype=labels.dtype, ctx=context.Context('cpu_shared', 0)),
+        nd.array(y_lens, ctx=context.Context('cpu_shared', 0)))
 
-    # This test is pointless on Windows because Windows doesn't fork
-    if platform.system() != 'Windows':
-        data = Dummy(True)
-        loader = DataLoader(data, batch_size=40, batchify_fn=data.batchify, num_workers=2)
-        for epoch in range(1):
-            for i, data in enumerate(loader):
-                if i % 100 == 0:
-                    print(data)
-                    print('{}:{}'.format(epoch, i))
-
-        data = Dummy(True)
-        loader = DataLoader(data, batch_size=40, batchify_fn=data.batchify_list, num_workers=2)
-        for epoch in range(1):
-            for i, data in enumerate(loader):
-                if i % 100 == 0:
-                    print(data)
-                    print('{}:{}'.format(epoch, i))
+@with_seed()
+def test_multi_worker_forked_data_loader():
+    data = _Dummy(False)
+    loader = DataLoader(data, batch_size=40, batchify_fn=_batchify, num_workers=2)
+    for epoch in range(1):
+        for i, data in enumerate(loader):
+            pass
+
+    data = _Dummy(True)
+    loader = DataLoader(data, batch_size=40, batchify_fn=_batchify_list, num_workers=2)
+    for epoch in range(1):
+        for i, data in enumerate(loader):
+            pass
 
 if __name__ == '__main__':
     import nose


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services