You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/11 19:56:42 UTC
[incubator-mxnet] branch master updated: generalize array dataset
(#8612)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 9572340 generalize array dataset (#8612)
9572340 is described below
commit 957234075c535dc6dbee95ce9416d37b67025db0
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Sat Nov 11 11:56:33 2017 -0800
generalize array dataset (#8612)
* generalize array dataset
* Update test_gluon_data.py
---
python/mxnet/gluon/data/dataset.py | 35 ++++++++++++++++++--------------
tests/python/unittest/test_gluon_data.py | 7 ++++++-
2 files changed, 26 insertions(+), 16 deletions(-)
diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py
index cbc73bc..059c2a6 100644
--- a/python/mxnet/gluon/data/dataset.py
+++ b/python/mxnet/gluon/data/dataset.py
@@ -40,30 +40,35 @@ class Dataset(object):
class ArrayDataset(Dataset):
- """A dataset with a data array and a label array.
+ """A dataset of multiple arrays.
- The i-th sample is `(data[i], lable[i])`.
+ The i-th sample is `(x1[i], x2[i], ...)`.
Parameters
----------
- data : array-like object
- The data array. Can be mxnet or numpy array.
- label : array-like object
- The label array. Can be mxnet or numpy array.
+ *args : one or more arrays
+ The data arrays.
"""
- def __init__(self, data, label):
- assert len(data) == len(label)
- self._data = data
- if isinstance(label, ndarray.NDArray) and len(label.shape) == 1:
- self._label = label.asnumpy()
- else:
- self._label = label
+ def __init__(self, *args):
+ assert len(args) > 0, "Needs at least 1 arrays"
+ self._length = len(args[0])
+ self._data = []
+ for i, data in enumerate(args):
+ assert len(data) == self._length, \
+ "All arrays must have the same length. But the first has %s " \
+ "while the %d-th has %d."%(length, i+1, len(data))
+ if isinstance(data, ndarray.NDArray) and len(data.shape) == 1:
+ data = data.asnumpy()
+ self._data.append(data)
def __getitem__(self, idx):
- return self._data[idx], self._label[idx]
+ if len(self._data) == 1:
+ return self._data[0][idx]
+ else:
+ return tuple(data[idx] for data in self._data)
def __len__(self):
- return len(self._data)
+ return self._length
class RecordFileDataset(Dataset):
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index 341e6d1..397fbbd 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -26,11 +26,16 @@ def test_array_dataset():
Y = np.random.uniform(size=(10,))
dataset = gluon.data.ArrayDataset(X, Y)
loader = gluon.data.DataLoader(dataset, 2)
-
for i, (x, y) in enumerate(loader):
assert mx.test_utils.almost_equal(x.asnumpy(), X[i*2:(i+1)*2])
assert mx.test_utils.almost_equal(y.asnumpy(), Y[i*2:(i+1)*2])
+ dataset = gluon.data.ArrayDataset(X)
+ loader = gluon.data.DataLoader(dataset, 2)
+
+ for i, x in enumerate(loader):
+ assert mx.test_utils.almost_equal(x.asnumpy(), X[i*2:(i+1)*2])
+
def prepare_record():
if not os.path.isdir("data/test_images"):
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].