You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2018/11/20 09:16:33 UTC

[3/3] incubator-singa git commit: use tqdm for progress bar in cifar10 example

use tqdm for progress bar in cifar10 example


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/e29ef5ff
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/e29ef5ff
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/e29ef5ff

Branch: refs/heads/master
Commit: e29ef5ff87f25838e3104fc326ec3a43950e72b6
Parents: e08af42
Author: Wang Wei <wa...@gmail.com>
Authored: Tue Nov 20 17:16:14 2018 +0800
Committer: Wang Wei <wa...@gmail.com>
Committed: Tue Nov 20 17:16:14 2018 +0800

----------------------------------------------------------------------
 examples/cifar10/train.py | 33 +++++++++++++++++----------------
 1 file changed, 17 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e29ef5ff/examples/cifar10/train.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/train.py b/examples/cifar10/train.py
index b2ab4af..7657026 100644
--- a/examples/cifar10/train.py
+++ b/examples/cifar10/train.py
@@ -30,6 +30,7 @@ except ImportError:
 import numpy as np
 import os
 import argparse
+from tqdm import trange
 
 from singa import utils
 from singa import optimizer
@@ -145,21 +146,21 @@ def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100,
     for epoch in range(max_epoch):
         np.random.shuffle(idx)
         loss, acc = 0.0, 0.0
-        print('Epoch %d' % epoch)
-        for b in range(num_train_batch):
-            x = train_x[idx[b * batch_size: (b + 1) * batch_size]]
-            y = train_y[idx[b * batch_size: (b + 1) * batch_size]]
-            tx.copy_from_numpy(x)
-            ty.copy_from_numpy(y)
-            grads, (l, a) = net.train(tx, ty)
-            loss += l
-            acc += a
-            for (s, p, g) in zip(net.param_names(), net.param_values(), grads):
-                opt.apply_with_lr(epoch, get_lr(epoch), g, p, str(s), b)
-            # update progress bar
-            utils.update_progress(b * 1.0 / num_train_batch,
-                                  'training loss = %f, accuracy = %f' % (l, a))
-        info = '\ntraining loss = %f, training accuracy = %f, lr = %f' \
+        with trange(num_train_batch) as t:
+            t.set_description('Epoch={}'.format(epoch))
+            for b in t:
+                x = train_x[idx[b * batch_size: (b + 1) * batch_size]]
+                y = train_y[idx[b * batch_size: (b + 1) * batch_size]]
+                tx.copy_from_numpy(x)
+                ty.copy_from_numpy(y)
+                grads, (l, a) = net.train(tx, ty)
+                loss += l
+                acc += a
+                for (s, p, g) in zip(net.param_names(), net.param_values(), grads):
+                    opt.apply_with_lr(epoch, get_lr(epoch), g, p, str(s), b)
+                t.set_postfix(loss=l, accuracy=a)
+
+        info = 'Training loss = %f, training accuracy = %f, lr = %f' \
             % ((loss / num_train_batch), (acc / num_train_batch), get_lr(epoch))
         print(info)
 
@@ -173,7 +174,7 @@ def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100,
             loss += l
             acc += a
 
-        print('test loss = %f, test accuracy = %f' %
+        print('Test loss = %f, test accuracy = %f' %
               ((loss / num_test_batch), (acc / num_test_batch)))
     net.save('model', 20)  # save model params into checkpoint file