You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/11 19:56:37 UTC

[GitHub] piiswrong closed pull request #8612: generalize array dataset

piiswrong closed pull request #8612: generalize array dataset
URL: https://github.com/apache/incubator-mxnet/pull/8612
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py
index cbc73bc401..059c2a61c7 100644
--- a/python/mxnet/gluon/data/dataset.py
+++ b/python/mxnet/gluon/data/dataset.py
@@ -40,30 +40,35 @@ def __len__(self):
 
 
 class ArrayDataset(Dataset):
-    """A dataset with a data array and a label array.
+    """A dataset of multiple arrays.
 
-    The i-th sample is `(data[i], lable[i])`.
+    The i-th sample is `(x1[i], x2[i], ...)`.
 
     Parameters
     ----------
-    data : array-like object
-        The data array. Can be mxnet or numpy array.
-    label : array-like object
-        The label array. Can be mxnet or numpy array.
+    *args : one or more arrays
+        The data arrays.
     """
-    def __init__(self, data, label):
-        assert len(data) == len(label)
-        self._data = data
-        if isinstance(label, ndarray.NDArray) and len(label.shape) == 1:
-            self._label = label.asnumpy()
-        else:
-            self._label = label
+    def __init__(self, *args):
+        assert len(args) > 0, "Needs at least 1 arrays"
+        self._length = len(args[0])
+        self._data = []
+        for i, data in enumerate(args):
+            assert len(data) == self._length, \
+                "All arrays must have the same length. But the first has %s " \
+                "while the %d-th has %d."%(length, i+1, len(data))
+            if isinstance(data, ndarray.NDArray) and len(data.shape) == 1:
+                data = data.asnumpy()
+            self._data.append(data)
 
     def __getitem__(self, idx):
-        return self._data[idx], self._label[idx]
+        if len(self._data) == 1:
+            return self._data[0][idx]
+        else:
+            return tuple(data[idx] for data in self._data)
 
     def __len__(self):
-        return len(self._data)
+        return self._length
 
 
 class RecordFileDataset(Dataset):
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index 341e6d1716..397fbbd33e 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -26,11 +26,16 @@ def test_array_dataset():
     Y = np.random.uniform(size=(10,))
     dataset = gluon.data.ArrayDataset(X, Y)
     loader = gluon.data.DataLoader(dataset, 2)
-
     for i, (x, y) in enumerate(loader):
         assert mx.test_utils.almost_equal(x.asnumpy(), X[i*2:(i+1)*2])
         assert mx.test_utils.almost_equal(y.asnumpy(), Y[i*2:(i+1)*2])
 
+    dataset = gluon.data.ArrayDataset(X)
+    loader = gluon.data.DataLoader(dataset, 2)
+
+    for i, x in enumerate(loader):
+        assert mx.test_utils.almost_equal(x.asnumpy(), X[i*2:(i+1)*2])
+
 
 def prepare_record():
     if not os.path.isdir("data/test_images"):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services