You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/05/15 05:31:10 UTC
[incubator-mxnet] branch master updated: fix symbolblock
save_params (#10748)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 0fb57ff fix symbolblock save_params (#10748)
0fb57ff is described below
commit 0fb57ff31ef5caa32edf973213bde8a8faba85e5
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Mon May 14 22:31:05 2018 -0700
fix symbolblock save_params (#10748)
* fix symbolblock save_params
* fix
---
python/mxnet/gluon/block.py | 14 ++++++++++++++
tests/python/unittest/test_gluon.py | 27 +++++++++++++++++++++++++++
2 files changed, 41 insertions(+)
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 7e41272..4779484 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -649,6 +649,18 @@ class HybridBlock(Block):
# pylint: disable= invalid-name
raise NotImplementedError
+def _common_prefix(names):
+ """Get the common prefix for all names"""
+ if not names:
+ return ''
+ prefix = names[0]
+ for name in names:
+ i = 0
+ while i < len(prefix) and i < len(name) and prefix[i] == name[i]:
+ i += 1
+ prefix = prefix[:i]
+ return prefix
+
class SymbolBlock(HybridBlock):
"""Construct block from symbol. This is useful for using pre-trained models
@@ -710,6 +722,8 @@ class SymbolBlock(HybridBlock):
self.params.get(i, grad_req='null', allow_deferred_init=True)
self._cached_graph = syms, out
+ len_prefix = len(_common_prefix(list(self._params.keys())))
+ self._reg_params = {key[len_prefix:]: val for key, val in self._params.items()}
def forward(self, x, *args):
if isinstance(x, NDArray):
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index b054aa6..fb73e53 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -986,6 +986,33 @@ def test_save_load():
net.load_params('test.params')
+def test_symbol_block_save_load():
+ class Net(gluon.HybridBlock):
+ def __init__(self):
+ super(Net, self).__init__()
+ with self.name_scope():
+ backbone = gluon.model_zoo.vision.resnet18_v1()
+ data = mx.sym.var('data')
+ featnames = ['stage1_activation0', 'stage2_activation0', 'stage3_activation0']
+ out_names = ['_'.join([backbone.name, featname, 'output']) for featname in featnames]
+ internals = backbone(data).get_internals()
+ outs = [internals[out_name] for out_name in out_names]
+ self.backbone = gluon.SymbolBlock(outs, data, params=backbone.collect_params())
+ self.body = nn.Conv2D(3, 1)
+
+ def hybrid_forward(self, F, x):
+ x = self.body(x)
+ return self.backbone(x)
+
+ net1 = Net()
+ net1.initialize(mx.init.Normal())
+ net1.hybridize()
+ net1(mx.nd.random.normal(shape=(1, 3, 32, 32)))
+ net1.save_params('./test.params')
+
+ net2 = Net()
+ net2.load_params('./test.params', ctx=mx.cpu())
+
def test_hybrid_multi_context():
net = mx.gluon.model_zoo.vision.get_resnet(1, 18)
--
To stop receiving notification emails like this one, please contact
jxie@apache.org.