You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2017/12/16 05:00:12 UTC
[incubator-mxnet] branch master updated: Modify Caffe example to
use module interface instead of the deprecated model interface. (#9095)
This is an automated email from the ASF dual-hosted git repository.
skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 24c7231 Modify Caffe example to use module interface instead of the deprecated model interface. (#9095)
24c7231 is described below
commit 24c7231f57d967c4b18a9f37df18b1a7d53e45e5
Author: Indhu Bharathi <in...@gmail.com>
AuthorDate: Fri Dec 15 21:00:06 2017 -0800
Modify Caffe example to use module interface instead of the deprecated model interface. (#9095)
---
example/caffe/train_model.py | 24 ++++++++----------------
1 file changed, 8 insertions(+), 16 deletions(-)
diff --git a/example/caffe/train_model.py b/example/caffe/train_model.py
index 2eadd86..4290e71 100644
--- a/example/caffe/train_model.py
+++ b/example/caffe/train_model.py
@@ -85,15 +85,8 @@ def fit(args, network, data_loader, eval_metrics=None, batch_end_callback=None):
args.gpus is None or len(args.gpus.split(',')) is 1):
kv = None
- model = mx.model.FeedForward(
- ctx = devs,
- symbol = network,
- num_epoch = args.num_epochs,
- learning_rate = args.lr,
- momentum = 0.9,
- wd = 0.00001,
- initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),
- **model_args)
+
+ mod = mx.mod.Module(network, context=devs)
if eval_metrics is None:
eval_metrics = ['accuracy']
@@ -108,10 +101,9 @@ def fit(args, network, data_loader, eval_metrics=None, batch_end_callback=None):
batch_end_callback = []
batch_end_callback.append(mx.callback.Speedometer(args.batch_size, 50))
- model.fit(
- X = train,
- eval_data = val,
- eval_metric = eval_metrics,
- kvstore = kv,
- batch_end_callback = batch_end_callback,
- epoch_end_callback = checkpoint)
+ mod.fit(train_data=train, eval_metric=eval_metrics, eval_data=val, optimizer='sgd',
+ optimizer_params={'learning_rate':args.lr, 'momentum': 0.9, 'wd': 0.00001},
+ num_epoch=args.num_epochs, batch_end_callback=batch_end_callback,
+ initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
+ kvstore=kv, epoch_end_callback=checkpoint, **model_args)
+
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].