You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2017/12/30 18:47:19 UTC

[incubator-mxnet] branch master updated: raise err in io.NDArrayIter for invalid usecase (#9228)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new e3e1f23  raise err in io.NDArrayIter for invalid usecase (#9228)
e3e1f23 is described below

commit e3e1f235ba72ad903cb6201ac71d281f597fb2cf
Author: Ziyue Huang <zy...@gmail.com>
AuthorDate: Sun Dec 31 02:47:08 2017 +0800

    raise err in io.NDArrayIter for invalid usecase (#9228)
    
    * raise err in io.NDArrayIter for invalid usecase
    
    * address comments
---
 python/mxnet/io.py               | 27 ++++++++++++---------------
 tests/python/unittest/test_io.py | 11 ++++++++---
 2 files changed, 20 insertions(+), 18 deletions(-)

diff --git a/python/mxnet/io.py b/python/mxnet/io.py
index 25a95be..b07f7c1 100644
--- a/python/mxnet/io.py
+++ b/python/mxnet/io.py
@@ -515,17 +515,13 @@ def _init_data(data, allow_empty, default_name):
     return list(data.items())
 
 def _has_instance(data, dtype):
-    """return True if data has instance of dtype"""
-    if isinstance(data, dtype):
-        return True
-    if isinstance(data, list):
-        for v in data:
-            if isinstance(v, dtype):
-                return True
-    if isinstance(data, dict):
-        for v in data.values():
-            if isinstance(v, dtype):
-                return True
+    """Return True if ``data`` has instance of ``dtype``.
+    This function is called after _init_data.
+    ``data`` is a list of (str, NDArray)"""
+    for item in data:
+        _, arr = item
+        if isinstance(arr, dtype):
+            return True
     return False
 
 def _shuffle(data, idx):
@@ -544,7 +540,7 @@ def _shuffle(data, idx):
 
 class NDArrayIter(DataIter):
     """Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray``, ``h5py.Dataset``
-    or ``mx.nd.sparse.CSRNDArray``.
+    ``mx.nd.sparse.CSRNDArray`` or ``scipy.sparse.csr_matrix``.
 
     Example usage:
     ----------
@@ -644,12 +640,13 @@ class NDArrayIter(DataIter):
                  label_name='softmax_label'):
         super(NDArrayIter, self).__init__(batch_size)
 
-        if ((_has_instance(data, CSRNDArray) or _has_instance(label, CSRNDArray)) and
+        self.data = _init_data(data, allow_empty=False, default_name=data_name)
+        self.label = _init_data(label, allow_empty=True, default_name=label_name)
+
+        if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, CSRNDArray)) and
                 (last_batch_handle != 'discard')):
             raise NotImplementedError("`NDArrayIter` only supports ``CSRNDArray``" \
                                       " with `last_batch_handle` set to `discard`.")
-        self.data = _init_data(data, allow_empty=False, default_name=data_name)
-        self.label = _init_data(label, allow_empty=True, default_name=label_name)
 
         self.idx = np.arange(self.data[0][1].shape[0])
         # shuffle data
diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py
index fa314e0..fd3c964 100644
--- a/tests/python/unittest/test_io.py
+++ b/tests/python/unittest/test_io.py
@@ -162,9 +162,14 @@ def test_NDArrayIter_csr():
     csr, _ = rand_sparse_ndarray(shape, 'csr')
     dns = csr.asnumpy()
 
-    # CSRNDArray with last_batch_handle not equal to 'discard' will throw NotImplementedError
-    assertRaises(NotImplementedError, mx.io.NDArrayIter, {'data': csr}, dns, batch_size,
-                 last_batch_handle='pad')
+    # CSRNDArray or scipy.sparse.csr_matrix with last_batch_handle not equal to 'discard' will throw NotImplementedError
+    assertRaises(NotImplementedError, mx.io.NDArrayIter, {'data': csr}, dns, batch_size)
+    try:
+        import scipy.sparse as spsp
+        train_data = spsp.csr_matrix(dns)
+        assertRaises(NotImplementedError, mx.io.NDArrayIter, {'data': train_data}, dns, batch_size)
+    except ImportError:
+        pass
 
     # CSRNDArray with shuffle
     csr_iter = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns}, dns, batch_size,

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].