You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/05/25 19:57:09 UTC

[incubator-mxnet] branch master updated: Restore save/load ndarray to 1.4.1 (#15073)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 49ae15b  Restore save/load ndarray to 1.4.1 (#15073)
49ae15b is described below

commit 49ae15bf390647f79b1ec1b5f3cedb2a8a91ddc1
Author: reminisce <wu...@gmail.com>
AuthorDate: Sat May 25 12:56:37 2019 -0700

    Restore save/load ndarray to 1.4.1 (#15073)
---
 src/ndarray/ndarray.cc                | 11 +++++++----
 tests/python/unittest/test_ndarray.py | 14 ++++++++++++++
 2 files changed, 21 insertions(+), 4 deletions(-)

diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 81cf844..9474d0c 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -1581,6 +1581,9 @@ static const uint32_t NDARRAY_V1_MAGIC = 0xF993fac8;
 static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9;
 
 void NDArray::Save(dmlc::Stream *strm) const {
+  // TODO(junwu): Support this after NumPy operators are merged
+  CHECK(!Imperative::Get()->is_np_comp())
+      << "Saving ndarray within the scope of np_shape is not supported.";
   // write magic number to mark this version
   // for storage type
   strm->Write(NDARRAY_V2_MAGIC);
@@ -1698,6 +1701,9 @@ bool NDArray::LegacyLoad(dmlc::Stream *strm, const uint32_t magic) {
 }
 
 bool NDArray::Load(dmlc::Stream *strm) {
+  // TODO(junwu): Support this after NumPy operators are merged
+  CHECK(!Imperative::Get()->is_np_comp())
+      << "Loading ndarray within the scope of np_shape is not supported.";
   uint32_t magic;
   if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false;
   if (magic != NDARRAY_V2_MAGIC) {
@@ -1718,10 +1724,7 @@ bool NDArray::Load(dmlc::Stream *strm) {
   // load shape
   mxnet::TShape shape;
   if (!shape.Load(strm)) return false;
-  if (!Imperative::Get()->is_np_comp()) {
-    common::ConvertToNumpyShape(&shape);
-  }
-  if (mxnet::op::shape_is_none(shape)) {
+  if (shape.ndim() == 0) {
     *this = NDArray(); return true;
   }
 
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index df50543..8998b21 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -1701,6 +1701,20 @@ def test_zero_from_numpy():
         assert False
 
 
+@with_seed()
+def test_save_load_zero_size_ndarrays():
+    shapes = [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)]
+    array_list = [np.random.randint(0, 10, size=shape) for shape in shapes]
+    array_list = [mx.nd.array(arr) for arr in array_list]
+    with TemporaryDirectory() as work_dir:
+        fname = os.path.join(work_dir, 'dataset')
+        mx.nd.save(fname, array_list)
+        array_list_loaded = mx.nd.load(fname)
+        assert len(array_list) == len(array_list_loaded)
+        for a1, a2 in zip(array_list, array_list_loaded):
+            assert np.array_equal(a1.asnumpy(), a2.asnumpy())
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()