You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/11 19:56:42 UTC

[incubator-mxnet] branch master updated: generalize array dataset (#8612)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9572340  generalize array dataset (#8612)
9572340 is described below

commit 957234075c535dc6dbee95ce9416d37b67025db0
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Sat Nov 11 11:56:33 2017 -0800

    generalize array dataset (#8612)
    
    * generalize array dataset
    
    * Update test_gluon_data.py
---
 python/mxnet/gluon/data/dataset.py       | 35 ++++++++++++++++++--------------
 tests/python/unittest/test_gluon_data.py |  7 ++++++-
 2 files changed, 26 insertions(+), 16 deletions(-)

diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py
index cbc73bc..059c2a6 100644
--- a/python/mxnet/gluon/data/dataset.py
+++ b/python/mxnet/gluon/data/dataset.py
@@ -40,30 +40,35 @@ class Dataset(object):
 
 
 class ArrayDataset(Dataset):
-    """A dataset with a data array and a label array.
+    """A dataset of multiple arrays.
 
-    The i-th sample is `(data[i], lable[i])`.
+    The i-th sample is `(x1[i], x2[i], ...)`.
 
     Parameters
     ----------
-    data : array-like object
-        The data array. Can be mxnet or numpy array.
-    label : array-like object
-        The label array. Can be mxnet or numpy array.
+    *args : one or more arrays
+        The data arrays.
     """
-    def __init__(self, data, label):
-        assert len(data) == len(label)
-        self._data = data
-        if isinstance(label, ndarray.NDArray) and len(label.shape) == 1:
-            self._label = label.asnumpy()
-        else:
-            self._label = label
+    def __init__(self, *args):
+        assert len(args) > 0, "Needs at least 1 arrays"
+        self._length = len(args[0])
+        self._data = []
+        for i, data in enumerate(args):
+            assert len(data) == self._length, \
+                "All arrays must have the same length. But the first has %s " \
+                "while the %d-th has %d."%(length, i+1, len(data))
+            if isinstance(data, ndarray.NDArray) and len(data.shape) == 1:
+                data = data.asnumpy()
+            self._data.append(data)
 
     def __getitem__(self, idx):
-        return self._data[idx], self._label[idx]
+        if len(self._data) == 1:
+            return self._data[0][idx]
+        else:
+            return tuple(data[idx] for data in self._data)
 
     def __len__(self):
-        return len(self._data)
+        return self._length
 
 
 class RecordFileDataset(Dataset):
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index 341e6d1..397fbbd 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -26,11 +26,16 @@ def test_array_dataset():
     Y = np.random.uniform(size=(10,))
     dataset = gluon.data.ArrayDataset(X, Y)
     loader = gluon.data.DataLoader(dataset, 2)
-
     for i, (x, y) in enumerate(loader):
         assert mx.test_utils.almost_equal(x.asnumpy(), X[i*2:(i+1)*2])
         assert mx.test_utils.almost_equal(y.asnumpy(), Y[i*2:(i+1)*2])
 
+    dataset = gluon.data.ArrayDataset(X)
+    loader = gluon.data.DataLoader(dataset, 2)
+
+    for i, x in enumerate(loader):
+        assert mx.test_utils.almost_equal(x.asnumpy(), X[i*2:(i+1)*2])
+
 
 def prepare_record():
     if not os.path.isdir("data/test_images"):

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].