You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2022/08/24 02:58:31 UTC
[singa] branch dev updated: update resnet cifar 10
This is an automated email from the ASF dual-hosted git repository.
zhaojing pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/singa.git
The following commit(s) were added to refs/heads/dev by this push:
new 873fe05f update resnet cifar 10
new 767dd5df Merge pull request #978 from zlheui/update-resnet-cifar-10
873fe05f is described below
commit 873fe05f12044b76720a54736069d12db2e6b6bc
Author: zhulei <zl...@gmail.com>
AuthorDate: Wed Aug 24 10:32:09 2022 +0800
update resnet cifar 10
---
examples/cnn/autograd/resnet_cifar10.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/examples/cnn/autograd/resnet_cifar10.py b/examples/cnn/autograd/resnet_cifar10.py
index 3c6876f6..75417369 100644
--- a/examples/cnn/autograd/resnet_cifar10.py
+++ b/examples/cnn/autograd/resnet_cifar10.py
@@ -199,7 +199,7 @@ def train_cifar10(DIST=False,
idx = np.arange(train_x.shape[0], dtype=np.int32)
if DIST:
- #Sychronize the initial parameters
+ # Sychronize the initial parameters
autograd.training = True
x = np.random.randn(batch_size, 3, IMG_SIZE,
IMG_SIZE).astype(np.float32)
@@ -220,7 +220,7 @@ def train_cifar10(DIST=False,
if ((DIST == False) or (sgd.global_rank == 0)):
print('Starting Epoch %d:' % (epoch))
- #Training phase
+ # Training phase
autograd.training = True
train_correct = np.zeros(shape=[1], dtype=np.float32)
test_correct = np.zeros(shape=[1], dtype=np.float32)
@@ -262,7 +262,7 @@ def train_cifar10(DIST=False,
for p in param:
synchronize(p, sgd)
- #Evaulation phase
+ # Evaulation phase
autograd.training = False
for b in range(num_test_batch):
x = test_x[b * batch_size:(b + 1) * batch_size]