You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by lx...@apache.org on 2017/07/07 15:58:21 UTC
[08/50] [abbrv] incubator-mxnet-test git commit: Fixed broken
SpatialTransformerNetwork example (#5798)
Fixed broken SpatialTransformerNetwork example (#5798)
* re-added option to include a SpatialTransformerNetwork layer in the lenet/mnist example, and a cmd argument in train_mnist to activate it
* Update lenet.py
Project: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/commit/17620fe5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/tree/17620fe5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/diff/17620fe5
Branch: refs/heads/master
Commit: 17620fe513742014d2cebd4594bc309105808567
Parents: ff96822
Author: Pepe Mandioca <fa...@users.noreply.github.com>
Authored: Tue Jun 27 14:51:32 2017 -0300
Committer: Mu Li <mu...@cs.cmu.edu>
Committed: Tue Jun 27 10:51:32 2017 -0700
----------------------------------------------------------------------
example/image-classification/symbols/lenet.py | 18 +++++++++++++++++-
example/image-classification/train_mnist.py | 5 ++++-
2 files changed, 21 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/17620fe5/example/image-classification/symbols/lenet.py
----------------------------------------------------------------------
diff --git a/example/image-classification/symbols/lenet.py b/example/image-classification/symbols/lenet.py
index f6cfd68..6df0299 100644
--- a/example/image-classification/symbols/lenet.py
+++ b/example/image-classification/symbols/lenet.py
@@ -5,9 +5,25 @@ Proceedings of the IEEE (1998)
"""
import mxnet as mx
+def get_loc(data, attr={'lr_mult':'0.01'}):
+ """
+ the localisation network in lenet-stn, it will increase acc about more than 1%,
+ when num-epoch >=15
+ """
+ loc = mx.symbol.Convolution(data=data, num_filter=30, kernel=(5, 5), stride=(2,2))
+ loc = mx.symbol.Activation(data = loc, act_type='relu')
+ loc = mx.symbol.Pooling(data=loc, kernel=(2, 2), stride=(2, 2), pool_type='max')
+ loc = mx.symbol.Convolution(data=loc, num_filter=60, kernel=(3, 3), stride=(1,1), pad=(1, 1))
+ loc = mx.symbol.Activation(data = loc, act_type='relu')
+ loc = mx.symbol.Pooling(data=loc, global_pool=True, kernel=(2, 2), pool_type='avg')
+ loc = mx.symbol.Flatten(data=loc)
+ loc = mx.symbol.FullyConnected(data=loc, num_hidden=6, name="stn_loc", attr=attr)
+ return loc
+
+
def get_symbol(num_classes=10, add_stn=False, **kwargs):
data = mx.symbol.Variable('data')
- if(add_stn):
+ if add_stn:
data = mx.sym.SpatialTransformer(data=data, loc=get_loc(data), target_shape = (28,28),
transform_type="affine", sampler_type="bilinear")
# first conv
http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/17620fe5/example/image-classification/train_mnist.py
----------------------------------------------------------------------
diff --git a/example/image-classification/train_mnist.py b/example/image-classification/train_mnist.py
index 61162e6..31ecbfb 100644
--- a/example/image-classification/train_mnist.py
+++ b/example/image-classification/train_mnist.py
@@ -53,6 +53,9 @@ if __name__ == '__main__':
help='the number of classes')
parser.add_argument('--num-examples', type=int, default=60000,
help='the number of training examples')
+
+ parser.add_argument('--add_stn', action="store_true", default=False, help='Add Spatial Transformer Network Layer (lenet only)')
+
fit.add_fit_args(parser)
parser.set_defaults(
# network
@@ -63,7 +66,7 @@ if __name__ == '__main__':
disp_batches = 100,
num_epochs = 20,
lr = .05,
- lr_step_epochs = '10',
+ lr_step_epochs = '10'
)
args = parser.parse_args()