You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/01/15 06:41:19 UTC
[incubator-mxnet] branch master updated: move concurrent/identity
blocks to contrib (#9427)
This is an automated email from the ASF dual-hosted git repository.
zhreshold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 6dcb0eb move concurrent/identity blocks to contrib (#9427)
6dcb0eb is described below
commit 6dcb0ebb39558d83df4d9fd0338a46f775b94bd6
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Sun Jan 14 22:41:16 2018 -0800
move concurrent/identity blocks to contrib (#9427)
---
python/mxnet/gluon/contrib/__init__.py | 2 +
python/mxnet/gluon/contrib/{ => nn}/__init__.py | 9 ++-
.../nn/basic_layers.py} | 74 +++++++++++++++-------
python/mxnet/gluon/model_zoo/vision/densenet.py | 4 +-
python/mxnet/gluon/model_zoo/vision/inception.py | 16 ++---
python/mxnet/gluon/model_zoo/vision/squeezenet.py | 4 +-
tests/python/unittest/test_gluon_contrib.py | 35 ++++++++++
tests/python/unittest/test_gluon_model_zoo.py | 29 +--------
8 files changed, 109 insertions(+), 64 deletions(-)
diff --git a/python/mxnet/gluon/contrib/__init__.py b/python/mxnet/gluon/contrib/__init__.py
index 3f8b64b..e06438b 100644
--- a/python/mxnet/gluon/contrib/__init__.py
+++ b/python/mxnet/gluon/contrib/__init__.py
@@ -18,4 +18,6 @@
# coding: utf-8
"""Contrib neural network module."""
+from . import nn
+
from . import rnn
diff --git a/python/mxnet/gluon/contrib/__init__.py b/python/mxnet/gluon/contrib/nn/__init__.py
similarity index 82%
copy from python/mxnet/gluon/contrib/__init__.py
copy to python/mxnet/gluon/contrib/nn/__init__.py
index 3f8b64b..62440cd 100644
--- a/python/mxnet/gluon/contrib/__init__.py
+++ b/python/mxnet/gluon/contrib/nn/__init__.py
@@ -16,6 +16,11 @@
# under the License.
# coding: utf-8
-"""Contrib neural network module."""
+# pylint: disable=wildcard-import
+"""Contrib recurrent neural network module."""
-from . import rnn
+from . import basic_layers
+
+from .basic_layers import *
+
+__all__ = basic_layers.__all__
diff --git a/python/mxnet/gluon/model_zoo/custom_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py
similarity index 53%
rename from python/mxnet/gluon/model_zoo/custom_layers.py
rename to python/mxnet/gluon/contrib/nn/basic_layers.py
index 8c481b3..8870888 100644
--- a/python/mxnet/gluon/model_zoo/custom_layers.py
+++ b/python/mxnet/gluon/contrib/nn/basic_layers.py
@@ -18,52 +18,82 @@
# coding: utf-8
# pylint: disable= arguments-differ
"""Custom neural network layers in model_zoo."""
-__all__ = ['HybridConcurrent', 'Identity']
+__all__ = ['Concurrent', 'HybridConcurrent', 'Identity']
-from ..block import Block, HybridBlock
-from ..utils import _indent
+from .... import nd
+from ...block import HybridBlock
+from ...nn import Sequential, HybridSequential
-class HybridConcurrent(HybridBlock):
+class Concurrent(Sequential):
+ """Lays `Block`s concurrently.
+
+ This block feeds its input to all children blocks, and
+ produce the output by concatenating all the children blocks' outputs
+ on the specified axis.
+
+ Example::
+
+ net = Concurrent()
+ # use net's name_scope to give children blocks appropriate names.
+ with net.name_scope():
+ net.add(nn.Dense(10, activation='relu'))
+ net.add(nn.Dense(20))
+ net.add(Identity())
+
+ Parameters
+ ----------
+ axis : int, default -1
+ The axis on which to concatenate the outputs.
+ """
+ def __init__(self, axis=-1, prefix=None, params=None):
+ super(Concurrent, self).__init__(prefix=prefix, params=params)
+ self.axis = axis
+
+ def forward(self, x):
+ out = []
+ for block in self._children:
+ out.append(block(x))
+ out = nd.concat(*out, dim=self.axis)
+ return out
+
+
+class HybridConcurrent(HybridSequential):
"""Lays `HybridBlock`s concurrently.
+ This block feeds its input to all children blocks, and
+ produce the output by concatenating all the children blocks' outputs
+ on the specified axis.
+
Example::
net = HybridConcurrent()
- # use net's name_scope to give child Blocks appropriate names.
+ # use net's name_scope to give children blocks appropriate names.
with net.name_scope():
net.add(nn.Dense(10, activation='relu'))
net.add(nn.Dense(20))
net.add(Identity())
+
+ Parameters
+ ----------
+ axis : int, default -1
+ The axis on which to concatenate the outputs.
"""
- def __init__(self, concat_dim, prefix=None, params=None):
+ def __init__(self, axis=-1, prefix=None, params=None):
super(HybridConcurrent, self).__init__(prefix=prefix, params=params)
- self.concat_dim = concat_dim
-
- def add(self, block):
- """Adds block on top of the stack."""
- self.register_child(block)
+ self.axis = axis
def hybrid_forward(self, F, x):
out = []
for block in self._children:
out.append(block(x))
- out = F.concat(*out, dim=self.concat_dim)
+ out = F.concat(*out, dim=self.axis)
return out
- def __repr__(self):
- s = '{name}(\n{modstr}\n)'
- modstr = '\n'.join([' ({key}): {block}'.format(key=key,
- block=_indent(block.__repr__(), 2))
- for key, block in enumerate(self._children)
- if isinstance(block, Block)])
- return s.format(name=self.__class__.__name__,
- modstr=modstr)
-
class Identity(HybridBlock):
"""Block that passes through the input directly.
- This layer is often used in conjunction with HybridConcurrent
+ This block can be used in conjunction with HybridConcurrent
block for residual connection.
Example::
diff --git a/python/mxnet/gluon/model_zoo/vision/densenet.py b/python/mxnet/gluon/model_zoo/vision/densenet.py
index 37a91e6..8353367 100644
--- a/python/mxnet/gluon/model_zoo/vision/densenet.py
+++ b/python/mxnet/gluon/model_zoo/vision/densenet.py
@@ -25,7 +25,7 @@ import os
from ....context import cpu
from ...block import HybridBlock
from ... import nn
-from ..custom_layers import HybridConcurrent, Identity
+from ...contrib.nn import HybridConcurrent, Identity
# Helpers
def _make_dense_block(num_layers, bn_size, growth_rate, dropout, stage_index):
@@ -46,7 +46,7 @@ def _make_dense_layer(growth_rate, bn_size, dropout):
if dropout:
new_features.add(nn.Dropout(dropout))
- out = HybridConcurrent(concat_dim=1, prefix='')
+ out = HybridConcurrent(axis=1, prefix='')
out.add(Identity())
out.add(new_features)
diff --git a/python/mxnet/gluon/model_zoo/vision/inception.py b/python/mxnet/gluon/model_zoo/vision/inception.py
index 42f0d3d..6d75050 100644
--- a/python/mxnet/gluon/model_zoo/vision/inception.py
+++ b/python/mxnet/gluon/model_zoo/vision/inception.py
@@ -25,7 +25,7 @@ import os
from ....context import cpu
from ...block import HybridBlock
from ... import nn
-from ..custom_layers import HybridConcurrent
+from ...contrib.nn import HybridConcurrent
# Helpers
def _make_basic_conv(**kwargs):
@@ -51,7 +51,7 @@ def _make_branch(use_pool, *conv_settings):
return out
def _make_A(pool_features, prefix):
- out = HybridConcurrent(concat_dim=1, prefix=prefix)
+ out = HybridConcurrent(axis=1, prefix=prefix)
with out.name_scope():
out.add(_make_branch(None,
(64, 1, None, None)))
@@ -67,7 +67,7 @@ def _make_A(pool_features, prefix):
return out
def _make_B(prefix):
- out = HybridConcurrent(concat_dim=1, prefix=prefix)
+ out = HybridConcurrent(axis=1, prefix=prefix)
with out.name_scope():
out.add(_make_branch(None,
(384, 3, 2, None)))
@@ -79,7 +79,7 @@ def _make_B(prefix):
return out
def _make_C(channels_7x7, prefix):
- out = HybridConcurrent(concat_dim=1, prefix=prefix)
+ out = HybridConcurrent(axis=1, prefix=prefix)
with out.name_scope():
out.add(_make_branch(None,
(192, 1, None, None)))
@@ -98,7 +98,7 @@ def _make_C(channels_7x7, prefix):
return out
def _make_D(prefix):
- out = HybridConcurrent(concat_dim=1, prefix=prefix)
+ out = HybridConcurrent(axis=1, prefix=prefix)
with out.name_scope():
out.add(_make_branch(None,
(192, 1, None, None),
@@ -112,7 +112,7 @@ def _make_D(prefix):
return out
def _make_E(prefix):
- out = HybridConcurrent(concat_dim=1, prefix=prefix)
+ out = HybridConcurrent(axis=1, prefix=prefix)
with out.name_scope():
out.add(_make_branch(None,
(320, 1, None, None)))
@@ -121,7 +121,7 @@ def _make_E(prefix):
out.add(branch_3x3)
branch_3x3.add(_make_branch(None,
(384, 1, None, None)))
- branch_3x3_split = HybridConcurrent(concat_dim=1, prefix='')
+ branch_3x3_split = HybridConcurrent(axis=1, prefix='')
branch_3x3_split.add(_make_branch(None,
(384, (1, 3), None, (0, 1))))
branch_3x3_split.add(_make_branch(None,
@@ -133,7 +133,7 @@ def _make_E(prefix):
branch_3x3dbl.add(_make_branch(None,
(448, 1, None, None),
(384, 3, None, 1)))
- branch_3x3dbl_split = HybridConcurrent(concat_dim=1, prefix='')
+ branch_3x3dbl_split = HybridConcurrent(axis=1, prefix='')
branch_3x3dbl.add(branch_3x3dbl_split)
branch_3x3dbl_split.add(_make_branch(None,
(384, (1, 3), None, (0, 1))))
diff --git a/python/mxnet/gluon/model_zoo/vision/squeezenet.py b/python/mxnet/gluon/model_zoo/vision/squeezenet.py
index 7eff102..09f62a5 100644
--- a/python/mxnet/gluon/model_zoo/vision/squeezenet.py
+++ b/python/mxnet/gluon/model_zoo/vision/squeezenet.py
@@ -25,14 +25,14 @@ import os
from ....context import cpu
from ...block import HybridBlock
from ... import nn
-from ..custom_layers import HybridConcurrent
+from ...contrib.nn import HybridConcurrent
# Helpers
def _make_fire(squeeze_channels, expand1x1_channels, expand3x3_channels):
out = nn.HybridSequential(prefix='')
out.add(_make_fire_conv(squeeze_channels, 1))
- paths = HybridConcurrent(concat_dim=1, prefix='')
+ paths = HybridConcurrent(axis=1, prefix='')
paths.add(_make_fire_conv(expand1x1_channels, 1))
paths.add(_make_fire_conv(expand3x3_channels, 3, 1))
out.add(paths)
diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py
index 07b8956..1a188c3 100644
--- a/tests/python/unittest/test_gluon_contrib.py
+++ b/tests/python/unittest/test_gluon_contrib.py
@@ -18,6 +18,8 @@
from __future__ import print_function
import mxnet as mx
from mxnet.gluon import contrib
+from mxnet.gluon import nn
+from mxnet.gluon.contrib.nn import Concurrent, HybridConcurrent, Identity
from mxnet.test_utils import almost_equal
import numpy as np
from numpy.testing import assert_allclose
@@ -138,6 +140,39 @@ def test_vardrop():
check_vardrop(0.5, 0, 0.5)
+def test_concurrent():
+ model = HybridConcurrent(axis=1)
+ model.add(nn.Dense(128, activation='tanh', in_units=10))
+ model.add(nn.Dense(64, activation='tanh', in_units=10))
+ model.add(nn.Dense(32, in_units=10))
+ model2 = Concurrent(axis=1)
+ model2.add(nn.Dense(128, activation='tanh', in_units=10))
+ model2.add(nn.Dense(64, activation='tanh', in_units=10))
+ model2.add(nn.Dense(32, in_units=10))
+
+ # symbol
+ x = mx.sym.var('data')
+ y = model(x)
+ assert len(y.list_arguments()) == 7
+
+ # ndarray
+ model.initialize(mx.init.Xavier(magnitude=2.24))
+ model2.initialize(mx.init.Xavier(magnitude=2.24))
+ x = model(mx.nd.zeros((32, 10)))
+ x2 = model2(mx.nd.zeros((32, 10)))
+ assert x.shape == (32, 224)
+ assert x2.shape == (32, 224)
+ x.wait_to_read()
+ x2.wait_to_read()
+
+
+def test_identity():
+ model = Identity()
+ x = mx.nd.random.uniform(shape=(128, 33, 64))
+ mx.test_utils.assert_almost_equal(model(x).asnumpy(),
+ x.asnumpy())
+
+
if __name__ == '__main__':
import nose
nose.runmodule()
diff --git a/tests/python/unittest/test_gluon_model_zoo.py b/tests/python/unittest/test_gluon_model_zoo.py
index 39d3b19..022f758 100644
--- a/tests/python/unittest/test_gluon_model_zoo.py
+++ b/tests/python/unittest/test_gluon_model_zoo.py
@@ -17,39 +17,12 @@
from __future__ import print_function
import mxnet as mx
-from mxnet.gluon import nn
-from mxnet.gluon.model_zoo.custom_layers import HybridConcurrent, Identity
from mxnet.gluon.model_zoo.vision import get_model
import sys
def eprint(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)
-def test_concurrent():
- model = HybridConcurrent(concat_dim=1)
- model.add(nn.Dense(128, activation='tanh', in_units=10))
- model.add(nn.Dense(64, activation='tanh', in_units=10))
- model.add(nn.Dense(32, in_units=10))
-
- # symbol
- x = mx.sym.var('data')
- y = model(x)
- assert len(y.list_arguments()) == 7
-
- # ndarray
- model.collect_params().initialize(mx.init.Xavier(magnitude=2.24))
- x = model(mx.nd.zeros((32, 10)))
- assert x.shape == (32, 224)
- x.wait_to_read()
-
-
-def test_identity():
- model = Identity()
- x = mx.nd.random.uniform(shape=(128, 33, 64))
- mx.test_utils.assert_almost_equal(model(x).asnumpy(),
- x.asnumpy())
-
-
def test_models():
all_models = ['resnet18_v1', 'resnet34_v1', 'resnet50_v1', 'resnet101_v1', 'resnet152_v1',
'resnet18_v2', 'resnet34_v2', 'resnet50_v2', 'resnet101_v2', 'resnet152_v2',
@@ -62,7 +35,7 @@ def test_models():
pretrained_to_test = set(['squeezenet1.1'])
for model_name in all_models:
- test_pretrain = model_name in pretrained_to_test
+ test_pretrain = True #model_name in pretrained_to_test
model = get_model(model_name, pretrained=test_pretrain, root='model/')
data_shape = (2, 3, 224, 224) if 'inception' not in model_name else (2, 3, 299, 299)
eprint('testing forward for %s'%model_name)
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].