You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2022/08/24 02:58:31 UTC

[singa] branch dev updated: update resnet cifar 10

This is an automated email from the ASF dual-hosted git repository.

zhaojing pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/singa.git


The following commit(s) were added to refs/heads/dev by this push:
     new 873fe05f update resnet cifar 10
     new 767dd5df Merge pull request #978 from zlheui/update-resnet-cifar-10
873fe05f is described below

commit 873fe05f12044b76720a54736069d12db2e6b6bc
Author: zhulei <zl...@gmail.com>
AuthorDate: Wed Aug 24 10:32:09 2022 +0800

    update resnet cifar 10
---
 examples/cnn/autograd/resnet_cifar10.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/examples/cnn/autograd/resnet_cifar10.py b/examples/cnn/autograd/resnet_cifar10.py
index 3c6876f6..75417369 100644
--- a/examples/cnn/autograd/resnet_cifar10.py
+++ b/examples/cnn/autograd/resnet_cifar10.py
@@ -199,7 +199,7 @@ def train_cifar10(DIST=False,
     idx = np.arange(train_x.shape[0], dtype=np.int32)
 
     if DIST:
-        #Sychronize the initial parameters
+        # Sychronize the initial parameters
         autograd.training = True
         x = np.random.randn(batch_size, 3, IMG_SIZE,
                             IMG_SIZE).astype(np.float32)
@@ -220,7 +220,7 @@ def train_cifar10(DIST=False,
         if ((DIST == False) or (sgd.global_rank == 0)):
             print('Starting Epoch %d:' % (epoch))
 
-        #Training phase
+        # Training phase
         autograd.training = True
         train_correct = np.zeros(shape=[1], dtype=np.float32)
         test_correct = np.zeros(shape=[1], dtype=np.float32)
@@ -262,7 +262,7 @@ def train_cifar10(DIST=False,
             for p in param:
                 synchronize(p, sgd)
 
-        #Evaulation phase
+        # Evaulation phase
         autograd.training = False
         for b in range(num_test_batch):
             x = test_x[b * batch_size:(b + 1) * batch_size]