You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/12/12 22:05:46 UTC

[GitHub] piiswrong closed pull request #8070: modify shell file of classification for gluon

piiswrong closed pull request #8070: modify shell file of classification for gluon
URL: https://github.com/apache/incubator-mxnet/pull/8070
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/gluon/image_classification.py b/example/gluon/image_classification.py
index 8481afb50c..34915bd2b9 100644
--- a/example/gluon/image_classification.py
+++ b/example/gluon/image_classification.py
@@ -64,6 +64,18 @@
 parser.add_argument('--kvstore', type=str, default='device',
                     help='kvstore to use for trainer/module.')
 parser.add_argument('--log-interval', type=int, default=50, help='Number of batches to wait before logging.')
+parser.add_argument('--lr-factor', type=float, default=0.1,
+                    help='the ratio to reduce lr on each step')
+parser.add_argument('--lr-step-epochs', type=str, default='30,60',
+                   help='the epochs to reduce the lr, e.g. 30,60')
+parser.add_argument('--num-examples', type=int, default=1281167,
+                    help='the number of training examples')
+parser.add_argument('--load-epoch', type=int,
+                    help = 'load the model on an epoch using the model-load-prefix')
+parser.add_argument('--model-prefix', type=str, default='model/',
+                    help='model prefix')
+parser.add_argument('--optimizer', type=str, default='sgd',
+                   help='the optimizer type')
 opt = parser.parse_args()
 
 logging.info(opt)
@@ -107,6 +119,45 @@
     else:
         train_data, val_data = dummy_iterator(batch_size, (3, 224, 224))
 
+kv = mx.kvstore.create(opt.kvstore)
+def _get_lr_scheduler():
+    if 'lr_factor' not in opt or opt.lr_factor >= 1:
+        return (opt.lr, None)
+    epoch_size = int(opt.num_examples / opt.batch_size / opt.num_gpus)
+    if 'dist' in opt.kvstore:
+        epoch_size /= kv.num_workers
+    begin_epoch = opt.load_epoch if opt.load_epoch else 0
+    step_epochs = [int(l) for l in opt.lr_step_epochs.split(',')]
+    lr = opt.lr
+    for s in step_epochs:
+        if begin_epoch >= s:
+            lr *= opt.lr_factor
+    if lr != opt.lr:
+        logging.info('Adjust learning rate to %e for epoch %d' %(lr, begin_epoch))
+    steps = [epoch_size * (x-begin_epoch) for x in step_epochs if x-begin_epoch > 0]
+    return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=opt.lr_factor))
+
+def _load_model(rank=0):
+    if 'load_epoch' not in opt or opt.load_epoch is None:
+        return (None, None, None)
+    assert opt.model_prefix is not None
+    model_prefix = opt.model_prefix
+    if rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)):
+        model_prefix += "-%d" % (rank)
+    sym, arg_params, aux_params = mx.model.load_checkpoint(
+        model_prefix, opt.load_epoch)
+    logging.info('Loaded model %s_%04d.params', model_prefix, opt.load_epoch)
+    return (sym, arg_params, aux_params)
+
+def _save_model(rank=0):
+    if opt.model_prefix is None:
+        return None
+    dst_dir = os.path.dirname(opt.model_prefix)
+    if not os.path.isdir(dst_dir):
+        os.mkdir(dst_dir)
+    return mx.callback.do_checkpoint(opt.model_prefix if rank == 0 else "%s-%d" % (
+        opt.model_prefix, rank))
+
 def test(ctx):
     metric = mx.metric.Accuracy()
     val_data.reset()
@@ -172,16 +223,47 @@ def train(epochs, ctx):
         out = net(data)
         softmax = mx.sym.SoftmaxOutput(out, name='softmax')
         mod = mx.mod.Module(softmax, context=[mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()])
+
+        # load model
+        if 'arg_params' in kwargs and 'aux_params' in kwargs:
+            arg_params = kwargs['arg_params']
+            aux_params = kwargs['aux_params']
+        else:
+            sym, arg_params, aux_params = _load_model(kv.rank)
+            if sym is not None:
+                assert sym.tojson() == softmax.tojson()
+
+        # save model
+        checkpoint = _save_model(kv.rank)
+
+        # learning rate
+        lr, lr_scheduler = _get_lr_scheduler()
+        optimizer_params = {
+            'learning_rate': lr,
+            'wd': opt.wd,
+            'lr_scheduler': lr_scheduler}
+        # Add 'multi_precision' parameter only for SGD optimizer
+        if opt.optimizer == 'sgd':
+            optimizer_params['multi_precision'] = True
+
+        # Only a limited number of optimizers have 'momentum' property
+        has_momentum = {'sgd', 'dcasgd', 'nag'}
+        if opt.optimizer in has_momentum:
+            optimizer_params['momentum'] = opt.momentum
+
         mod.fit(train_data,
+                begin_epoch=opt.load_epoch if opt.load_epoch else 0,
                 eval_data = val_data,
                 num_epoch=opt.epochs,
-                kvstore=opt.kvstore,
-                batch_end_callback = mx.callback.Speedometer(batch_size, max(1, opt.log_interval)),
-                epoch_end_callback = mx.callback.do_checkpoint('image-classifier-%s'% opt.model),
+                kvstore=kv,
+                batch_end_callback =[mx.callback.Speedometer(batch_size, max(1, opt.log_interval))],
+                epoch_end_callback = checkpoint,
                 optimizer = 'sgd',
-                optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum},
+                arg_params = arg_params,
+                aux_params = aux_params,
+                optimizer_params = optimizer_params,
                 initializer = mx.init.Xavier(magnitude=2))
-        mod.save_params('image-classifier-%s-%d-final.params'%(opt.model, epochs))
+        mod.save_params('image-classifier-%s-%d-final.params'%(opt.model, opt.epochs))
     else:
         if opt.mode == 'hybrid':
             net.hybridize()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services