You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/21 21:19:37 UTC

[GitHub] hetong007 closed pull request #12101: Add gamma initialization and se module for gluon resnet model

hetong007 closed pull request #12101: Add gamma initialization and se module for gluon resnet model
URL: https://github.com/apache/incubator-mxnet/pull/12101
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/model_zoo/vision/__init__.py b/python/mxnet/gluon/model_zoo/vision/__init__.py
index 7d33ce409b2..2ae53f8b76a 100644
--- a/python/mxnet/gluon/model_zoo/vision/__init__.py
+++ b/python/mxnet/gluon/model_zoo/vision/__init__.py
@@ -119,6 +119,16 @@ def get_model(name, **kwargs):
               'resnet50_v2': resnet50_v2,
               'resnet101_v2': resnet101_v2,
               'resnet152_v2': resnet152_v2,
+              'se_resnet18_v1': se_resnet18_v1,
+              'se_resnet34_v1': se_resnet34_v1,
+              'se_resnet50_v1': se_resnet50_v1,
+              'se_resnet101_v1': se_resnet101_v1,
+              'se_resnet152_v1': se_resnet152_v1,
+              'se_resnet18_v2': se_resnet18_v2,
+              'se_resnet34_v2': se_resnet34_v2,
+              'se_resnet50_v2': se_resnet50_v2,
+              'se_resnet101_v2': se_resnet101_v2,
+              'se_resnet152_v2': se_resnet152_v2,
               'vgg11': vgg11,
               'vgg13': vgg13,
               'vgg16': vgg16,
diff --git a/python/mxnet/gluon/model_zoo/vision/resnet.py b/python/mxnet/gluon/model_zoo/vision/resnet.py
index 48390decb11..47d11011298 100644
--- a/python/mxnet/gluon/model_zoo/vision/resnet.py
+++ b/python/mxnet/gluon/model_zoo/vision/resnet.py
@@ -25,6 +25,10 @@
            'BottleneckV1', 'BottleneckV2',
            'resnet18_v1', 'resnet34_v1', 'resnet50_v1', 'resnet101_v1', 'resnet152_v1',
            'resnet18_v2', 'resnet34_v2', 'resnet50_v2', 'resnet101_v2', 'resnet152_v2',
+           'se_resnet18_v1', 'se_resnet34_v1', 'se_resnet50_v1',
+           'se_resnet101_v1', 'se_resnet152_v1',
+           'se_resnet18_v2', 'se_resnet34_v2', 'se_resnet50_v2',
+           'se_resnet101_v2', 'se_resnet152_v2',
            'get_resnet']
 
 import os
@@ -56,15 +60,33 @@ class BasicBlockV1(HybridBlock):
         Whether to downsample the input.
     in_channels : int, default 0
         Number of input channels. Default is 0, to infer from the graph.
+    last_gamma : bool, default False
+        Whether to initialize the gamma of the last BatchNorm layer in each bottleneck to zero.
+    use_se : bool, default False
+        Whether to use Squeeze-and-Excitation module
     """
-    def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs):
+    def __init__(self, channels, stride, downsample=False, in_channels=0,
+                 last_gamma=False, use_se=False, **kwargs):
         super(BasicBlockV1, self).__init__(**kwargs)
         self.body = nn.HybridSequential(prefix='')
         self.body.add(_conv3x3(channels, stride, in_channels))
         self.body.add(nn.BatchNorm())
         self.body.add(nn.Activation('relu'))
         self.body.add(_conv3x3(channels, 1, channels))
-        self.body.add(nn.BatchNorm())
+        if not last_gamma:
+            self.body.add(nn.BatchNorm())
+        else:
+            self.body.add(nn.BatchNorm(gamma_initializer='zeros'))
+
+        if use_se:
+            self.se = nn.HybridSequential(prefix='')
+            self.se.add(nn.Dense(channels // 4, use_bias=False))
+            self.se.add(nn.Activation('relu'))
+            self.se.add(nn.Dense(channels * 4, use_bias=False))
+            self.se.add(nn.Activation('sigmoid'))
+        else:
+            self.se = None
+
         if downsample:
             self.downsample = nn.HybridSequential(prefix='')
             self.downsample.add(nn.Conv2D(channels, kernel_size=1, strides=stride,
@@ -78,6 +100,11 @@ def hybrid_forward(self, F, x):
 
         x = self.body(x)
 
+        if self.se:
+            w = F.contrib.AdaptiveAvgPooling2D(x, output_size=1)
+            w = self.se(w)
+            x = F.broadcast_mul(x, w.expand_dims(axis=2).expand_dims(axis=2))
+
         if self.downsample:
             residual = self.downsample(residual)
 
@@ -101,8 +128,13 @@ class BottleneckV1(HybridBlock):
         Whether to downsample the input.
     in_channels : int, default 0
         Number of input channels. Default is 0, to infer from the graph.
+    last_gamma : bool, default False
+        Whether to initialize the gamma of the last BatchNorm layer in each bottleneck to zero.
+    use_se : bool, default False
+        Whether to use Squeeze-and-Excitation module
     """
