You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/12 18:44:02 UTC

[incubator-mxnet] branch master updated: support regex of collect_params() (#9348)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 4600070  support regex of collect_params() (#9348)
4600070 is described below

commit 4600070cd35bf4f1f3b93f4ce349c8e34e3610ae
Author: Wei Wu <to...@users.noreply.github.com>
AuthorDate: Sat Jan 13 02:43:50 2018 +0800

    support regex of collect_params() (#9348)
    
    * support regex of collect_params()
    
    * fix pylint
    
    * change default value && make select as a sigle reg
    
    * fix if select is None, then will not do any regex matching
    
    * update regex compile && add test
    
    * Update block.py
    
    * support regex of collect_params()
    
    * fix pylint
    
    * change default value && make select as a sigle reg
    
    * fix if select is None, then will not do any regex matching
    
    * update regex compile && add test
---
 python/mxnet/gluon/block.py         | 35 ++++++++++++++++++++++++++++++-----
 tests/python/unittest/test_gluon.py | 13 ++++++++++++-
 2 files changed, 42 insertions(+), 6 deletions(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 0d49def..fd75e4b 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -22,6 +22,7 @@ __all__ = ['Block', 'HybridBlock', 'SymbolBlock']
 
 import copy
 import warnings
+import re
 
 from .. import symbol, ndarray, initializer
 from ..symbol import Symbol
@@ -227,13 +228,38 @@ class Block(object):
         children's parameters)."""
         return self._params
 
-    def collect_params(self):
+    def collect_params(self, select=None):
         """Returns a :py:class:`ParameterDict` containing this :py:class:`Block` and all of its
-        children's Parameters."""
+        children's Parameters(default), also can returns the select :py:class:`ParameterDict`
+        which match some given regular expressions.
+
+        For example, collect the specified parameter in ['conv1_weight', 'conv1_bias', 'fc_weight',
+        'fc_bias']::
+
+            model.collect_params('conv1_weight|conv1_bias|fc_weight|fc_bias')
+
+        or collect all paramters which their name ends with 'weight' or 'bias', this can be done
+        using regular expressions::
+
+            model.collect_params('.*weight|.*bias')
+
+        Parameters
+        ----------
+        select : str
+            regular expressions
+
+        Returns
+        -------
+        The selected :py:class:`ParameterDict`
+        """
         ret = ParameterDict(self._params.prefix)
-        ret.update(self.params)
+        if not select:
+            ret.update(self.params)
+        else:
+            pattern = re.compile(select)
+            ret.update({name:value for name, value in self.params.items() if pattern.match(name)})
         for cld in self._children:
-            ret.update(cld.collect_params())
+            ret.update(cld.collect_params(select=select))
         return ret
 
     def save_params(self, filename):
@@ -261,7 +287,6 @@ class Block(object):
         self.collect_params().load(filename, ctx, allow_missing, ignore_extra,
                                    self.prefix)
 
-
     def register_child(self, block):
         """Registers block as a child of self. :py:class:`Block` s assigned to self as
         attributes will be registered automatically."""
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index f2d001a..57bf5c9 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -86,7 +86,18 @@ def test_parameter_str():
     assert 'numpy.float32' in lines[1]
     assert lines[2] == ')'
 
-
+def test_collect_paramters():
+    net = nn.HybridSequential(prefix="test_")
+    with net.name_scope():
+        net.add(nn.Conv2D(10, 3))
+        net.add(nn.Dense(10, activation='relu'))
+    assert set(net.collect_params().keys()) == \
+        set(['test_conv0_weight', 'test_conv0_bias','test_dense0_weight','test_dense0_bias'])
+    assert set(net.collect_params('.*weight').keys()) == \
+        set(['test_conv0_weight', 'test_dense0_weight'])
+    assert set(net.collect_params('test_conv0_bias|test_dense0_bias').keys()) == \
+        set(['test_conv0_bias', 'test_dense0_bias'])
+        
 def test_basic():
     model = nn.Sequential()
     model.add(nn.Dense(128, activation='tanh', in_units=10, flatten=False))

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].