You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/12/19 19:10:35 UTC

[GitHub] eric-haibin-lin closed pull request #13686: [gluon][transform]fix unpicklable transform_first on windows

eric-haibin-lin closed pull request #13686: [gluon][transform]fix unpicklable transform_first on windows
URL: https://github.com/apache/incubator-mxnet/pull/13686
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py
index c93a4b1cd6b..28d19c9fe37 100644
--- a/python/mxnet/gluon/data/dataset.py
+++ b/python/mxnet/gluon/data/dataset.py
@@ -88,11 +88,7 @@ def transform_first(self, fn, lazy=True):
         Dataset
             The transformed dataset.
         """
-        def base_fn(x, *args):
-            if args:
-                return (fn(x),) + args
-            return fn(x)
-        return self.transform(base_fn, lazy)
+        return self.transform(_TransformFirstClosure(fn), lazy)
 
 
 class SimpleDataset(Dataset):
@@ -129,6 +125,16 @@ def __getitem__(self, idx):
         return self._fn(item)
 
 
+class _TransformFirstClosure(object):
+    """Use callable object instead of nested function, it can be pickled."""
+    def __init__(self, fn):
+        self._fn = fn
+
+    def __call__(self, x, *args):
+        if args:
+            return (self._fn(x),) + args
+        return self._fn(x)
+
 class ArrayDataset(Dataset):
     """A dataset that combines multiple dataset-like objects, e.g.
     Datasets, lists, arrays, etc.
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index 6a5322616e2..353a819ddbf 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -77,6 +77,10 @@ def _dataset_transform_fn(x, y):
     """Named transform function since lambda function cannot be pickled."""
     return x, y
 
+def _dataset_transform_first_fn(x):
+    """Named transform function since lambda function cannot be pickled."""
+    return x
+
 @with_seed()
 def test_recordimage_dataset_with_data_loader_multiworker():
     recfile = prepare_record()
@@ -95,17 +99,13 @@ def test_recordimage_dataset_with_data_loader_multiworker():
         assert x.shape[0] == 1 and x.shape[3] == 3
         assert y.asscalar() == i
 
-    # try limit recursion depth
-    import sys
-    old_limit = sys.getrecursionlimit()
-    sys.setrecursionlimit(500)  # this should be smaller than any default value used in python
-    dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(_dataset_transform_fn)
+    # with transform_first
+    dataset = gluon.data.vision.ImageRecordDataset(recfile).transform_first(_dataset_transform_first_fn)
     loader = gluon.data.DataLoader(dataset, 1, num_workers=5)
 
     for i, (x, y) in enumerate(loader):
         assert x.shape[0] == 1 and x.shape[3] == 3
         assert y.asscalar() == i
-    sys.setrecursionlimit(old_limit)
 
 @with_seed()
 def test_sampler():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services