You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/06/20 20:04:13 UTC

[incubator-mxnet] branch master updated: Showing proper error when csr array is not 2D in shape. (#15242)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 2de0db0  Showing proper error when csr array is not 2D in shape. (#15242)
2de0db0 is described below

commit 2de0db0911f2e71728fa85ab342bd99a10974fc9
Author: Piyush Ghai <gh...@osu.edu>
AuthorDate: Thu Jun 20 13:03:48 2019 -0700

    Showing proper error when csr array is not 2D in shape. (#15242)
    
    * Showing proper error when csr array is not 2D in shape.
    
    * Fixed failing CI
    
    * Nudge to CI
---
 python/mxnet/ndarray/ndarray.py              | 4 ++++
 tests/python/unittest/test_sparse_ndarray.py | 5 +++++
 2 files changed, 9 insertions(+)

diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 7e21dae..3fb1af6 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -2227,6 +2227,10 @@ fixed-size items.
         NDArray, CSRNDArray or RowSparseNDArray
             A copy of the array with the chosen storage stype
         """
+        if stype == 'csr' and len(self.shape) != 2:
+            raise ValueError("To convert to a CSR, the NDArray should be 2 Dimensional. Current "
+                             "shape is %s" % str(self.shape))
+
         return op.cast_storage(self, stype=stype)
 
     def to_dlpack_for_read(self):
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index 3b4c684..9a1fce4 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -963,6 +963,11 @@ def test_sparse_nd_check_format():
     indptr_list = [0, -2, 2, 3]
     a = mx.nd.sparse.csr_matrix((data_list, indices_list, indptr_list), shape=shape)
     assertRaises(mx.base.MXNetError, a.check_format)
+    # CSR format should be 2 Dimensional.
+    a = mx.nd.array([1, 2, 3])
+    assertRaises(ValueError, a.tostype, 'csr')
+    a = mx.nd.array([[[1, 2, 3]]])
+    assertRaises(ValueError, a.tostype, 'csr')
     # Row Sparse format indices should be less than the number of rows
     shape = (3, 2)
     data_list = [[1, 2], [3, 4]]