You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/12/19 19:10:51 UTC

[incubator-mxnet] branch master updated: fix unpicklable transform_first on windows (#13686)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 61744b5  fix unpicklable transform_first on windows (#13686)
61744b5 is described below

commit 61744b5009e354842df01563a52ad658fe26dd1a
Author: Joshua Z. Zhang <ch...@gmail.com>
AuthorDate: Wed Dec 19 11:10:33 2018 -0800

    fix unpicklable transform_first on windows (#13686)
---
 python/mxnet/gluon/data/dataset.py       | 16 +++++++++++-----
 tests/python/unittest/test_gluon_data.py | 12 ++++++------
 2 files changed, 17 insertions(+), 11 deletions(-)

diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py
index c93a4b1..28d19c9 100644
--- a/python/mxnet/gluon/data/dataset.py
+++ b/python/mxnet/gluon/data/dataset.py
@@ -88,11 +88,7 @@ class Dataset(object):
         Dataset
             The transformed dataset.
         """
-        def base_fn(x, *args):
-            if args:
-                return (fn(x),) + args
-            return fn(x)
-        return self.transform(base_fn, lazy)
+        return self.transform(_TransformFirstClosure(fn), lazy)
 
 
 class SimpleDataset(Dataset):
@@ -129,6 +125,16 @@ class _LazyTransformDataset(Dataset):
         return self._fn(item)
 
 
+class _TransformFirstClosure(object):
+    """Use callable object instead of nested function, it can be pickled."""
+    def __init__(self, fn):
+        self._fn = fn
+
+    def __call__(self, x, *args):
+        if args:
+            return (self._fn(x),) + args
+        return self._fn(x)
+
 class ArrayDataset(Dataset):
     """A dataset that combines multiple dataset-like objects, e.g.
     Datasets, lists, arrays, etc.
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index 6a53226..353a819 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -77,6 +77,10 @@ def _dataset_transform_fn(x, y):
     """Named transform function since lambda function cannot be pickled."""
     return x, y
 
+def _dataset_transform_first_fn(x):
+    """Named transform function since lambda function cannot be pickled."""
+    return x
+
 @with_seed()
 def test_recordimage_dataset_with_data_loader_multiworker():
     recfile = prepare_record()
@@ -95,17 +99,13 @@ def test_recordimage_dataset_with_data_loader_multiworker():
         assert x.shape[0] == 1 and x.shape[3] == 3
         assert y.asscalar() == i
 
-    # try limit recursion depth
-    import sys
-    old_limit = sys.getrecursionlimit()
-    sys.setrecursionlimit(500)  # this should be smaller than any default value used in python
-    dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(_dataset_transform_fn)
+    # with transform_first
+    dataset = gluon.data.vision.ImageRecordDataset(recfile).transform_first(_dataset_transform_first_fn)
     loader = gluon.data.DataLoader(dataset, 1, num_workers=5)
 
     for i, (x, y) in enumerate(loader):
         assert x.shape[0] == 1 and x.shape[3] == 3
         assert y.asscalar() == i
-    sys.setrecursionlimit(old_limit)
 
 @with_seed()
 def test_sampler():