You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/06/20 20:04:13 UTC
[incubator-mxnet] branch master updated: Showing proper error when
csr array is not 2D in shape. (#15242)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 2de0db0 Showing proper error when csr array is not 2D in shape. (#15242)
2de0db0 is described below
commit 2de0db0911f2e71728fa85ab342bd99a10974fc9
Author: Piyush Ghai <gh...@osu.edu>
AuthorDate: Thu Jun 20 13:03:48 2019 -0700
Showing proper error when csr array is not 2D in shape. (#15242)
* Showing proper error when csr array is not 2D in shape.
* Fixed failing CI
* Nudge to CI
---
python/mxnet/ndarray/ndarray.py | 4 ++++
tests/python/unittest/test_sparse_ndarray.py | 5 +++++
2 files changed, 9 insertions(+)
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 7e21dae..3fb1af6 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -2227,6 +2227,10 @@ fixed-size items.
NDArray, CSRNDArray or RowSparseNDArray
A copy of the array with the chosen storage stype
"""
+ if stype == 'csr' and len(self.shape) != 2:
+ raise ValueError("To convert to a CSR, the NDArray should be 2 Dimensional. Current "
+ "shape is %s" % str(self.shape))
+
return op.cast_storage(self, stype=stype)
def to_dlpack_for_read(self):
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index 3b4c684..9a1fce4 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -963,6 +963,11 @@ def test_sparse_nd_check_format():
indptr_list = [0, -2, 2, 3]
a = mx.nd.sparse.csr_matrix((data_list, indices_list, indptr_list), shape=shape)
assertRaises(mx.base.MXNetError, a.check_format)
+ # CSR format should be 2 Dimensional.
+ a = mx.nd.array([1, 2, 3])
+ assertRaises(ValueError, a.tostype, 'csr')
+ a = mx.nd.array([[[1, 2, 3]]])
+ assertRaises(ValueError, a.tostype, 'csr')
# Row Sparse format indices should be less than the number of rows
shape = (3, 2)
data_list = [[1, 2], [3, 4]]