-    def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs):
+    def __init__(self, channels, stride, downsample=False, in_channels=0,
+                 last_gamma=False, use_se=False, **kwargs):
         super(BottleneckV1, self).__init__(**kwargs)
         self.body = nn.HybridSequential(prefix='')
         self.body.add(nn.Conv2D(channels//4, kernel_size=1, strides=stride))
@@ -112,7 +144,21 @@ def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs):
         self.body.add(nn.BatchNorm())
         self.body.add(nn.Activation('relu'))
         self.body.add(nn.Conv2D(channels, kernel_size=1, strides=1))
-        self.body.add(nn.BatchNorm())
+
+        if use_se:
+            self.se = nn.HybridSequential(prefix='')
+            self.se.add(nn.Dense(channels // 4, use_bias=False))
+            self.se.add(nn.Activation('relu'))
+            self.se.add(nn.Dense(channels * 4, use_bias=False))
+            self.se.add(nn.Activation('sigmoid'))
+        else:
+            self.se = None
+
+        if not last_gamma:
+            self.body.add(nn.BatchNorm())
+        else:
+            self.body.add(nn.BatchNorm(gamma_initializer='zeros'))
+
         if downsample:
             self.downsample = nn.HybridSequential(prefix='')
             self.downsample.add(nn.Conv2D(channels, kernel_size=1, strides=stride,
@@ -126,6 +172,11 @@ def hybrid_forward(self, F, x):
 
         x = self.body(x)
 
+        if self.se:
+            w = F.contrib.AdaptiveAvgPooling2D(x, output_size=1)
+            w = self.se(w)
+            x = F.broadcast_mul(x, w.expand_dims(axis=2).expand_dims(axis=2))
+
         if self.downsample:
             residual = self.downsample(residual)
 
@@ -149,13 +200,31 @@ class BasicBlockV2(HybridBlock):
         Whether to downsample the input.
     in_channels : int, default 0
         Number of input channels. Default is 0, to infer from the graph.
+    last_gamma : bool, default False
+        Whether to initialize the gamma of the last BatchNorm layer in each bottleneck to zero.
+    use_se : bool, default False
+        Whether to use Squeeze-and-Excitation module
     """
-    def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs):
+    def __init__(self, channels, stride, downsample=False, in_channels=0,
+                 last_gamma=False, use_se=False, **kwargs):
         super(BasicBlockV2, self).__init__(**kwargs)
         self.bn1 = nn.BatchNorm()
         self.conv1 = _conv3x3(channels, stride, in_channels)
-        self.bn2 = nn.BatchNorm()
+        if not last_gamma:
+            self.bn2 = nn.BatchNorm()
+        else:
+            self.bn2 = nn.BatchNorm(gamma_initializer='zeros')
         self.conv2 = _conv3x3(channels, 1, channels)
+
+        if use_se:
+            self.se = nn.HybridSequential(prefix='')
+            self.se.add(nn.Dense(channels // 4, use_bias=False))
+            self.se.add(nn.Activation('relu'))
+            self.se.add(nn.Dense(channels * 4, use_bias=False))
+            self.se.add(nn.Activation('sigmoid'))
+        else:
+            self.se = None
+
         if downsample:
             self.downsample = nn.Conv2D(channels, 1, stride, use_bias=False,
                                         in_channels=in_channels)
@@ -174,6 +243,11 @@ def hybrid_forward(self, F, x):
         x = F.Activation(x, act_type='relu')
         x = self.conv2(x)
 
+        if self.se:
+            w = F.contrib.AdaptiveAvgPooling2D(x, output_size=1)
+            w = self.se(w)
+            x = F.broadcast_mul(x, w.expand_dims(axis=2).expand_dims(axis=2))
+
         return x + residual
 
 
@@ -193,15 +267,33 @@ class BottleneckV2(HybridBlock):
         Whether to downsample the input.
     in_channels : int, default 0
         Number of input channels. Default is 0, to infer from the graph.
+    last_gamma : bool, default False
+        Whether to initialize the gamma of the last BatchNorm layer in each bottleneck to zero.
+    use_se : bool, default False
+        Whether to use Squeeze-and-Excitation module
     """
-    def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs):
+    def __init__(self, channels, stride, downsample=False, in_channels=0,
+                 last_gamma=False, use_se=False, **kwargs):
         super(BottleneckV2, self).__init__(**kwargs)
         self.bn1 = nn.BatchNorm()
         self.conv1 = nn.Conv2D(channels//4, kernel_size=1, strides=1, use_bias=False)
         self.bn2 = nn.BatchNorm()
         self.conv2 = _conv3x3(channels//4, stride, channels//4)
-        self.bn3 = nn.BatchNorm()
+        if not last_gamma:
+            self.bn3 = nn.BatchNorm()
+        else:
+            self.bn3 = nn.BatchNorm(gamma_initializer='zeros')
         self.conv3 = nn.Conv2D(channels, kernel_size=1, strides=1, use_bias=False)
+
+        if use_se:
+            self.se = nn.HybridSequential(prefix='')
+            self.se.add(nn.Dense(channels // 4, use_bias=False))
+            self.se.add(nn.Activation('relu'))
+            self.se.add(nn.Dense(channels * 4, use_bias=False))
+            self.se.add(nn.Activation('sigmoid'))
+        else:
+            self.se = None
+
         if downsample:
             self.downsample = nn.Conv2D(channels, 1, stride, use_bias=False,
                                         in_channels=in_channels)
@@ -224,6 +316,11 @@ def hybrid_forward(self, F, x):
         x = F.Activation(x, act_type='relu')
         x = self.conv3(x)
 
+        if self.se:
+            w = F.contrib.AdaptiveAvgPooling2D(x, output_size=1)
+            w = self.se(w)
+            x = F.broadcast_mul(x, w.expand_dims(axis=2).expand_dims(axis=2))
+
         return x + residual
 
 
@@ -245,8 +342,13 @@ class ResNetV1(HybridBlock):
         Number of classification classes.
     thumbnail : bool, default False
         Enable thumbnail.
+    last_gamma : bool, default False
+        Whether to initialize the gamma of the last BatchNorm layer in each bottleneck to zero.
+    use_se : bool, default False
+        Whether to use Squeeze-and-Excitation module
     """
-    def __init__(self, block, layers, channels, classes=1000, thumbnail=False, **kwargs):
+    def __init__(self, block, layers, channels, classes=1000, thumbnail=False,
+                 last_gamma=False, use_se=False, **kwargs):
         super(ResNetV1, self).__init__(**kwargs)
         assert len(layers) == len(channels) - 1
         with self.name_scope():
@@ -262,18 +364,21 @@ def __init__(self, block, layers, channels, classes=1000, thumbnail=False, **kwa
             for i, num_layer in enumerate(layers):
                 stride = 1 if i == 0 else 2
                 self.features.add(self._make_layer(block, num_layer, channels[i+1],
-                                                   stride, i+1, in_channels=channels[i]))
+                                                   stride, i+1, in_channels=channels[i],
+                                                   last_gamma=last_gamma, use_se=use_se))
             self.features.add(nn.GlobalAvgPool2D())
 
             self.output = nn.Dense(classes, in_units=channels[-1])
 
-    def _make_layer(self, block, layers, channels, stride, stage_index, in_channels=0):
+    def _make_layer(self, block, layers, channels, stride, stage_index, in_channels=0,
+                    last_gamma=False, use_se=False):
         layer = nn.HybridSequential(prefix='stage%d_'%stage_index)
         with layer.name_scope():
             layer.add(block(channels, stride, channels != in_channels, in_channels=in_channels,
-                            prefix=''))
+                            last_gamma=last_gamma, use_se=use_se, prefix=''))
             for _ in range(layers-1):
-                layer.add(block(channels, 1, False, in_channels=channels, prefix=''))
+                layer.add(block(channels, 1, False, in_channels=channels,
+                                last_gamma=last_gamma, use_se=use_se, prefix=''))
         return layer
 
     def hybrid_forward(self, F, x):
@@ -300,8 +405,13 @@ class ResNetV2(HybridBlock):
         Number of classification classes.
     thumbnail : bool, default False
         Enable thumbnail.
+    last_gamma : bool, default False
+        Whether to initialize the gamma of the last BatchNorm layer in each bottleneck to zero.
+    use_se : bool, default False
+        Whether to use Squeeze-and-Excitation module
     """
-    def __init__(self, block, layers, channels, classes=1000, thumbnail=False, **kwargs):
+    def __init__(self, block, layers, channels, classes=1000, thumbnail=False,
+                 last_gamma=False, use_se=False, **kwargs):
         super(ResNetV2, self).__init__(**kwargs)
         assert len(layers) == len(channels) - 1
         with self.name_scope():
@@ -319,7 +429,8 @@ def __init__(self, block, layers, channels, classes=1000, thumbnail=False, **kwa
             for i, num_layer in enumerate(layers):
                 stride = 1 if i == 0 else 2
                 self.features.add(self._make_layer(block, num_layer, channels[i+1],
-                                                   stride, i+1, in_channels=in_channels))
+                                                   stride, i+1, in_channels=in_channels,
+                                                   last_gamma=last_gamma, use_se=use_se))
                 in_channels = channels[i+1]
             self.features.add(nn.BatchNorm())
             self.features.add(nn.Activation('relu'))
@@ -328,13 +439,15 @@ def __init__(self, block, layers, channels, classes=1000, thumbnail=False, **kwa
 
             self.output = nn.Dense(classes, in_units=in_channels)
 
-    def _make_layer(self, block, layers, channels, stride, stage_index, in_channels=0):
+    def _make_layer(self, block, layers, channels, stride, stage_index, in_channels=0,
+                    last_gamma=False, use_se=False):
         layer = nn.HybridSequential(prefix='stage%d_'%stage_index)
         with layer.name_scope():
             layer.add(block(channels, stride, channels != in_channels, in_channels=in_channels,
-                            prefix=''))
+                            last_gamma=last_gamma, use_se=use_se, prefix=''))
             for _ in range(layers-1):
-                layer.add(block(channels, 1, False, in_channels=channels, prefix=''))
+                layer.add(block(channels, 1, False, in_channels=channels,
+                                last_gamma=last_gamma, use_se=use_se, prefix=''))
         return layer
 
     def hybrid_forward(self, F, x):
@@ -357,7 +470,7 @@ def hybrid_forward(self, F, x):
 
 # Constructor
 def get_resnet(version, num_layers, pretrained=False, ctx=cpu(),
-               root=os.path.join(base.data_dir(), 'models'), **kwargs):
+               root=os.path.join(base.data_dir(), 'models'), use_se=False, **kwargs):
     r"""ResNet V1 model from `"Deep Residual Learning for Image Recognition"
     <http://arxiv.org/abs/1512.03385>`_ paper.
     ResNet V2 model from `"Identity Mappings in Deep Residual Networks"
@@ -375,6 +488,8 @@ def get_resnet(version, num_layers, pretrained=False, ctx=cpu(),
         The context in which to load the pretrained weights.
     root : str, default $MXNET_HOME/models
         Location for keeping the model parameters.
+    use_se : bool, default False
+        Whether to use Squeeze-and-Excitation module
     """
     assert num_layers in resnet_spec, \
         "Invalid number of layers: %d. Options are %s"%(
@@ -387,8 +502,12 @@ def get_resnet(version, num_layers, pretrained=False, ctx=cpu(),
     net = resnet_class(block_class, layers, channels, **kwargs)
     if pretrained:
         from ..model_store import get_model_file
-        net.load_parameters(get_model_file('resnet%d_v%d'%(num_layers, version),
-                                           root=root), ctx=ctx)
+        if not use_se:
+            net.load_parameters(get_model_file('resnet%d_v%d'%(num_layers, version),
+                                               root=root), ctx=ctx)
+        else:
+            net.load_parameters(get_model_file('se_resnet%d_v%d'%(num_layers, version),
+                                               root=root), ctx=ctx)
     return net
 
 def resnet18_v1(**kwargs):
@@ -404,7 +523,7 @@ def resnet18_v1(**kwargs):
     root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
-    return get_resnet(1, 18, **kwargs)
+    return get_resnet(1, 18, use_se=False, **kwargs)
 
 def resnet34_v1(**kwargs):
     r"""ResNet-34 V1 model from `"Deep Residual Learning for Image Recognition"
@@ -419,7 +538,7 @@ def resnet34_v1(**kwargs):
     root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
-    return get_resnet(1, 34, **kwargs)
+    return get_resnet(1, 34, use_se=False, **kwargs)
 
 def resnet50_v1(**kwargs):
     r"""ResNet-50 V1 model from `"Deep Residual Learning for Image Recognition"
@@ -434,7 +553,7 @@ def resnet50_v1(**kwargs):
     root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
-    return get_resnet(1, 50, **kwargs)
+    return get_resnet(1, 50, use_se=False, **kwargs)
 
 def resnet101_v1(**kwargs):
     r"""ResNet-101 V1 model from `"Deep Residual Learning for Image Recognition"
@@ -449,7 +568,7 @@ def resnet101_v1(**kwargs):
     root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
-    return get_resnet(1, 101, **kwargs)
+    return get_resnet(1, 101, use_se=False, **kwargs)
 
 def resnet152_v1(**kwargs):
     r"""ResNet-152 V1 model from `"Deep Residual Learning for Image Recognition"
@@ -464,7 +583,7 @@ def resnet152_v1(**kwargs):
     root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
-    return get_resnet(1, 152, **kwargs)
+    return get_resnet(1, 152, use_se=False, **kwargs)
 
 def resnet18_v2(**kwargs):
     r"""ResNet-18 V2 model from `"Identity Mappings in Deep Residual Networks"
@@ -479,7 +598,7 @@ def resnet18_v2(**kwargs):
     root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
-    return get_resnet(2, 18, **kwargs)
+    return get_resnet(2, 18, use_se=False, **kwargs)
 
 def resnet34_v2(**kwargs):
     r"""ResNet-34 V2 model from `"Identity Mappings in Deep Residual Networks"
@@ -494,7 +613,7 @@ def resnet34_v2(**kwargs):
     root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
-    return get_resnet(2, 34, **kwargs)
+    return get_resnet(2, 34, use_se=False, **kwargs)
 
 def resnet50_v2(**kwargs):
     r"""ResNet-50 V2 model from `"Identity Mappings in Deep Residual Networks"
@@ -509,7 +628,7 @@ def resnet50_v2(**kwargs):
     root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
-    return get_resnet(2, 50, **kwargs)
+    return get_resnet(2, 50, use_se=False, **kwargs)
 
 def resnet101_v2(**kwargs):
     r"""ResNet-101 V2 model from `"Identity Mappings in Deep Residual Networks"
@@ -524,7 +643,7 @@ def resnet101_v2(**kwargs):
     root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
-    return get_resnet(2, 101, **kwargs)
+    return get_resnet(2, 101, use_se=False, **kwargs)
 
 def resnet152_v2(**kwargs):
     r"""ResNet-152 V2 model from `"Identity Mappings in Deep Residual Networks"
@@ -539,4 +658,155 @@ def resnet152_v2(**kwargs):
     root : str, default '$MXNET_HOME/models'
         Location for keeping the model parameters.
     """
-    return get_resnet(2, 152, **kwargs)
+    return get_resnet(2, 152, use_se=False, **kwargs)
+
+# SE-ResNet
+def se_resnet18_v1(**kwargs):
+    r"""SE-ResNet-18 V1 model from `"Squeeze-and-Excitation Networks"
+    <https://arxiv.org/abs/1709.01507>`_ paper.
+
+    Parameters
+    ----------
+    pretrained : bool, default False
+        Whether to load the pretrained weights for model.
+    ctx : Context, default CPU
+        The context in which to load the pretrained weights.
+    root : str, default '$MXNET_HOME/models'
+        Location for keeping the model parameters.
+    """
+    return get_resnet(1, 18, use_se=True, **kwargs)
+
+def se_resnet34_v1(**kwargs):
+    r"""SE-ResNet-34 V1 model from `"Squeeze-and-Excitation Networks"
+    <https://arxiv.org/abs/1709.01507>`_ paper.
+
+    Parameters
+    ----------
+    pretrained : bool, default False
+        Whether to load the pretrained weights for model.
+    ctx : Context, default CPU
+        The context in which to load the pretrained weights.
+    root : str, default '$MXNET_HOME/models'
+        Location for keeping the model parameters.
+    """
+    return get_resnet(1, 34, use_se=True, **kwargs)
+
+def se_resnet50_v1(**kwargs):
+    r"""SE-ResNet-50 V1 model from `"Squeeze-and-Excitation Networks"
+    <https://arxiv.org/abs/1709.01507>`_ paper.
+
+    Parameters
+    ----------
+    pretrained : bool, default False
+        Whether to load the pretrained weights for model.
+    ctx : Context, default CPU
+        The context in which to load the pretrained weights.
+    root : str, default '$MXNET_HOME/models'
+        Location for keeping the model parameters.
+    """
+    return get_resnet(1, 50, use_se=True, **kwargs)
+
+def se_resnet101_v1(**kwargs):
+    r"""SE-ResNet-101 V1 model from `"Squeeze-and-Excitation Networks"
+    <https://arxiv.org/abs/1709.01507>`_ paper.
+
+    Parameters
+    ----------
+    pretrained : bool, default False
+        Whether to load the pretrained weights for model.
+    ctx : Context, default CPU
+        The context in which to load the pretrained weights.
+    root : str, default '$MXNET_HOME/models'
+        Location for keeping the model parameters.
+    """
+    return get_resnet(1, 101, use_se=True, **kwargs)
+
+def se_resnet152_v1(**kwargs):
+    r"""SE-ResNet-152 V1 model from `"Squeeze-and-Excitation Networks"
+    <https://arxiv.org/abs/1709.01507>`_ paper.
+
+    Parameters
+    ----------
+    pretrained : bool, default False
+        Whether to load the pretrained weights for model.
+    ctx : Context, default CPU
+        The context in which to load the pretrained weights.
+    root : str, default '$MXNET_HOME/models'
+        Location for keeping the model parameters.
+    """
+    return get_resnet(1, 152, use_se=True, **kwargs)
+
+def se_resnet18_v2(**kwargs):
+    r"""SE-ResNet-18 V2 model from `"Squeeze-and-Excitation Networks"
+    <https://arxiv.org/abs/1709.01507>`_ paper.
+
+    Parameters
+    ----------
+    pretrained : bool, default False
+        Whether to load the pretrained weights for model.
+    ctx : Context, default CPU
+        The context in which to load the pretrained weights.
+    root : str, default '$MXNET_HOME/models'
+        Location for keeping the model parameters.
+    """
+    return get_resnet(2, 18, use_se=True, **kwargs)
+
+def se_resnet34_v2(**kwargs):
+    r"""SE-ResNet-34 V2 model from `"Squeeze-and-Excitation Networks"
+    <https://arxiv.org/abs/1709.01507>`_ paper.
+
+    Parameters
+    ----------
+    pretrained : bool, default False
+        Whether to load the pretrained weights for model.
+    ctx : Context, default CPU
+        The context in which to load the pretrained weights.
+    root : str, default '$MXNET_HOME/models'
+        Location for keeping the model parameters.
+    """
+    return get_resnet(2, 34, use_se=True, **kwargs)
+
+def se_resnet50_v2(**kwargs):
+    r"""SE-ResNet-50 V2 model from `"Squeeze-and-Excitation Networks"
+    <https://arxiv.org/abs/1709.01507>`_ paper.
+
+    Parameters
+    ----------
+    pretrained : bool, default False
+        Whether to load the pretrained weights for model.
+    ctx : Context, default CPU
+        The context in which to load the pretrained weights.
+    root : str, default '$MXNET_HOME/models'
+        Location for keeping the model parameters.
+    """
+    return get_resnet(2, 50, use_se=True, **kwargs)
+
+def se_resnet101_v2(**kwargs):
+    r"""SE-ResNet-101 V2 model from `"Squeeze-and-Excitation Networks"
+    <https://arxiv.org/abs/1709.01507>`_ paper.
+
+    Parameters
+    ----------
+    pretrained : bool, default False
+        Whether to load the pretrained weights for model.
+    ctx : Context, default CPU
+        The context in which to load the pretrained weights.
+    root : str, default '$MXNET_HOME/models'
+        Location for keeping the model parameters.
+    """
+    return get_resnet(2, 101, use_se=True, **kwargs)
+
+def se_resnet152_v2(**kwargs):
+    r"""SE-ResNet-152 V2 model from `"Squeeze-and-Excitation Networks"
+    <https://arxiv.org/abs/1709.01507>`_ paper.
+
+    Parameters
+    ----------
+    pretrained : bool, default False
+        Whether to load the pretrained weights for model.
+    ctx : Context, default CPU
+        The context in which to load the pretrained weights.
+    root : str, default '$MXNET_HOME/models'
+        Location for keeping the model parameters.
+    """
+    return get_resnet(2, 152, use_se=True, **kwargs)
diff --git a/tests/python/unittest/test_gluon_model_zoo.py b/tests/python/unittest/test_gluon_model_zoo.py
index a64668451a2..668f274846e 100644
--- a/tests/python/unittest/test_gluon_model_zoo.py
+++ b/tests/python/unittest/test_gluon_model_zoo.py
@@ -30,6 +30,10 @@ def eprint(*args, **kwargs):
 def test_models():
     all_models = ['resnet18_v1', 'resnet34_v1', 'resnet50_v1', 'resnet101_v1', 'resnet152_v1',
                   'resnet18_v2', 'resnet34_v2', 'resnet50_v2', 'resnet101_v2', 'resnet152_v2',
+                  'se_resnet18_v1', 'se_resnet34_v1', 'se_resnet50_v1',
+                  'se_resnet101_v1', 'se_resnet152_v1',
+                  'se_resnet18_v2', 'se_resnet34_v2', 'se_resnet50_v2',
+                  'se_resnet101_v2', 'se_resnet152_v2',
                   'vgg11', 'vgg13', 'vgg16', 'vgg19',
                   'vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn',
                   'alexnet', 'inceptionv3',


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services