You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/05/20 05:19:12 UTC
[incubator-mxnet] branch master updated: [MXNET-1403] Disable
numpy's writability of NDArray once it is zero-copied to MXNet (#14948)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 96b1cde [MXNET-1403] Disable numpy's writability of NDArray once it is zero-copied to MXNet (#14948)
96b1cde is described below
commit 96b1cde15fc6b7492fed44d941b39e5d97d0022f
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Sun May 19 22:18:36 2019 -0700
[MXNET-1403] Disable numpy's writability of NDArray once it is zero-copied to MXNet (#14948)
* Initial commit
* update
* Update test_ndarray.py
* Retrigger
---
python/mxnet/ndarray/ndarray.py | 8 +++++++-
tests/python/unittest/test_ndarray.py | 4 ++--
2 files changed, 9 insertions(+), 3 deletions(-)
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 1c18273..2325890 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -4212,7 +4212,12 @@ def dl_managed_tensor_deleter(dl_managed_tensor_handle):
def from_numpy(ndarray, zero_copy=True):
- """Returns an MXNet's NDArray backed by Numpy's ndarray.
+ """Returns an MXNet's ndarray backed by numpy's ndarray.
+ When `zero_copy` is set to be true,
+ this API consumes numpy's ndarray and produces MXNet's ndarray
+ without having to copy the content. In this case, we disallow
+ users to modify the given numpy ndarray, and it is suggested
+ not to read the numpy ndarray as well for internal correctness.
Parameters
----------
@@ -4261,6 +4266,7 @@ def from_numpy(ndarray, zero_copy=True):
if not ndarray.flags['C_CONTIGUOUS']:
raise ValueError("Only c-contiguous arrays are supported for zero-copy")
+ ndarray.flags['WRITEABLE'] = False
c_obj = _make_dl_managed_tensor(ndarray)
address = ctypes.addressof(c_obj)
address = ctypes.cast(address, ctypes.c_void_p)
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index c62bd19..df50543 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -1687,8 +1687,8 @@ def test_zero_from_numpy():
mx.test_utils.assert_almost_equal(np_array, mx_array.asnumpy())
np_array = arrays[0]
mx_array = mx.nd.from_numpy(np_array)
- np_array[2, 1] = 0
- mx.test_utils.assert_almost_equal(np_array, mx_array.asnumpy())
+ assertRaises(ValueError, np_array.__setitem__, (2, 1), 0)
+
mx_array[2, 1] = 100
mx.test_utils.assert_almost_equal(np_array, mx_array.asnumpy())
np_array = np.array([[1, 2], [3, 4], [5, 6]]).transpose()