You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/06/07 05:21:05 UTC
[incubator-mxnet] branch master updated: add cast_dtype option to
load parameters (#15168)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 77da1e2 add cast_dtype option to load parameters (#15168)
77da1e2 is described below
commit 77da1e27a4077359c963e3c8bf57128c3bc3cb34
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Thu Jun 6 22:20:24 2019 -0700
add cast_dtype option to load parameters (#15168)
---
python/mxnet/gluon/block.py | 9 ++++++---
python/mxnet/gluon/parameter.py | 14 ++++++++++----
tests/python/unittest/test_gluon.py | 9 +++++++++
3 files changed, 25 insertions(+), 7 deletions(-)
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 20f0a32..e78d11c 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -354,7 +354,7 @@ class Block(object):
'save_parameters may resolve this error.'%e.message)
def load_parameters(self, filename, ctx=None, allow_missing=False,
- ignore_extra=False):
+ ignore_extra=False, cast_dtype=False):
"""Load parameters from file previously saved by `save_parameters`.
Parameters
@@ -368,6 +368,9 @@ class Block(object):
ignore_extra : bool, default False
Whether to silently ignore parameters from the file that are not
present in this Block.
+ cast_dtype : bool, default False
+ Cast the data type of the NDArray loaded from the checkpoint to the dtype
+ provided by the Parameter if any.
References
----------
@@ -383,7 +386,7 @@ class Block(object):
# legacy loading
del loaded
self.collect_params().load(
- filename, ctx, allow_missing, ignore_extra, self.prefix)
+ filename, ctx, allow_missing, ignore_extra, self.prefix, cast_dtype=cast_dtype)
return
if not allow_missing:
@@ -399,7 +402,7 @@ class Block(object):
"which contains parameters %s. Set ignore_extra=True to ignore. "%(
name, filename, _brief_print_list(self._params.keys())))
if name in params:
- params[name]._load_init(loaded[name], ctx)
+ params[name]._load_init(loaded[name], ctx, cast_dtype=cast_dtype)
def load_params(self, filename, ctx=None, allow_missing=False,
ignore_extra=False):
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 83edbaf..19f3883 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -241,7 +241,7 @@ class Parameter(object):
self._trainer._row_sparse_pull(self, results, row_id)
return results
- def _load_init(self, data, ctx):
+ def _load_init(self, data, ctx, cast_dtype=False):
"""(Re)initializes by loading from data."""
if self.shape:
for self_dim, data_dim in zip(self.shape, data.shape):
@@ -251,9 +251,12 @@ class Parameter(object):
self.name, str(self.shape), str(data.shape))
self.shape = tuple(i if i != 0 else j for i, j in zip(self.shape, data.shape))
if self.dtype:
+ if cast_dtype and np.dtype(self.dtype).type != data.dtype:
+ data = data.astype(self.dtype, copy=False)
assert np.dtype(self.dtype).type == data.dtype, \
"Failed loading Parameter '%s' from saved params: " \
- "dtype incompatible expected %s vs saved %s"%(
+ "dtype incompatible expected %s vs saved %s. " \
+ "Set cast_dtype=True to cast the dtype of saved params."%(
self.name, str(self.dtype), str(data.dtype))
if self._stype != data.stype:
data = data.tostype(self._stype)
@@ -891,7 +894,7 @@ class ParameterDict(object):
ndarray.save(filename, arg_dict)
def load(self, filename, ctx=None, allow_missing=False,
- ignore_extra=False, restore_prefix=''):
+ ignore_extra=False, restore_prefix='', cast_dtype=False):
"""Load parameters from file.
Parameters
@@ -907,6 +910,9 @@ class ParameterDict(object):
present in this ParameterDict.
restore_prefix : str, default ''
prepend prefix to names of stored parameters before loading.
+ cast_dtype : bool, default False
+ Cast the data type of the NDArray loaded from the checkpoint to the dtype
+ provided by the Parameter if any.
"""
if restore_prefix:
for name in self.keys():
@@ -932,4 +938,4 @@ class ParameterDict(object):
"Please make sure source and target networks have the same prefix."%(
name[lprefix:], filename, _brief_print_list(self._params.keys()))
continue
- self[name]._load_init(arg_dict[name], ctx)
+ self[name]._load_init(arg_dict[name], ctx, cast_dtype=cast_dtype)
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 08e4c52..557e817 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -2742,6 +2742,15 @@ def test_np_shape_parameters():
foo.initialize()
print(foo(z).shape)
+@with_seed()
+def test_gluon_param_load():
+ net = mx.gluon.nn.Dense(10, in_units=10)
+ net.initialize()
+ net.save_parameters('test_gluon_param_load.params')
+ net.cast('float16')
+ net.load_parameters('test_gluon_param_load.params', cast_dtype=True)
+ mx.nd.waitall()
+
if __name__ == '__main__':
import nose
nose.runmodule()