You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2017/12/30 18:47:19 UTC
[incubator-mxnet] branch master updated: raise err in
io.NDArrayIter for invalid usecase (#9228)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new e3e1f23 raise err in io.NDArrayIter for invalid usecase (#9228)
e3e1f23 is described below
commit e3e1f235ba72ad903cb6201ac71d281f597fb2cf
Author: Ziyue Huang <zy...@gmail.com>
AuthorDate: Sun Dec 31 02:47:08 2017 +0800
raise err in io.NDArrayIter for invalid usecase (#9228)
* raise err in io.NDArrayIter for invalid usecase
* address comments
---
python/mxnet/io.py | 27 ++++++++++++---------------
tests/python/unittest/test_io.py | 11 ++++++++---
2 files changed, 20 insertions(+), 18 deletions(-)
diff --git a/python/mxnet/io.py b/python/mxnet/io.py
index 25a95be..b07f7c1 100644
--- a/python/mxnet/io.py
+++ b/python/mxnet/io.py
@@ -515,17 +515,13 @@ def _init_data(data, allow_empty, default_name):
return list(data.items())
def _has_instance(data, dtype):
- """return True if data has instance of dtype"""
- if isinstance(data, dtype):
- return True
- if isinstance(data, list):
- for v in data:
- if isinstance(v, dtype):
- return True
- if isinstance(data, dict):
- for v in data.values():
- if isinstance(v, dtype):
- return True
+ """Return True if ``data`` has instance of ``dtype``.
+ This function is called after _init_data.
+ ``data`` is a list of (str, NDArray)"""
+ for item in data:
+ _, arr = item
+ if isinstance(arr, dtype):
+ return True
return False
def _shuffle(data, idx):
@@ -544,7 +540,7 @@ def _shuffle(data, idx):
class NDArrayIter(DataIter):
"""Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray``, ``h5py.Dataset``
- or ``mx.nd.sparse.CSRNDArray``.
+ ``mx.nd.sparse.CSRNDArray`` or ``scipy.sparse.csr_matrix``.
Example usage:
----------
@@ -644,12 +640,13 @@ class NDArrayIter(DataIter):
label_name='softmax_label'):
super(NDArrayIter, self).__init__(batch_size)
- if ((_has_instance(data, CSRNDArray) or _has_instance(label, CSRNDArray)) and
+ self.data = _init_data(data, allow_empty=False, default_name=data_name)
+ self.label = _init_data(label, allow_empty=True, default_name=label_name)
+
+ if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, CSRNDArray)) and
(last_batch_handle != 'discard')):
raise NotImplementedError("`NDArrayIter` only supports ``CSRNDArray``" \
" with `last_batch_handle` set to `discard`.")
- self.data = _init_data(data, allow_empty=False, default_name=data_name)
- self.label = _init_data(label, allow_empty=True, default_name=label_name)
self.idx = np.arange(self.data[0][1].shape[0])
# shuffle data
diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py
index fa314e0..fd3c964 100644
--- a/tests/python/unittest/test_io.py
+++ b/tests/python/unittest/test_io.py
@@ -162,9 +162,14 @@ def test_NDArrayIter_csr():
csr, _ = rand_sparse_ndarray(shape, 'csr')
dns = csr.asnumpy()
- # CSRNDArray with last_batch_handle not equal to 'discard' will throw NotImplementedError
- assertRaises(NotImplementedError, mx.io.NDArrayIter, {'data': csr}, dns, batch_size,
- last_batch_handle='pad')
+ # CSRNDArray or scipy.sparse.csr_matrix with last_batch_handle not equal to 'discard' will throw NotImplementedError
+ assertRaises(NotImplementedError, mx.io.NDArrayIter, {'data': csr}, dns, batch_size)
+ try:
+ import scipy.sparse as spsp
+ train_data = spsp.csr_matrix(dns)
+ assertRaises(NotImplementedError, mx.io.NDArrayIter, {'data': train_data}, dns, batch_size)
+ except ImportError:
+ pass
# CSRNDArray with shuffle
csr_iter = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns}, dns, batch_size,
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].