You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/08/10 16:01:53 UTC

[incubator-mxnet] branch master updated: Fix shared memory with gluon dataloader, add option pin_memory (#11908)

This is an automated email from the ASF dual-hosted git repository.

zhreshold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 5a9c3af  Fix shared memory with gluon dataloader, add option pin_memory (#11908)
5a9c3af is described below

commit 5a9c3af4b101b85047b306575575ea9022a8474e
Author: Joshua Z. Zhang <ch...@gmail.com>
AuthorDate: Fri Aug 10 09:01:44 2018 -0700

    Fix shared memory with gluon dataloader, add option pin_memory (#11908)
    
    * use threading for mp dataloader fetching, allow pin_memory option
    
    * allow pin tuple of data into cpu_pinned
    
    * fix as_in_context if not cpu_pinned
    
    * fix cpu_pinned
    
    * fix unittest for windows, update doc that windows mp is available
    
    * fix pin_memory
    
    * fix lint
    
    * always use simplequeue for data queue
    
    * remove main thread clearing for data_queue
    
    * do not use outside folder as pythonpath but run nosetests inside
    
    * use :MXNET_LIBRARY_PATH= to locate dll
    
    * fix dll path
    
    * correct dll path
---
 ci/windows/test_py2_cpu.ps1              |   1 +
 ci/windows/test_py2_gpu.ps1              |   1 +
 ci/windows/test_py3_cpu.ps1              |   1 +
 ci/windows/test_py3_gpu.ps1              |   1 +
 python/mxnet/gluon/data/dataloader.py    |  61 ++++++++---
 tests/python/unittest/test_gluon_data.py | 168 +++++++++++++++----------------
 6 files changed, 128 insertions(+), 105 deletions(-)

diff --git a/ci/windows/test_py2_cpu.ps1 b/ci/windows/test_py2_cpu.ps1
index 1623d29..aa38b81 100644
--- a/ci/windows/test_py2_cpu.ps1
+++ b/ci/windows/test_py2_cpu.ps1
@@ -16,6 +16,7 @@
 # under the License.
 
 7z x -y windows_package.7z
+$env:MXNET_LIBRARY_PATH=join-path $pwd.Path windows_package\lib\libmxnet.dll
 $env:PYTHONPATH=join-path $pwd.Path windows_package\python
 $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
 c:\Anaconda3\envs\py2\Scripts\pip install -r tests\requirements.txt
diff --git a/ci/windows/test_py2_gpu.ps1 b/ci/windows/test_py2_gpu.ps1
index 13cd536..5f8de5a 100644
--- a/ci/windows/test_py2_gpu.ps1
+++ b/ci/windows/test_py2_gpu.ps1
@@ -16,6 +16,7 @@
 # under the License.
 
 7z x -y windows_package.7z
+$env:MXNET_LIBRARY_PATH=join-path $pwd.Path windows_package\lib\libmxnet.dll
 $env:PYTHONPATH=join-path $pwd.Path windows_package\python
 $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
 c:\Anaconda3\envs\py2\Scripts\pip install -r tests\requirements.txt
diff --git a/ci/windows/test_py3_cpu.ps1 b/ci/windows/test_py3_cpu.ps1
index 98d4e41..0dd48de 100644
--- a/ci/windows/test_py3_cpu.ps1
+++ b/ci/windows/test_py3_cpu.ps1
@@ -16,6 +16,7 @@
 # under the License.
 
 7z x -y windows_package.7z
+$env:MXNET_LIBRARY_PATH=join-path $pwd.Path windows_package\lib\libmxnet.dll
 $env:PYTHONPATH=join-path $pwd.Path windows_package\python
 $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
 c:\Anaconda3\envs\py3\Scripts\pip install -r tests\requirements.txt
diff --git a/ci/windows/test_py3_gpu.ps1 b/ci/windows/test_py3_gpu.ps1
index b94b4f3..4a0feb1 100644
--- a/ci/windows/test_py3_gpu.ps1
+++ b/ci/windows/test_py3_gpu.ps1
@@ -16,6 +16,7 @@
 # under the License.
 
 7z x -y windows_package.7z
+$env:MXNET_LIBRARY_PATH=join-path $pwd.Path windows_package\lib\libmxnet.dll
 $env:PYTHONPATH=join-path $pwd.Path windows_package\python
 $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
 c:\Anaconda3\envs\py3\Scripts\pip install -r tests\requirements.txt
diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index eb1eb41..13ab544 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -16,7 +16,7 @@
 # under the License.
 
 # coding: utf-8
-# pylint: disable=
+# pylint: disable=ungrouped-imports
 """Dataset generator."""
 __all__ = ['DataLoader']
 
@@ -26,6 +26,7 @@ import sys
 import multiprocessing
 import multiprocessing.queues
 from multiprocessing.reduction import ForkingPickler
+import threading
 import numpy as np
 
 try:
@@ -149,6 +150,14 @@ def default_mp_batchify_fn(data):
                         ctx=context.Context('cpu_shared', 0))
 
 
+def _as_in_context(data, ctx):
+    """Move data into new context."""
+    if isinstance(data, nd.NDArray):
+        return data.as_in_context(ctx)
+    elif isinstance(data, (list, tuple)):
+        return [_as_in_context(d, ctx) for d in data]
+    return data
+
 def worker_loop(dataset, key_queue, data_queue, batchify_fn):
     """Worker loop for multiprocessing DataLoader."""
     dataset._fork()
@@ -159,9 +168,21 @@ def worker_loop(dataset, key_queue, data_queue, batchify_fn):
         batch = batchify_fn([dataset[i] for i in samples])
         data_queue.put((idx, batch))
 
+def fetcher_loop(data_queue, data_buffer, pin_memory=False):
+    """Fetcher loop for fetching data from queue and put in reorder dict."""
+    while True:
+        idx, batch = data_queue.get()
+        if idx is None:
+            break
+        if pin_memory:
+            batch = _as_in_context(batch, context.cpu_pinned())
+        else:
+            batch = _as_in_context(batch, context.cpu())
+        data_buffer[idx] = batch
+
 class _MultiWorkerIter(object):
     """Interal multi-worker iterator for DataLoader."""
-    def __init__(self, num_workers, dataset, batchify_fn, batch_sampler):
+    def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False):
         assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers)
         self._num_workers = num_workers
         self._dataset = dataset
