You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/24 20:36:14 UTC

[GitHub] zhreshold closed pull request #12622: Gluon DataLoader: avoid recursionlimit error

zhreshold closed pull request #12622: Gluon DataLoader: avoid recursionlimit error
URL: https://github.com/apache/incubator-mxnet/pull/12622
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 1c54158a2ba..50e2ad9f784 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -175,7 +175,12 @@ def _recursive_fork_recordio(obj, depth, max_depth=1000):
 def worker_loop(dataset, key_queue, data_queue, batchify_fn):
     """Worker loop for multiprocessing DataLoader."""
     # re-fork a new recordio handler in new process if applicable
-    _recursive_fork_recordio(dataset, 0, 1000)
+    # for a dataset with transform function, the depth of MXRecordIO is 1
+    # for a lazy transformer, the depth is 2
+    # for a user defined transformer, the depth is unknown, try a reasonable depth
+    limit = sys.getrecursionlimit()
+    max_recursion_depth = min(limit - 5, max(10, limit // 2))
+    _recursive_fork_recordio(dataset, 0, max_recursion_depth)
 
     while True:
         idx, samples = key_queue.get()
diff --git a/python/mxnet/recordio.py b/python/mxnet/recordio.py
index 2ebe657accb..6fc4d8e7bf5 100644
--- a/python/mxnet/recordio.py
+++ b/python/mxnet/recordio.py
@@ -83,6 +83,32 @@ def open(self):
     def __del__(self):
         self.close()
 
+    def __getstate__(self):
+        """Override pickling behavior."""
+        # pickling pointer is not allowed
+        is_open = self.is_open
+        self.close()
+        d = dict(self.__dict__)
+        d['is_open'] = is_open
+        uri = self.uri.value
+        try:
+            uri = uri.decode('utf-8')
+        except AttributeError:
+            pass
+        del d['handle']
+        d['uri'] = uri
+        return d
+
+    def __setstate__(self, d):
+        """Restore from pickled."""
+        self.__dict__ = d
+        is_open = d['is_open']
+        self.is_open = False
+        self.handle = RecordIOHandle()
+        self.uri = c_str(self.uri)
+        if is_open:
+            self.open()
+
     def close(self):
         """Closes the record file."""
         if not self.is_open:
@@ -217,6 +243,12 @@ def close(self):
         super(MXIndexedRecordIO, self).close()
         self.fidx.close()
 
+    def __getstate__(self):
+        """Override pickling behavior."""
+        d = super(MXIndexedRecordIO, self).__getstate__()
+        d['fidx'] = None
+        return d
+
     def seek(self, idx):
         """Sets the current read pointer position.
 
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index cc80aacb644..c731f8d782d 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -73,26 +73,39 @@ def test_recordimage_dataset():
         assert x.shape[0] == 1 and x.shape[3] == 3
         assert y.asscalar() == i
 
+def _dataset_transform_fn(x, y):
+    """Named transform function since lambda function cannot be pickled."""
+    return x, y
+
 @with_seed()
 def test_recordimage_dataset_with_data_loader_multiworker():
-    # This test is pointless on Windows because Windows doesn't fork
-    if platform.system() != 'Windows':
-        recfile = prepare_record()
-        dataset = gluon.data.vision.ImageRecordDataset(recfile)
-        loader = gluon.data.DataLoader(dataset, 1, num_workers=5)
-
-        for i, (x, y) in enumerate(loader):
-            assert x.shape[0] == 1 and x.shape[3] == 3
-            assert y.asscalar() == i
-
-        # with transform
-        fn = lambda x, y : (x, y)
-        dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(fn)
-        loader = gluon.data.DataLoader(dataset, 1, num_workers=5)
-
-        for i, (x, y) in enumerate(loader):
-            assert x.shape[0] == 1 and x.shape[3] == 3
-            assert y.asscalar() == i
+    recfile = prepare_record()
+    dataset = gluon.data.vision.ImageRecordDataset(recfile)
+    loader = gluon.data.DataLoader(dataset, 1, num_workers=5)
+
+    for i, (x, y) in enumerate(loader):
+        assert x.shape[0] == 1 and x.shape[3] == 3
+        assert y.asscalar() == i
+
+    # with transform
+    dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(_dataset_transform_fn)
+    loader = gluon.data.DataLoader(dataset, 1, num_workers=5)
+
+    for i, (x, y) in enumerate(loader):
+        assert x.shape[0] == 1 and x.shape[3] == 3
+        assert y.asscalar() == i
+
+    # try limit recursion depth
+    import sys
+    old_limit = sys.getrecursionlimit()
+    sys.setrecursionlimit(500)  # this should be smaller than any default value used in python
+    dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(_dataset_transform_fn)
+    loader = gluon.data.DataLoader(dataset, 1, num_workers=5)
+
+    for i, (x, y) in enumerate(loader):
+        assert x.shape[0] == 1 and x.shape[3] == 3
+        assert y.asscalar() == i
+    sys.setrecursionlimit(old_limit)
 
 @with_seed()
 def test_sampler():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services