You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by ch...@apache.org on 2021/05/14 12:32:16 UTC
[singa] branch dev updated: add distributed training for cifar-10
using resnet
This is an automated email from the ASF dual-hosted git repository.
chrishkchris pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/singa.git
The following commit(s) were added to refs/heads/dev by this push:
new b2456f2 add distributed training for cifar-10 using resnet
new 7c4bc91 Merge pull request #843 from lzjpaul/cifar-distributed-cnn
b2456f2 is described below
commit b2456f2c63a2a498d57dc42ebeb304b66631d54c
Author: zhaojing <zh...@comp.nus.edu.sg>
AuthorDate: Fri May 14 19:38:06 2021 +0800
add distributed training for cifar-10 using resnet
---
examples/cifar_distributed_cnn/README.md | 46 ++++++++++++++
examples/cifar_distributed_cnn/train_mpi.py | 94 +++++++++++++++++++++++++++++
2 files changed, 140 insertions(+)
diff --git a/examples/cifar_distributed_cnn/README.md b/examples/cifar_distributed_cnn/README.md
new file mode 100644
index 0000000..4af7916
--- /dev/null
+++ b/examples/cifar_distributed_cnn/README.md
@@ -0,0 +1,46 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# Image Classification using Convolutional Neural Networks
+
+Examples inside this folder show how to train CNN models using
+SINGA for image classification.
+
+* `data` includes the scripts for preprocessing image datasets.
+ Currently, MNIST, CIFAR10 and CIFAR100 are included.
+
+* `model` includes the CNN model construction codes by creating
+ a subclass of `Module` to wrap the neural network operations
+ of each model. Then computational graph is enabled to optimized
+ the memory and efficiency.
+
+* `autograd` includes the codes to train CNN models by calling the
+ [neural network operations](../../python/singa/autograd.py) imperatively.
+ The computational graph is not created.
+
+* `train_cnn.py` is the training script, which controls the training flow by
+ doing BackPropagation and SGD update.
+
+* `train_multiprocess.py` is the script for distributed training on a single
+ node with multiple GPUs; it uses Python's multiprocessing module and NCCL.
+
+* `train_mpi.py` is the script for distributed training (among multiple nodes)
+ using MPI and NCCL for communication.
+
+* `benchmark.py` tests the training throughput using `ResNet50` as the workload.
diff --git a/examples/cifar_distributed_cnn/train_mpi.py b/examples/cifar_distributed_cnn/train_mpi.py
new file mode 100644
index 0000000..dc9151b
--- /dev/null
+++ b/examples/cifar_distributed_cnn/train_mpi.py
@@ -0,0 +1,94 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+from singa import singa_wrap as singa
+from singa import opt
+# import opt
+from singa import tensor
+import argparse
+import train_cnn
+import warnings
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+
+singa_dtype = {"float32": tensor.float32}
+
+if __name__ == '__main__':
+ # Use argparse to get command config: max_epoch, model, data, etc., for single gpu training
+ parser = argparse.ArgumentParser(
+ description='Training using the autograd and graph.')
+ parser.add_argument('model',
+ choices=['cnn', 'resnet', 'xceptionnet', 'mlp'],
+ default='cnn')
+ parser.add_argument('data', choices=['mnist', 'cifar10', 'cifar100'], default='mnist')
+ parser.add_argument('-p',
+ choices=['float32', 'float16'],
+ default='float32',
+ dest='precision')
+ parser.add_argument('-m',
+ '--max-epoch',
+ default=10,
+ type=int,
+ help='maximum epochs',
+ dest='max_epoch')
+ parser.add_argument('-b',
+ '--batch-size',
+ default=32,
+ type=int,
+ help='batch size',
+ dest='batch_size')
+ parser.add_argument('-l',
+ '--learning-rate',
+ default=0.005,
+ type=float,
+ help='initial learning rate',
+ dest='lr')
+ parser.add_argument('-d',
+ '--dist-option',
+ default='plain',
+ choices=['plain','half','partialUpdate','sparseTopK','sparseThreshold'],
+ help='distibuted training options',
+ dest='dist_option') # currently partialUpdate support graph=False only
+ parser.add_argument('-s',
+ '--sparsification',
+ default='0.05',
+ type=float,
+ help='the sparsity parameter used for sparsification, between 0 to 1',
+ dest='spars')
+ parser.add_argument('-g',
+ '--disable-graph',
+ default='True',
+ action='store_false',
+ help='disable graph',
+ dest='graph')
+ parser.add_argument('-v',
+ '--log-verbosity',
+ default=0,
+ type=int,
+ help='logging verbosity',
+ dest='verbosity')
+
+ args = parser.parse_args()
+
+ sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5)
+ sgd = opt.DistOpt(sgd)
+
+ train_cnn.run(sgd.global_rank, sgd.world_size, sgd.local_rank, args.max_epoch,
+ args.batch_size, args.model, args.data, sgd, args.graph,
+ args.verbosity, args.dist_option, args.spars, args.precision)