You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2017/12/16 05:00:12 UTC

[incubator-mxnet] branch master updated: Modify Caffe example to use module interface instead of the deprecated model interface. (#9095)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 24c7231  Modify Caffe example to use module interface instead of the deprecated model interface. (#9095)
24c7231 is described below

commit 24c7231f57d967c4b18a9f37df18b1a7d53e45e5
Author: Indhu Bharathi <in...@gmail.com>
AuthorDate: Fri Dec 15 21:00:06 2017 -0800

    Modify Caffe example to use module interface instead of the deprecated model interface. (#9095)
---
 example/caffe/train_model.py | 24 ++++++++----------------
 1 file changed, 8 insertions(+), 16 deletions(-)

diff --git a/example/caffe/train_model.py b/example/caffe/train_model.py
index 2eadd86..4290e71 100644
--- a/example/caffe/train_model.py
+++ b/example/caffe/train_model.py
@@ -85,15 +85,8 @@ def fit(args, network, data_loader, eval_metrics=None, batch_end_callback=None):
             args.gpus is None or len(args.gpus.split(',')) is 1):
         kv = None
 
-    model = mx.model.FeedForward(
-        ctx                = devs,
-        symbol             = network,
-        num_epoch          = args.num_epochs,
-        learning_rate      = args.lr,
-        momentum           = 0.9,
-        wd                 = 0.00001,
-        initializer        = mx.init.Xavier(factor_type="in", magnitude=2.34),
-        **model_args)
+
+    mod = mx.mod.Module(network, context=devs)
 
     if eval_metrics is None:
         eval_metrics = ['accuracy']
@@ -108,10 +101,9 @@ def fit(args, network, data_loader, eval_metrics=None, batch_end_callback=None):
         batch_end_callback = []
     batch_end_callback.append(mx.callback.Speedometer(args.batch_size, 50))
 
-    model.fit(
-       X                  = train,
-       eval_data          = val,
-       eval_metric        = eval_metrics,
-       kvstore            = kv,
-       batch_end_callback = batch_end_callback,
-       epoch_end_callback = checkpoint)
+    mod.fit(train_data=train, eval_metric=eval_metrics, eval_data=val, optimizer='sgd',
+        optimizer_params={'learning_rate':args.lr, 'momentum': 0.9, 'wd': 0.00001},
+        num_epoch=args.num_epochs, batch_end_callback=batch_end_callback,
+        initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
+        kvstore=kv, epoch_end_callback=checkpoint, **model_args)
+

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].