You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2018/11/20 09:16:33 UTC
[3/3] incubator-singa git commit: use tqdm for progress bar in
cifar10 example
use tqdm for progress bar in cifar10 example
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/e29ef5ff
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/e29ef5ff
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/e29ef5ff
Branch: refs/heads/master
Commit: e29ef5ff87f25838e3104fc326ec3a43950e72b6
Parents: e08af42
Author: Wang Wei <wa...@gmail.com>
Authored: Tue Nov 20 17:16:14 2018 +0800
Committer: Wang Wei <wa...@gmail.com>
Committed: Tue Nov 20 17:16:14 2018 +0800
----------------------------------------------------------------------
examples/cifar10/train.py | 33 +++++++++++++++++----------------
1 file changed, 17 insertions(+), 16 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e29ef5ff/examples/cifar10/train.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/train.py b/examples/cifar10/train.py
index b2ab4af..7657026 100644
--- a/examples/cifar10/train.py
+++ b/examples/cifar10/train.py
@@ -30,6 +30,7 @@ except ImportError:
import numpy as np
import os
import argparse
+from tqdm import trange
from singa import utils
from singa import optimizer
@@ -145,21 +146,21 @@ def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100,
for epoch in range(max_epoch):
np.random.shuffle(idx)
loss, acc = 0.0, 0.0
- print('Epoch %d' % epoch)
- for b in range(num_train_batch):
- x = train_x[idx[b * batch_size: (b + 1) * batch_size]]
- y = train_y[idx[b * batch_size: (b + 1) * batch_size]]
- tx.copy_from_numpy(x)
- ty.copy_from_numpy(y)
- grads, (l, a) = net.train(tx, ty)
- loss += l
- acc += a
- for (s, p, g) in zip(net.param_names(), net.param_values(), grads):
- opt.apply_with_lr(epoch, get_lr(epoch), g, p, str(s), b)
- # update progress bar
- utils.update_progress(b * 1.0 / num_train_batch,
- 'training loss = %f, accuracy = %f' % (l, a))
- info = '\ntraining loss = %f, training accuracy = %f, lr = %f' \
+ with trange(num_train_batch) as t:
+ t.set_description('Epoch={}'.format(epoch))
+ for b in t:
+ x = train_x[idx[b * batch_size: (b + 1) * batch_size]]
+ y = train_y[idx[b * batch_size: (b + 1) * batch_size]]
+ tx.copy_from_numpy(x)
+ ty.copy_from_numpy(y)
+ grads, (l, a) = net.train(tx, ty)
+ loss += l
+ acc += a
+ for (s, p, g) in zip(net.param_names(), net.param_values(), grads):
+ opt.apply_with_lr(epoch, get_lr(epoch), g, p, str(s), b)
+ t.set_postfix(loss=l, accuracy=a)
+
+ info = 'Training loss = %f, training accuracy = %f, lr = %f' \
% ((loss / num_train_batch), (acc / num_train_batch), get_lr(epoch))
print(info)
@@ -173,7 +174,7 @@ def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100,
loss += l
acc += a
- print('test loss = %f, test accuracy = %f' %
+ print('Test loss = %f, test accuracy = %f' %
((loss / num_test_batch), (acc / num_test_batch)))
net.save('model', 20) # save model params into checkpoint file