You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/05/25 19:57:09 UTC
[incubator-mxnet] branch master updated: Restore save/load ndarray
to 1.4.1 (#15073)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 49ae15b Restore save/load ndarray to 1.4.1 (#15073)
49ae15b is described below
commit 49ae15bf390647f79b1ec1b5f3cedb2a8a91ddc1
Author: reminisce <wu...@gmail.com>
AuthorDate: Sat May 25 12:56:37 2019 -0700
Restore save/load ndarray to 1.4.1 (#15073)
---
src/ndarray/ndarray.cc | 11 +++++++----
tests/python/unittest/test_ndarray.py | 14 ++++++++++++++
2 files changed, 21 insertions(+), 4 deletions(-)
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 81cf844..9474d0c 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -1581,6 +1581,9 @@ static const uint32_t NDARRAY_V1_MAGIC = 0xF993fac8;
static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9;
void NDArray::Save(dmlc::Stream *strm) const {
+ // TODO(junwu): Support this after NumPy operators are merged
+ CHECK(!Imperative::Get()->is_np_comp())
+ << "Saving ndarray within the scope of np_shape is not supported.";
// write magic number to mark this version
// for storage type
strm->Write(NDARRAY_V2_MAGIC);
@@ -1698,6 +1701,9 @@ bool NDArray::LegacyLoad(dmlc::Stream *strm, const uint32_t magic) {
}
bool NDArray::Load(dmlc::Stream *strm) {
+ // TODO(junwu): Support this after NumPy operators are merged
+ CHECK(!Imperative::Get()->is_np_comp())
+ << "Loading ndarray within the scope of np_shape is not supported.";
uint32_t magic;
if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false;
if (magic != NDARRAY_V2_MAGIC) {
@@ -1718,10 +1724,7 @@ bool NDArray::Load(dmlc::Stream *strm) {
// load shape
mxnet::TShape shape;
if (!shape.Load(strm)) return false;
- if (!Imperative::Get()->is_np_comp()) {
- common::ConvertToNumpyShape(&shape);
- }
- if (mxnet::op::shape_is_none(shape)) {
+ if (shape.ndim() == 0) {
*this = NDArray(); return true;
}
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index df50543..8998b21 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -1701,6 +1701,20 @@ def test_zero_from_numpy():
assert False
+@with_seed()
+def test_save_load_zero_size_ndarrays():
+ shapes = [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)]
+ array_list = [np.random.randint(0, 10, size=shape) for shape in shapes]
+ array_list = [mx.nd.array(arr) for arr in array_list]
+ with TemporaryDirectory() as work_dir:
+ fname = os.path.join(work_dir, 'dataset')
+ mx.nd.save(fname, array_list)
+ array_list_loaded = mx.nd.load(fname)
+ assert len(array_list) == len(array_list_loaded)
+ for a1, a2 in zip(array_list, array_list_loaded):
+ assert np.array_equal(a1.asnumpy(), a2.asnumpy())
+
+
if __name__ == '__main__':
import nose
nose.runmodule()