You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/12 18:44:02 UTC
[incubator-mxnet] branch master updated: support regex of
collect_params() (#9348)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 4600070 support regex of collect_params() (#9348)
4600070 is described below
commit 4600070cd35bf4f1f3b93f4ce349c8e34e3610ae
Author: Wei Wu <to...@users.noreply.github.com>
AuthorDate: Sat Jan 13 02:43:50 2018 +0800
support regex of collect_params() (#9348)
* support regex of collect_params()
* fix pylint
* change default value && make select as a sigle reg
* fix if select is None, then will not do any regex matching
* update regex compile && add test
* Update block.py
* support regex of collect_params()
* fix pylint
* change default value && make select as a sigle reg
* fix if select is None, then will not do any regex matching
* update regex compile && add test
---
python/mxnet/gluon/block.py | 35 ++++++++++++++++++++++++++++++-----
tests/python/unittest/test_gluon.py | 13 ++++++++++++-
2 files changed, 42 insertions(+), 6 deletions(-)
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 0d49def..fd75e4b 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -22,6 +22,7 @@ __all__ = ['Block', 'HybridBlock', 'SymbolBlock']
import copy
import warnings
+import re
from .. import symbol, ndarray, initializer
from ..symbol import Symbol
@@ -227,13 +228,38 @@ class Block(object):
children's parameters)."""
return self._params
- def collect_params(self):
+ def collect_params(self, select=None):
"""Returns a :py:class:`ParameterDict` containing this :py:class:`Block` and all of its
- children's Parameters."""
+ children's Parameters(default), also can returns the select :py:class:`ParameterDict`
+ which match some given regular expressions.
+
+ For example, collect the specified parameter in ['conv1_weight', 'conv1_bias', 'fc_weight',
+ 'fc_bias']::
+
+ model.collect_params('conv1_weight|conv1_bias|fc_weight|fc_bias')
+
+ or collect all paramters which their name ends with 'weight' or 'bias', this can be done
+ using regular expressions::
+
+ model.collect_params('.*weight|.*bias')
+
+ Parameters
+ ----------
+ select : str
+ regular expressions
+
+ Returns
+ -------
+ The selected :py:class:`ParameterDict`
+ """
ret = ParameterDict(self._params.prefix)
- ret.update(self.params)
+ if not select:
+ ret.update(self.params)
+ else:
+ pattern = re.compile(select)
+ ret.update({name:value for name, value in self.params.items() if pattern.match(name)})
for cld in self._children:
- ret.update(cld.collect_params())
+ ret.update(cld.collect_params(select=select))
return ret
def save_params(self, filename):
@@ -261,7 +287,6 @@ class Block(object):
self.collect_params().load(filename, ctx, allow_missing, ignore_extra,
self.prefix)
-
def register_child(self, block):
"""Registers block as a child of self. :py:class:`Block` s assigned to self as
attributes will be registered automatically."""
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index f2d001a..57bf5c9 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -86,7 +86,18 @@ def test_parameter_str():
assert 'numpy.float32' in lines[1]
assert lines[2] == ')'
-
+def test_collect_paramters():
+ net = nn.HybridSequential(prefix="test_")
+ with net.name_scope():
+ net.add(nn.Conv2D(10, 3))
+ net.add(nn.Dense(10, activation='relu'))
+ assert set(net.collect_params().keys()) == \
+ set(['test_conv0_weight', 'test_conv0_bias','test_dense0_weight','test_dense0_bias'])
+ assert set(net.collect_params('.*weight').keys()) == \
+ set(['test_conv0_weight', 'test_dense0_weight'])
+ assert set(net.collect_params('test_conv0_bias|test_dense0_bias').keys()) == \
+ set(['test_conv0_bias', 'test_dense0_bias'])
+
def test_basic():
model = nn.Sequential()
model.add(nn.Dense(128, activation='tanh', in_units=10, flatten=False))
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].