You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/02/05 15:36:49 UTC

[incubator-mxnet] branch v1.x updated: [v1.x] provide a faster PrefetchedDataLoader (#19748)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 7d934a7  [v1.x] provide a faster PrefetchedDataLoader (#19748)
7d934a7 is described below

commit 7d934a72a0c536d10de71e6bd19cfcd3b8222143
Author: Neutron3529 <qw...@163.com>
AuthorDate: Fri Feb 5 23:34:43 2021 +0800

    [v1.x] provide a faster PrefetchedDataLoader (#19748)
    
    * provide a faster PrefetchedDataLoader
    
    Now, since my programming skill is very poor, this `PrefetchedDataLoader` only allow generate a single iter at the same time.
    the benefit of `PrefetchedDataLoader` is that, `PrefetchedDataLoader` provides better performance with a simple replacement in most of the existing codes.
    test:
    ```python
    $ cat iternew.py && python iternew.py
    import mxnet as mx
    from mxnet.gluon.data import PrefetchedDataLoader as DataLoader,ArrayDataset
    from time import sleep,perf_counter_ns
    train_data=ArrayDataset(mx.nd.array([[i] for i in range(50000)]),mx.nd.array([[99-i] for i in range(50000)]))
    test_data=ArrayDataset(mx.nd.array([[i] for i in range(10000)]),mx.nd.array([[99-i] for i in range(10000)]))
    def transform_train(sample):
      sleep(0.0016)
      return sample
    
    def transform_test(sample):
      sleep(0.0008)
      return sample
    
    train_iter=DataLoader(train_data.transform_first(transform_train),batch_size=500,num_workers=10)
    test_iter =DataLoader(test_data .transform_first(transform_test ),batch_size=500,num_workers=10)
    if True:
      tic=perf_counter_ns()
      for epoch in range(10):
        print("epoch"+str(epoch)+" start at "+str(round((perf_counter_ns()-tic)*1e-9,2))+"s")
        for i in train_iter:
          sleep(0.1)
        print("       finished train phase at "+str(round((perf_counter_ns()-tic)*1e-9,2))+"s")
        for i in test_iter:
          sleep(0.05)
        print("        finished test phase at "+str(round((perf_counter_ns()-tic)*1e-9,2))+"s")
      print("cost="+str((perf_counter_ns()-tic)*1e-9)+"s")
    
    epoch0 start at 0.0s
           finished train phase at 11.25s
            finished test phase at 12.31s
    epoch1 start at 12.31s
           finished train phase at 22.62s
            finished test phase at 23.68s
    epoch2 start at 23.68s
           finished train phase at 34.03s
            finished test phase at 35.09s
    epoch3 start at 35.09s
           finished train phase at 45.41s
            finished test phase at 46.48s
    epoch4 start at 46.48s
           finished train phase at 56.82s
            finished test phase at 57.88s
    epoch5 start at 57.88s
           finished train phase at 68.24s
            finished test phase at 69.3s
    epoch6 start at 69.3s
           finished train phase at 79.65s
            finished test phase at 80.71s
    epoch7 start at 80.71s
           finished train phase at 91.04s
            finished test phase at 92.11s
    epoch8 start at 92.11s
           finished train phase at 102.46s
            finished test phase at 103.53s
    epoch9 start at 103.53s
           finished train phase at 113.89s
            finished test phase at 114.95s
    cost=114.94954171600001s
    ```
    (cost ~`129.67192333600002s` if we are using `Dataloader` rather than `PrefetchedDataLoader`)
    
    * provide a faster PrefetchedDataLoader
    
    there already exists some faster dataloader in mxnet 2.0, but in v1.x, the exist dataloader is slower and could be improved by changing its prefetch behavior as what 2.0 have done.
    ```python
    $ cat iternew.py && python iternew.py
    import mxnet as mx
    from mxnet.gluon.data import PrefetchedDataLoader as DataLoader,ArrayDataset
    from time import sleep,perf_counter_ns
    train_data=ArrayDataset(mx.nd.array([[i] for i in range(50000)]),mx.nd.array([[99-i] for i in range(50000)]))
    test_data=ArrayDataset(mx.nd.array([[i] for i in range(10000)]),mx.nd.array([[99-i] for i in range(10000)]))
    def transform_train(sample):
      sleep(0.0016)
      return sample
    
    def transform_test(sample):
      sleep(0.0008)
      return sample
    
    train_iter=DataLoader(train_data.transform_first(transform_train),batch_size=500,num_workers=10)
    test_iter =DataLoader(test_data .transform_first(transform_test ),batch_size=500,num_workers=10)
    if True:
      tic=perf_counter_ns()
      for epoch in range(10):
        print("epoch"+str(epoch)+" start at "+str(round((perf_counter_ns()-tic)*1e-9,2))+"s")
        for i in train_iter:
          sleep(0.1)
        print("       finished train phase at "+str(round((perf_counter_ns()-tic)*1e-9,2))+"s")
        for i in test_iter:
          sleep(0.05)
        print("        finished test phase at "+str(round((perf_counter_ns()-tic)*1e-9,2))+"s")
      print("cost="+str((perf_counter_ns()-tic)*1e-9)+"s")
    epoch0 start at 0.0s
           finished train phase at 11.28s
            finished test phase at 12.35s
    epoch1 start at 12.35s
           finished train phase at 22.73s
            finished test phase at 23.79s
    epoch2 start at 23.79s
           finished train phase at 34.15s
            finished test phase at 35.21s
    epoch3 start at 35.22s
           finished train phase at 45.59s
            finished test phase at 46.66s
    epoch4 start at 46.66s
           finished train phase at 57.01s
            finished test phase at 58.07s
    epoch5 start at 58.07s
           finished train phase at 68.43s
            finished test phase at 69.5s
    epoch6 start at 69.5s
           finished train phase at 79.87s
            finished test phase at 80.93s
    epoch7 start at 80.93s
           finished train phase at 91.3s
            finished test phase at 92.37s
    epoch8 start at 92.37s
           finished train phase at 102.74s
            finished test phase at 103.8s
    epoch9 start at 103.8s
           finished train phase at 114.17s
            finished test phase at 115.23s
    cost=115.23376344s
    ```
    
    * Update test_gluon_data.py
    
    add unittest for PrefetchedDataLoader
    
    * Update dataloader.py
    
    update document
    
    * delete trailing-whitespace
    
    * remove the modification of num_workers.
    
    * Update dataloader.py
    
    previous test shows that there may be something wrong with the `_MultiWorkerIter` according to the inappropriate __iter__() is called, I tried to fix it by moving the call here.
    
    * add an auto_reload flag into the dataloader
    
    the added flag is set to `True` rather than the default `False` since in mxnet 2.0, the default `nopython` mode prefetch data and auto reload it.
    
    * Update dataloader.py
    
    merge `prefetcheddataloader` into `dataloader`
    
    * Update dataloader.py
    
    remove whitespace
    
    * Update dataloader.py
    
    solve the warning in L726
    
    * Update dataloader.py
    
    fix  typo
    
    * Update dataloader.py
    
    fix the outdated perfetcheddataloader
    
    * using pytest for nested loop
    
    * change auto_reload to false.
    
    * Revert "using pytest for nested loop"
    
    This reverts commit 2c8d8582491b98ffa28cd8364edada4da20ffbc4.
    
    Co-authored-by: Leonard Lausen <la...@amazon.com>
---
 python/mxnet/gluon/data/dataloader.py    | 70 +++++++++++++++++++++++++++++++-
 tests/python/unittest/test_gluon_data.py |  9 ++--
 2 files changed, 74 insertions(+), 5 deletions(-)

diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index d341484..15c37cb 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -572,11 +572,55 @@ class DataLoader(object):
         unless you are experiencing timeout and you know it's due to slow data loading.
         Sometimes full `shared_memory` will cause all workers to hang and causes timeout. In these
         cases please reduce `num_workers` or increase system `shared_memory` size instead.
+    auto_reload : bool, default is True
+        control whether prefetch data after a batch is ended.
+
+    Example:
+    >>> from mxnet.gluon.data import DataLoader, ArrayDataset
+    >>> train_data = ArrayDataset([i for i in range(10)],[9-i for i in range(10)])
+    >>> def transform_train(sample):
+    ...   if sample == 0 : print('(pre)fetching data here')
+    ...   return sample
+    ...
+    >>> train_iter = DataLoader(train_data.transform_first(transform_train),
+    ...                         auto_reload=False, batch_size=1,num_workers=1)
+    >>> # no prefetch is performed, the prefetch & autoload start after
+    >>> # train_iter.__iter__() is called.
+    >>> for i in train_iter:pass
+    (pre)fetching data here
+    >>> train_iter = DataLoader(train_data.transform_first(transform_train),
+    ...                         batch_size=1,num_workers=1)
+    (pre)fetching data here
+    >>> it = iter(train_iter) # nothing is generated since lazy-evaluation occurs
+    >>> it2 = iter(train_iter)
+    >>> it3 = iter(train_iter)
+    >>> it4 = iter(train_iter)
+    >>> _ = next(it2) # the first iter we are using is the prefetched iter.
+    >>> _ = next(it) # since the prefetched iter is consumed, we have to fetch data for `it`.
+    (pre)fetching data here
+    >>> _ = [None for _ in it3]
+    (pre)fetching data here
+    (pre)fetching data here
+    >>> # Here, 2 prefetches are triggered, one is fetching the first batch of `it3` and
+    >>> # another is when `it3` yield its last item, a prefetch is automatically performed.
+    >>> _ = [None for _ in it]
+    >>> # no prefetch is happened since train_loader has already prefetch data.
+    >>> _ = next(it4)
+    >>> # since the prefetch is performed, it4 become the prefetched iter.
+    >>>
+    >>> test_data = ArrayDataset([i for i in range(10)],[9-i for i in range(10)])
+    >>> test_iter = DataLoader(test_data, batch_size=1,num_workers=1)
+    >>> for epoch in range(200):
+    ...   # there is almost no difference between it and the default DataLoader
+    ...   for data, label in train_iter:
+    ...     # training...
+    ...   for data, label in test_iter:
+    ...     # testing...
     """
     def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
                  last_batch=None, batch_sampler=None, batchify_fn=None,
                  num_workers=0, pin_memory=False, pin_device_id=0,
-                 prefetch=None, thread_pool=False, timeout=120):
+                 prefetch=None, thread_pool=False, timeout=120, auto_reload=False):
         self._dataset = dataset
         self._pin_memory = pin_memory
         self._pin_device_id = pin_device_id
@@ -627,8 +671,24 @@ class DataLoader(object):
                 self._batchify_fn = default_batchify_fn
         else:
             self._batchify_fn = batchify_fn
+        self.auto_reload = auto_reload
+        if self.auto_reload:
+            self.refresh()
+        else:
+            self.clean() # ensure self._iter exists.
 
     def __iter__(self):
+        if self._iter is None:
+            self.refresh()
+        t = self._iter
+        self._iter = None # ensure a single iter would not using twice.
+        for item in t:
+            yield item
+        if self._iter is None and self.auto_reload:
+            # ensure we do not waste any exist iter by mistake
+            self.refresh()
+
+    def _prefetch_iter(self):
         if self._num_workers == 0:
             def same_process_iter():
                 for batch in self._batch_sampler:
@@ -655,3 +715,11 @@ class DataLoader(object):
             # https://bugs.python.org/issue34172
             assert isinstance(self._worker_pool, multiprocessing.pool.Pool)
             self._worker_pool.terminate()
+
+    def refresh(self):
+        """Refresh its iter, fetch data again from its dataset"""
+        self._iter = self._prefetch_iter()
+
+    def clean(self):
+        """Remove its prefetched iter, the prefetch step will start after call its __iter__()"""
+        self._iter = None
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index ef27a7f..a2b8164 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -158,10 +158,11 @@ class Dataset(gluon.data.Dataset):
 def test_multi_worker():
     data = Dataset()
     for thread_pool in [True, False]:
-        loader = gluon.data.DataLoader(data, batch_size=1, num_workers=5, thread_pool=thread_pool)
-        for i, batch in enumerate(loader):
-            assert (batch.asnumpy() == i).all()
-
+        for auto_reload in [True, False]:
+            loader = gluon.data.DataLoader(data, batch_size=1, num_workers=5,
+                                           thread_pool=thread_pool,auto_reload=auto_reload)
+            for i, batch in enumerate(loader):
+                assert (batch.asnumpy() == i).all()
 
 @with_seed()
 def test_multi_worker_shape():