You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2020/09/17 17:03:43 UTC
[incubator-mxnet] branch master updated: SymbolBlock.imports
ignore_extra & allow_missing (#19157)
This is an automated email from the ASF dual-hosted git repository.
lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 30ae04a SymbolBlock.imports ignore_extra & allow_missing (#19157)
30ae04a is described below
commit 30ae04a54ceee244f2e1eba48e9ee867954cd0fe
Author: Sam Skalicky <sa...@gmail.com>
AuthorDate: Thu Sep 17 10:00:27 2020 -0700
SymbolBlock.imports ignore_extra & allow_missing (#19157)
---
python/mxnet/gluon/block.py | 10 ++++++++--
1 file changed, 8 insertions(+), 2 deletions(-)
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 8fd7dd3..d430aee 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -1516,7 +1516,8 @@ class SymbolBlock(HybridBlock):
>>> print(feat_model(x))
"""
@staticmethod
- def imports(symbol_file, input_names, param_file=None, ctx=None):
+ def imports(symbol_file, input_names, param_file=None, ctx=None, allow_missing=False,
+ ignore_extra=False):
"""Import model previously saved by `gluon.HybridBlock.export`
as a `gluon.SymbolBlock` for use in Gluon.
@@ -1530,6 +1531,11 @@ class SymbolBlock(HybridBlock):
Path to parameter file.
ctx : Context, default None
The context to initialize `gluon.SymbolBlock` on.
+ allow_missing : bool, default False
+ Whether to silently skip loading parameters not represents in the file.
+ ignore_extra : bool, default False
+ Whether to silently ignore parameters from the file that are not
+ present in this Block.
Returns
-------
@@ -1562,7 +1568,7 @@ class SymbolBlock(HybridBlock):
inputs = [symbol.var(i).as_np_ndarray() if is_np_array() else symbol.var(i) for i in input_names]
ret = SymbolBlock(sym, inputs)
if param_file is not None:
- ret.load_parameters(param_file, ctx=ctx, cast_dtype=True, dtype_source='saved')
+ ret.load_parameters(param_file, ctx, allow_missing, ignore_extra, True, 'saved')
return ret
def __repr__(self):