You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2020/09/17 17:03:43 UTC

[incubator-mxnet] branch master updated: SymbolBlock.imports ignore_extra & allow_missing (#19157)

This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 30ae04a  SymbolBlock.imports ignore_extra & allow_missing (#19157)
30ae04a is described below

commit 30ae04a54ceee244f2e1eba48e9ee867954cd0fe
Author: Sam Skalicky <sa...@gmail.com>
AuthorDate: Thu Sep 17 10:00:27 2020 -0700

    SymbolBlock.imports ignore_extra & allow_missing (#19157)
---
 python/mxnet/gluon/block.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 8fd7dd3..d430aee 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -1516,7 +1516,8 @@ class SymbolBlock(HybridBlock):
     >>> print(feat_model(x))
     """
     @staticmethod
-    def imports(symbol_file, input_names, param_file=None, ctx=None):
+    def imports(symbol_file, input_names, param_file=None, ctx=None, allow_missing=False,
+                ignore_extra=False):
         """Import model previously saved by `gluon.HybridBlock.export`
         as a `gluon.SymbolBlock` for use in Gluon.
 
@@ -1530,6 +1531,11 @@ class SymbolBlock(HybridBlock):
             Path to parameter file.
         ctx : Context, default None
             The context to initialize `gluon.SymbolBlock` on.
+        allow_missing : bool, default False
+            Whether to silently skip loading parameters not represents in the file.
+        ignore_extra : bool, default False
+            Whether to silently ignore parameters from the file that are not
+            present in this Block.
 
         Returns
         -------
@@ -1562,7 +1568,7 @@ class SymbolBlock(HybridBlock):
             inputs = [symbol.var(i).as_np_ndarray() if is_np_array() else symbol.var(i) for i in input_names]
         ret = SymbolBlock(sym, inputs)
         if param_file is not None:
-            ret.load_parameters(param_file, ctx=ctx, cast_dtype=True, dtype_source='saved')
+            ret.load_parameters(param_file, ctx, allow_missing, ignore_extra, True, 'saved')
         return ret
 
     def __repr__(self):