You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/05/15 05:31:10 UTC

[incubator-mxnet] branch master updated: fix symbolblock save_params (#10748)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0fb57ff  fix symbolblock save_params (#10748)
0fb57ff is described below

commit 0fb57ff31ef5caa32edf973213bde8a8faba85e5
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Mon May 14 22:31:05 2018 -0700

    fix symbolblock save_params (#10748)
    
    * fix symbolblock save_params
    
    * fix
---
 python/mxnet/gluon/block.py         | 14 ++++++++++++++
 tests/python/unittest/test_gluon.py | 27 +++++++++++++++++++++++++++
 2 files changed, 41 insertions(+)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 7e41272..4779484 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -649,6 +649,18 @@ class HybridBlock(Block):
         # pylint: disable= invalid-name
         raise NotImplementedError
 
+def _common_prefix(names):
+    """Get the common prefix for all names"""
+    if not names:
+        return ''
+    prefix = names[0]
+    for name in names:
+        i = 0
+        while i < len(prefix) and i < len(name) and prefix[i] == name[i]:
+            i += 1
+        prefix = prefix[:i]
+    return prefix
+
 
 class SymbolBlock(HybridBlock):
     """Construct block from symbol. This is useful for using pre-trained models
@@ -710,6 +722,8 @@ class SymbolBlock(HybridBlock):
                 self.params.get(i, grad_req='null', allow_deferred_init=True)
 
         self._cached_graph = syms, out
+        len_prefix = len(_common_prefix(list(self._params.keys())))
+        self._reg_params = {key[len_prefix:]: val for key, val in self._params.items()}
 
     def forward(self, x, *args):
         if isinstance(x, NDArray):
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index b054aa6..fb73e53 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -986,6 +986,33 @@ def test_save_load():
 
     net.load_params('test.params')
 
+def test_symbol_block_save_load():
+    class Net(gluon.HybridBlock):
+        def __init__(self):
+            super(Net, self).__init__()
+            with self.name_scope():
+                backbone = gluon.model_zoo.vision.resnet18_v1()
+                data = mx.sym.var('data')
+                featnames = ['stage1_activation0', 'stage2_activation0', 'stage3_activation0']
+                out_names = ['_'.join([backbone.name, featname, 'output']) for featname in featnames]
+                internals = backbone(data).get_internals()
+                outs = [internals[out_name] for out_name in out_names]
+                self.backbone = gluon.SymbolBlock(outs, data, params=backbone.collect_params())
+                self.body = nn.Conv2D(3, 1)
+
+        def hybrid_forward(self, F, x):
+            x = self.body(x)
+            return self.backbone(x)
+
+    net1 = Net()
+    net1.initialize(mx.init.Normal())
+    net1.hybridize()
+    net1(mx.nd.random.normal(shape=(1, 3, 32, 32)))
+    net1.save_params('./test.params')
+
+    net2 = Net()
+    net2.load_params('./test.params', ctx=mx.cpu())
+
 
 def test_hybrid_multi_context():
     net = mx.gluon.model_zoo.vision.get_resnet(1, 18)

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.