You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/12/19 19:10:51 UTC
[incubator-mxnet] branch master updated: fix unpicklable
transform_first on windows (#13686)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 61744b5 fix unpicklable transform_first on windows (#13686)
61744b5 is described below
commit 61744b5009e354842df01563a52ad658fe26dd1a
Author: Joshua Z. Zhang <ch...@gmail.com>
AuthorDate: Wed Dec 19 11:10:33 2018 -0800
fix unpicklable transform_first on windows (#13686)
---
python/mxnet/gluon/data/dataset.py | 16 +++++++++++-----
tests/python/unittest/test_gluon_data.py | 12 ++++++------
2 files changed, 17 insertions(+), 11 deletions(-)
diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py
index c93a4b1..28d19c9 100644
--- a/python/mxnet/gluon/data/dataset.py
+++ b/python/mxnet/gluon/data/dataset.py
@@ -88,11 +88,7 @@ class Dataset(object):
Dataset
The transformed dataset.
"""
- def base_fn(x, *args):
- if args:
- return (fn(x),) + args
- return fn(x)
- return self.transform(base_fn, lazy)
+ return self.transform(_TransformFirstClosure(fn), lazy)
class SimpleDataset(Dataset):
@@ -129,6 +125,16 @@ class _LazyTransformDataset(Dataset):
return self._fn(item)
+class _TransformFirstClosure(object):
+ """Use callable object instead of nested function, it can be pickled."""
+ def __init__(self, fn):
+ self._fn = fn
+
+ def __call__(self, x, *args):
+ if args:
+ return (self._fn(x),) + args
+ return self._fn(x)
+
class ArrayDataset(Dataset):
"""A dataset that combines multiple dataset-like objects, e.g.
Datasets, lists, arrays, etc.
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index 6a53226..353a819 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -77,6 +77,10 @@ def _dataset_transform_fn(x, y):
"""Named transform function since lambda function cannot be pickled."""
return x, y
+def _dataset_transform_first_fn(x):
+ """Named transform function since lambda function cannot be pickled."""
+ return x
+
@with_seed()
def test_recordimage_dataset_with_data_loader_multiworker():
recfile = prepare_record()
@@ -95,17 +99,13 @@ def test_recordimage_dataset_with_data_loader_multiworker():
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i
- # try limit recursion depth
- import sys
- old_limit = sys.getrecursionlimit()
- sys.setrecursionlimit(500) # this should be smaller than any default value used in python
- dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(_dataset_transform_fn)
+ # with transform_first
+ dataset = gluon.data.vision.ImageRecordDataset(recfile).transform_first(_dataset_transform_first_fn)
loader = gluon.data.DataLoader(dataset, 1, num_workers=5)
for i, (x, y) in enumerate(loader):
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i
- sys.setrecursionlimit(old_limit)
@with_seed()
def test_sampler():