@@ -184,6 +205,12 @@ class _MultiWorkerIter(object):
             worker.start()
             workers.append(worker)
 
+        self._fetcher = threading.Thread(
+            target=fetcher_loop,
+            args=(self._data_queue, self._data_buffer, pin_memory))
+        self._fetcher.daemon = True
+        self._fetcher.start()
+
         # pre-fetch
         for _ in range(2 * self._num_workers):
             self._push_next()
@@ -210,13 +237,11 @@ class _MultiWorkerIter(object):
             raise StopIteration
 
         while True:
-            self._push_next()
             if self._rcvd_idx in self._data_buffer:
                 batch = self._data_buffer.pop(self._rcvd_idx)
                 self._rcvd_idx += 1
+                self._push_next()
                 return batch
-            idx, batch = self._data_queue.get()
-            self._data_buffer[idx] = batch
 
     def next(self):
         return self.__next__()
@@ -229,11 +254,7 @@ class _MultiWorkerIter(object):
         if not self._shutdown:
             for _ in range(self._num_workers):
                 self._key_queue.put((None, None))
-            try:
-                while not self._data_queue.empty():
-                    self._data_queue.get()
-            except IOError:
-                pass
+            self._data_queue.put((None, None))
             self._shutdown = True
 
 
@@ -277,12 +298,16 @@ class DataLoader(object):
 
     num_workers : int, default 0
         The number of multiprocessing workers to use for data preprocessing.
