You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2022/08/25 05:25:34 UTC
[singa] branch dev updated: Update inline comments for iterative training
This is an automated email from the ASF dual-hosted git repository.
zhaojing pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/singa.git
The following commit(s) were added to refs/heads/dev by this push:
new 35b55aca Update inline comments for iterative training
new 51b5c760 Merge pull request #984 from NLGithubWP/update-demos
35b55aca is described below
commit 35b55acabccf53175a44f8c7c2831891d48f275a
Author: NLGithubWP <xi...@gmail.com>
AuthorDate: Thu Aug 25 12:55:51 2022 +0800
Update inline comments for iterative training
---
examples/demos/Classification/BloodMnist/ClassDemo.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/examples/demos/Classification/BloodMnist/ClassDemo.py b/examples/demos/Classification/BloodMnist/ClassDemo.py
index adccac4f..a6872f8c 100644
--- a/examples/demos/Classification/BloodMnist/ClassDemo.py
+++ b/examples/demos/Classification/BloodMnist/ClassDemo.py
@@ -206,7 +206,7 @@ model = CNNModel(num_classes=num_class)
criterion = layer.SoftMaxCrossEntropy()
optimizer_ft = opt.Adam(lr=1e-3)
-# start training
+# Start training
dev = device.create_cpu_device()
dev.SetRandSeed(0)
np.random.seed(0)
@@ -234,10 +234,10 @@ for epoch in range(max_epoch):
test_correct = np.zeros(shape=[1], dtype=np.float32)
train_loss = np.zeros(shape=[1], dtype=np.float32)
- # training part
+ # Training part
model.train()
for b in tqdm(range(num_train_batch)):
- # extract batch from image list
+ # Extract batch from image list
x, y = train_dataset.batchgenerator(idx[b * batch_size:(b + 1) * batch_size],
batch_size=batch_size, data_size=(3, model.input_size, model.input_size))
x = x.astype(np_dtype['float32'])
@@ -252,7 +252,7 @@ for epoch in range(max_epoch):
(train_loss, train_correct /
(num_train_batch * batch_size)))
- # validation part
+ # Validation part
model.eval()
for b in tqdm(range(num_val_batch)):
x, y = train_dataset.batchgenerator(idx[b * batch_size:(b + 1) * batch_size],