You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/06/07 05:21:05 UTC

[incubator-mxnet] branch master updated: add cast_dtype option to load parameters (#15168)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 77da1e2  add cast_dtype option to load parameters (#15168)
77da1e2 is described below

commit 77da1e27a4077359c963e3c8bf57128c3bc3cb34
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Thu Jun 6 22:20:24 2019 -0700

    add cast_dtype option to load parameters (#15168)
---
 python/mxnet/gluon/block.py         |  9 ++++++---
 python/mxnet/gluon/parameter.py     | 14 ++++++++++----
 tests/python/unittest/test_gluon.py |  9 +++++++++
 3 files changed, 25 insertions(+), 7 deletions(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 20f0a32..e78d11c 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -354,7 +354,7 @@ class Block(object):
                               'save_parameters may resolve this error.'%e.message)
 
     def load_parameters(self, filename, ctx=None, allow_missing=False,
-                        ignore_extra=False):
+                        ignore_extra=False, cast_dtype=False):
         """Load parameters from file previously saved by `save_parameters`.
 
         Parameters
@@ -368,6 +368,9 @@ class Block(object):
         ignore_extra : bool, default False
             Whether to silently ignore parameters from the file that are not
             present in this Block.
+        cast_dtype : bool, default False
+            Cast the data type of the NDArray loaded from the checkpoint to the dtype
+            provided by the Parameter if any.
 
         References
         ----------
@@ -383,7 +386,7 @@ class Block(object):
             # legacy loading
             del loaded
             self.collect_params().load(
-                filename, ctx, allow_missing, ignore_extra, self.prefix)
+                filename, ctx, allow_missing, ignore_extra, self.prefix, cast_dtype=cast_dtype)
             return
 
         if not allow_missing:
@@ -399,7 +402,7 @@ class Block(object):
                     "which contains parameters %s. Set ignore_extra=True to ignore. "%(
                         name, filename, _brief_print_list(self._params.keys())))
             if name in params:
-                params[name]._load_init(loaded[name], ctx)
+                params[name]._load_init(loaded[name], ctx, cast_dtype=cast_dtype)
 
     def load_params(self, filename, ctx=None, allow_missing=False,
                     ignore_extra=False):
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 83edbaf..19f3883 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -241,7 +241,7 @@ class Parameter(object):
         self._trainer._row_sparse_pull(self, results, row_id)
         return results
 
-    def _load_init(self, data, ctx):
+    def _load_init(self, data, ctx, cast_dtype=False):
         """(Re)initializes by loading from data."""
         if self.shape:
             for self_dim, data_dim in zip(self.shape, data.shape):
@@ -251,9 +251,12 @@ class Parameter(object):
                         self.name, str(self.shape), str(data.shape))
             self.shape = tuple(i if i != 0 else j for i, j in zip(self.shape, data.shape))
         if self.dtype:
+            if cast_dtype and np.dtype(self.dtype).type != data.dtype:
+                data = data.astype(self.dtype, copy=False)
             assert np.dtype(self.dtype).type == data.dtype, \
                 "Failed loading Parameter '%s' from saved params: " \
-                "dtype incompatible expected %s vs saved %s"%(
+                "dtype incompatible expected %s vs saved %s. " \
+                "Set cast_dtype=True to cast the dtype of saved params."%(
                     self.name, str(self.dtype), str(data.dtype))
         if self._stype != data.stype:
             data = data.tostype(self._stype)
@@ -891,7 +894,7 @@ class ParameterDict(object):
         ndarray.save(filename, arg_dict)
 
     def load(self, filename, ctx=None, allow_missing=False,
-             ignore_extra=False, restore_prefix=''):
+             ignore_extra=False, restore_prefix='', cast_dtype=False):
         """Load parameters from file.
 
         Parameters
@@ -907,6 +910,9 @@ class ParameterDict(object):
             present in this ParameterDict.
         restore_prefix : str, default ''
             prepend prefix to names of stored parameters before loading.
+        cast_dtype : bool, default False
+            Cast the data type of the NDArray loaded from the checkpoint to the dtype
+            provided by the Parameter if any.
         """
         if restore_prefix:
             for name in self.keys():
@@ -932,4 +938,4 @@ class ParameterDict(object):
                     "Please make sure source and target networks have the same prefix."%(
                         name[lprefix:], filename, _brief_print_list(self._params.keys()))
                 continue
-            self[name]._load_init(arg_dict[name], ctx)
+            self[name]._load_init(arg_dict[name], ctx, cast_dtype=cast_dtype)
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 08e4c52..557e817 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -2742,6 +2742,15 @@ def test_np_shape_parameters():
         foo.initialize()
         print(foo(z).shape)
 
+@with_seed()
+def test_gluon_param_load():
+    net = mx.gluon.nn.Dense(10, in_units=10)
+    net.initialize()
+    net.save_parameters('test_gluon_param_load.params')
+    net.cast('float16')
+    net.load_parameters('test_gluon_param_load.params', cast_dtype=True)
+    mx.nd.waitall()
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()