You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by zh...@apache.org on 2021/05/16 03:21:51 UTC

[singa] branch dev updated: add multiprocess distributed training for cifar-10

This is an automated email from the ASF dual-hosted git repository.

zhaojing pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/singa.git


The following commit(s) were added to refs/heads/dev by this push:
     new 024cd8e  add multiprocess distributed training for cifar-10
     new 4d7eccc  Merge pull request #844 from naili-xing/cifar-distributed-cnn-multiprocess
024cd8e is described below

commit 024cd8e10814f4733072f71c43674e83b58d2e8b
Author: nailixing <xi...@gmail.com>
AuthorDate: Sat May 15 23:05:04 2021 +0800

    add multiprocess distributed training for cifar-10
---
 .../cifar_distributed_cnn/train_multiprocess.py    | 115 +++++++++++++++++++++
 1 file changed, 115 insertions(+)

diff --git a/examples/cifar_distributed_cnn/train_multiprocess.py b/examples/cifar_distributed_cnn/train_multiprocess.py
new file mode 100644
index 0000000..5a09adc
--- /dev/null
+++ b/examples/cifar_distributed_cnn/train_multiprocess.py
@@ -0,0 +1,115 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+from singa import singa_wrap as singa
+from singa import opt
+# import opt
+from singa import tensor
+import argparse
+import train_cnn
+import multiprocessing
+
+import warnings
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+
+singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
+
+def run(args, local_rank, world_size, nccl_id):
+    sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5, dtype=singa_dtype[args.precision])
+    sgd = opt.DistOpt(sgd, nccl_id=nccl_id, local_rank=local_rank, world_size=world_size)
+    train_cnn.run(sgd.global_rank, sgd.world_size, sgd.local_rank, args.max_epoch,
+              args.batch_size, args.model, args.data, sgd, args.graph,
+              args.verbosity, args.dist_option, args.spars, args.precision)
+
+
+if __name__ == '__main__':
+    # Use argparse to get command config: max_epoch, model, data, etc., for single gpu training
+    parser = argparse.ArgumentParser(
+        description='Training using the autograd and graph.')
+    parser.add_argument('model',
+                        choices=['resnet', 'xceptionnet', 'cnn', 'mlp'],
+                        default='cnn')
+    parser.add_argument('data', choices=['cifar10', 'cifar100', 'mnist'], default='mnist')
+    parser.add_argument('-p',
+                        choices=['float32', 'float16'],
+                        default='float32',
+                        dest='precision')
+    parser.add_argument('-m',
+                        '--max-epoch',
+                        default=10,
+                        type=int,
+                        help='maximum epochs',
+                        dest='max_epoch')
+    parser.add_argument('-b',
+                        '--batch-size',
+                        default=32,
+                        type=int,
+                        help='batch size',
+                        dest='batch_size')
+    parser.add_argument('-l',
+                        '--learning-rate',
+                        default=0.005,
+                        type=float,
+                        help='initial learning rate',
+                        dest='lr')
+    parser.add_argument('-w',
+                        '--world-size',
+                        default=2,
+                        type=int,
+                        help='number of gpus to be used',
+                        dest='world_size')
+    parser.add_argument('-d',
+                        '--dist-option',
+                        default='plain',
+                        choices=['plain','half','partialUpdate','sparseTopK','sparseThreshold'],
+                        help='distibuted training options',
+                        dest='dist_option') # currently partialUpdate support graph=False only
+    parser.add_argument('-s',
+                        '--sparsification',
+                        default='0.05',
+                        type=float,
+                        help='the sparsity parameter used for sparsification, between 0 to 1',
+                        dest='spars')
+    parser.add_argument('-g',
+                        '--disable-graph',
+                        default='True',
+                        action='store_false',
+                        help='disable graph',
+                        dest='graph')
+    parser.add_argument('-v',
+                        '--log-verbosity',
+                        default=0,
+                        type=int,
+                        help='logging verbosity',
+                        dest='verbosity')
+
+    args = parser.parse_args()
+
+    # Generate a NCCL ID to be used for collective communication
+    nccl_id = singa.NcclIdHolder()
+
+    process = []
+    for local_rank in range(0, args.world_size):
+        process.append(
+            multiprocessing.Process(target=run,
+                                    args=(args, local_rank, args.world_size, nccl_id)))
+
+    for p in process:
+        p.start()