-        `num_workers > 0` is not supported on Windows yet.
+    pin_memory : boolean, default False
+        If ``True``, the dataloader will copy NDArrays into pinned memory
+        before returning them. Copying from CPU pinned memory to GPU is faster
+        than from normal CPU memory.
     """
     def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
                  last_batch=None, batch_sampler=None, batchify_fn=None,
-                 num_workers=0):
+                 num_workers=0, pin_memory=False):
         self._dataset = dataset
+        self._pin_memory = pin_memory
 
         if batch_sampler is None:
             if batch_size is None:
@@ -315,13 +340,17 @@ class DataLoader(object):
 
     def __iter__(self):
         if self._num_workers == 0:
-            generator = lambda: [(yield self._batchify_fn([self._dataset[idx] for idx in batch]))
-                                 for batch in self._batch_sampler]
-            return generator()
+            def same_process_iter():
+                for batch in self._batch_sampler:
+                    ret = self._batchify_fn([self._dataset[idx] for idx in batch])
+                    if self._pin_memory:
+                        ret = _as_in_context(ret, context.cpu_pinned())
+                    yield ret
+            return same_process_iter()
 
         # multi-worker
         return _MultiWorkerIter(self._num_workers, self._dataset,
-                                self._batchify_fn, self._batch_sampler)
+                                self._batchify_fn, self._batch_sampler, self._pin_memory)
 
     def __len__(self):
         return len(self._batch_sampler)
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index 0438044..4dc4f3a 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -130,99 +130,89 @@ def test_multi_worker():
     for i, batch in enumerate(loader):
         assert (batch.asnumpy() == i).all()
 
-@with_seed()
-def test_multi_worker_forked_data_loader():
+class _Dummy(Dataset):
+    """Dummy dataset for randomized shape arrays."""
+    def __init__(self, random_shape):
+        self.random_shape = random_shape
+
+    def __getitem__(self, idx):
+        key = idx
+        if self.random_shape:
+            out = np.random.uniform(size=(random.randint(1000, 1100), 40))
+            labels = np.random.uniform(size=(random.randint(10, 15)))
+        else:
+            out = np.random.uniform(size=(1000, 40))
+            labels = np.random.uniform(size=(10))
+        return key, out, labels
+
+    def __len__(self):
+        return 50
+
+def _batchify_list(data):
+    """
+    return list of ndarray without stack/concat/pad
     """
-    Test should successfully run its course of multi-process/forked data loader without errors
+    if isinstance(data, (tuple, list)):
+        return list(data)
+    if isinstance(data, mx.nd.NDArray):
+        return [data]
+    return data
+
+def _batchify(data):
+    """
+    Collate data into batch. Use shared memory for stacking.
+    :param data: a list of array, with layout of 'NTC'.
+    :return either x  and x's unpadded lengths, or x, x's unpadded lengths, y and y's unpadded lengths
+            if labels are not supplied.
     """
-    class Dummy(Dataset):
-        def __init__(self, random_shape):
-            self.random_shape = random_shape
-
-        def __getitem__(self, idx):
-            key = idx
-            if self.random_shape:
-                out = np.random.uniform(size=(random.randint(1000, 1100), 40))
-                labels = np.random.uniform(size=(random.randint(10, 15)))
-            else:
-                out = np.random.uniform(size=(1000, 40))
-                labels = np.random.uniform(size=(10))
-            return key, out, labels
-
-        def __len__(self):
-            return 50
-
-        def batchify_list(self, data):
-            """
-            return list of ndarray without stack/concat/pad
-            """
-            if isinstance(data, (tuple, list)):
-                return list(data)
-            if isinstance(data, mx.nd.NDArray):
-                return [data]
-            return data
-
-        def batchify(self, data):
-            """
-            Collate data into batch. Use shared memory for stacking.
-
-            :param data: a list of array, with layout of 'NTC'.
-            :return either x  and x's unpadded lengths, or x, x's unpadded lengths, y and y's unpadded lengths
-                    if labels are not supplied.
-            """
-
-            # input layout is NTC
-            keys, inputs, labels = [item[0] for item in data], [item[1] for item in data], \
-                                   [item[2] for item in data]
-
-            if len(data) > 1:
-                max_data_len = max([seq.shape[0] for seq in inputs])
-                max_labels_len = 0 if not labels else max([seq.shape[0] for seq in labels])
-            else:
-                max_data_len = inputs[0].shape[0]
-                max_labels_len = 0 if not labels else labels[0].shape[0]
-
-            x_lens = [item.shape[0] for item in inputs]
-            y_lens = [item.shape[0] for item in labels]
-
-            for i, seq in enumerate(inputs):
-                pad_len = max_data_len - seq.shape[0]
-                inputs[i] = np.pad(seq, ((0, pad_len), (0, 0)), 'constant', constant_values=0)
-                labels[i] = np.pad(labels[i], (0, max_labels_len - labels[i].shape[0]),
-                                   'constant', constant_values=-1)
-
-            inputs = np.asarray(inputs, dtype=np.float32)
-            if labels is not None:
-                labels = np.asarray(labels, dtype=np.float32)
-            inputs = inputs.transpose((1, 0, 2))
-            labels = labels.transpose((1, 0))
-
-            return (nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
-                    nd.array(x_lens, ctx=context.Context('cpu_shared', 0))) \
-                if labels is None else (
-                nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
-                nd.array(x_lens, ctx=context.Context('cpu_shared', 0)),
-                nd.array(labels, dtype=labels.dtype, ctx=context.Context('cpu_shared', 0)),
-                nd.array(y_lens, ctx=context.Context('cpu_shared', 0)))
 
+    # input layout is NTC
+    keys, inputs, labels = [item[0] for item in data], [item[1] for item in data], \
+                           [item[2] for item in data]
+
+    if len(data) > 1:
+        max_data_len = max([seq.shape[0] for seq in inputs])
+        max_labels_len = 0 if not labels else max([seq.shape[0] for seq in labels])
+    else:
+        max_data_len = inputs[0].shape[0]
+        max_labels_len = 0 if not labels else labels[0].shape[0]
+
+    x_lens = [item.shape[0] for item in inputs]
+    y_lens = [item.shape[0] for item in labels]
+
+    for i, seq in enumerate(inputs):
+        pad_len = max_data_len - seq.shape[0]
+        inputs[i] = np.pad(seq, ((0, pad_len), (0, 0)), 'constant', constant_values=0)
+        labels[i] = np.pad(labels[i], (0, max_labels_len - labels[i].shape[0]),
+                           'constant', constant_values=-1)
+
+    inputs = np.asarray(inputs, dtype=np.float32)
+    if labels is not None:
+        labels = np.asarray(labels, dtype=np.float32)
+    inputs = inputs.transpose((1, 0, 2))
+    labels = labels.transpose((1, 0))
+
+    return (nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
+            nd.array(x_lens, ctx=context.Context('cpu_shared', 0))) \
+        if labels is None else (
+        nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
+        nd.array(x_lens, ctx=context.Context('cpu_shared', 0)),
+        nd.array(labels, dtype=labels.dtype, ctx=context.Context('cpu_shared', 0)),
+        nd.array(y_lens, ctx=context.Context('cpu_shared', 0)))
 
-    # This test is pointless on Windows because Windows doesn't fork
-    if platform.system() != 'Windows':
-        data = Dummy(True)
-        loader = DataLoader(data, batch_size=40, batchify_fn=data.batchify, num_workers=2)
-        for epoch in range(1):
-            for i, data in enumerate(loader):
-                if i % 100 == 0:
-                    print(data)
-                    print('{}:{}'.format(epoch, i))
-
-        data = Dummy(True)
-        loader = DataLoader(data, batch_size=40, batchify_fn=data.batchify_list, num_workers=2)
-        for epoch in range(1):
-            for i, data in enumerate(loader):
-                if i % 100 == 0:
-                    print(data)
-                    print('{}:{}'.format(epoch, i))
+@with_seed()
+def test_multi_worker_forked_data_loader():
+    data = _Dummy(False)
+    loader = DataLoader(data, batch_size=40, batchify_fn=_batchify, num_workers=2)
+    for epoch in range(1):
+        for i, data in enumerate(loader):
+            pass
+
+    data = _Dummy(True)
+    loader = DataLoader(data, batch_size=40, batchify_fn=_batchify_list, num_workers=2)
+    for epoch in range(1):
+        for i, data in enumerate(loader):
+            pass
 
 if __name__ == '__main__':
     import nose