You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by lx...@apache.org on 2017/07/07 15:58:21 UTC

[08/50] [abbrv] incubator-mxnet-test git commit: Fixed broken SpatialTransformerNetwork example (#5798)

Fixed broken SpatialTransformerNetwork example (#5798)

* re-added option to include a SpatialTransformerNetwork layer in the lenet/mnist example, and a cmd argument in train_mnist to activate it

* Update lenet.py


Project: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/commit/17620fe5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/tree/17620fe5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/diff/17620fe5

Branch: refs/heads/master
Commit: 17620fe513742014d2cebd4594bc309105808567
Parents: ff96822
Author: Pepe Mandioca <fa...@users.noreply.github.com>
Authored: Tue Jun 27 14:51:32 2017 -0300
Committer: Mu Li <mu...@cs.cmu.edu>
Committed: Tue Jun 27 10:51:32 2017 -0700

----------------------------------------------------------------------
 example/image-classification/symbols/lenet.py | 18 +++++++++++++++++-
 example/image-classification/train_mnist.py   |  5 ++++-
 2 files changed, 21 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/17620fe5/example/image-classification/symbols/lenet.py
----------------------------------------------------------------------
diff --git a/example/image-classification/symbols/lenet.py b/example/image-classification/symbols/lenet.py
index f6cfd68..6df0299 100644
--- a/example/image-classification/symbols/lenet.py
+++ b/example/image-classification/symbols/lenet.py
@@ -5,9 +5,25 @@ Proceedings of the IEEE (1998)
 """
 import mxnet as mx
 
+def get_loc(data, attr={'lr_mult':'0.01'}):
+    """
+    the localisation network in lenet-stn, it will increase acc about more than 1%,
+    when num-epoch >=15
+    """
+    loc = mx.symbol.Convolution(data=data, num_filter=30, kernel=(5, 5), stride=(2,2))
+    loc = mx.symbol.Activation(data = loc, act_type='relu')
+    loc = mx.symbol.Pooling(data=loc, kernel=(2, 2), stride=(2, 2), pool_type='max')
+    loc = mx.symbol.Convolution(data=loc, num_filter=60, kernel=(3, 3), stride=(1,1), pad=(1, 1))
+    loc = mx.symbol.Activation(data = loc, act_type='relu')
+    loc = mx.symbol.Pooling(data=loc, global_pool=True, kernel=(2, 2), pool_type='avg')
+    loc = mx.symbol.Flatten(data=loc)
+    loc = mx.symbol.FullyConnected(data=loc, num_hidden=6, name="stn_loc", attr=attr)
+    return loc
+
+
 def get_symbol(num_classes=10, add_stn=False, **kwargs):
     data = mx.symbol.Variable('data')
-    if(add_stn):
+    if add_stn:
         data = mx.sym.SpatialTransformer(data=data, loc=get_loc(data), target_shape = (28,28),
                                          transform_type="affine", sampler_type="bilinear")
     # first conv

http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/17620fe5/example/image-classification/train_mnist.py
----------------------------------------------------------------------
diff --git a/example/image-classification/train_mnist.py b/example/image-classification/train_mnist.py
index 61162e6..31ecbfb 100644
--- a/example/image-classification/train_mnist.py
+++ b/example/image-classification/train_mnist.py
@@ -53,6 +53,9 @@ if __name__ == '__main__':
                         help='the number of classes')
     parser.add_argument('--num-examples', type=int, default=60000,
                         help='the number of training examples')
+    
+    parser.add_argument('--add_stn',  action="store_true", default=False, help='Add Spatial Transformer Network Layer (lenet only)')
+    
     fit.add_fit_args(parser)
     parser.set_defaults(
         # network
@@ -63,7 +66,7 @@ if __name__ == '__main__':
         disp_batches   = 100,
         num_epochs     = 20,
         lr             = .05,
-        lr_step_epochs = '10',
+        lr_step_epochs = '10'
     )
     args = parser.parse_args()