You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by an...@apache.org on 2019/03/22 21:49:22 UTC

[incubator-mxnet] branch master updated: Add examples of running MXNet with Horovod (#14286)

This is an automated email from the ASF dual-hosted git repository.

anirudh2290 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 056fce4  Add examples of running MXNet with Horovod (#14286)
056fce4 is described below

commit 056fce47eb340af3fa6afa56f471b2c515f2b002
Author: Lin Yuan <ap...@gmail.com>
AuthorDate: Fri Mar 22 14:48:58 2019 -0700

    Add examples of running MXNet with Horovod (#14286)
    
    * Add examples for MXNet with Horovod
    
    * update readme
    
    * update examples
    
    * update README
    
    * update mnist_module example
    
    * Update README
    
    * update README
    
    * update README
    
    * update README
---
 example/distributed_training-horovod/README.md     | 201 +++++++++
 .../distributed_training-horovod/gluon_mnist.py    | 186 +++++++++
 .../distributed_training-horovod/module_mnist.py   | 162 ++++++++
 .../resnet50_imagenet.py                           | 453 +++++++++++++++++++++
 4 files changed, 1002 insertions(+)

diff --git a/example/distributed_training-horovod/README.md b/example/distributed_training-horovod/README.md
new file mode 100644
index 0000000..c477604
--- /dev/null
+++ b/example/distributed_training-horovod/README.md
@@ -0,0 +1,201 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+# Distributed Training using MXNet with Horovod 
+[Horovod](https://github.com/horovod/horovod) is a distributed training framework that demonstrates 
+excellent scaling efficiency for dense models running on a large number of nodes. It currently 
+supports mainstream deep learning frameworks such as MXNet, TensorFlow, Keras, and PyTorch. 
+It is created at Uber and currently hosted by the [Linux Foundation Deep Learning](https://lfdl.io)(LF DL). 
+
+MXNet is supported in Horovod 0.16.0 [release](https://eng.uber.com/horovod-pyspark-apache-mxnet-support/).
+
+## What's New?
+Compared with the standard distributed training script in MXNet which uses parameter server to 
+distribute and aggregate parameters, Horovod uses ring allreduce and/or tree-based allreduce algorithm 
+to communicate parameters between workers. There is no dedicated server and the communication data size 
+between workers does not depend on the number of workers. Therefore, it scales well in the case where 
+there are a large number of workers and network bandwidth is the bottleneck.
+
+# Install
+## Install MXNet
+```bash
+$ pip install mxnet
+```
+**Note**: There is a [known issue](https://github.com/horovod/horovod/issues/884) when running Horovod with MXNet on a Linux system with GCC version 5.X and above. We recommend users to build MXNet from source following this [guide](https://mxnet.incubator.apache.org/install/build_from_source.html) as a workaround for now. Also mxnet-mkl package in 1.4.0 release does not support Horovod.
+
+## Install Horovod
+```bash
+$ pip install horovod
+```
+
+This basic installation is good for laptops and for getting to know Horovod.
+If you're installing Horovod on a server with GPUs, read the [Horovod on GPU](https://github.com/horovod/horovod/blob/master/docs/gpus.md) page.
+If you want to use Docker, read the [Horovod in Docker](https://github.com/horovod/horovod/blob/master/docs/docker.md) page.
+
+## Install MPI
+MPI is required to run distributed training with Horovod. Install [Open MPI](https://www.open-mpi.org/) or another MPI implementation.
+Steps to install Open MPI are listed [here](https://www.open-mpi.org/faq/?category=building#easy-build).
+
+**Note**: Open MPI 3.1.3 has an issue that may cause hangs.  It is recommended
+to downgrade to Open MPI 3.1.2 or upgrade to Open MPI 4.0.0.
+
+# Usage
+
+To run MXNet with Horovod, make the following additions to your training script:
+
+1. Run `hvd.init()`.
+
+2. Pin the context to a processor using `hvd.local_rank()`.
+    Typically, each Horovod worker is associated with one process. The local rank is a unique ID specifically
+    for all processes running Horovod job on the same node.
+
+3. Scale the learning rate by number of workers. Effective batch size in synchronous distributed training is scaled by
+    the number of workers. An increase in learning rate compensates for the increased batch size.
+
+4. Wrap optimizer in `hvd.DistributedOptimizer`.  The distributed optimizer delegates gradient computation
+    to the original optimizer, averages gradients using *allreduce* or *allgather*, and then applies those averaged
+    gradients.
+
+5. Add `hvd.broadcast_parameters` to broadcast initial variable states from rank 0 to all other processes.
+    This is necessary to ensure consistent initialization of all workers when training is started with random weights or
+    restored from a checkpoint. 
+
+# Example
+
+Here we provide the building blocks to train a model using MXNet with Horovod.
+The full examples are in [MNIST](gluon_mnist.py) and [ImageNet](resnet50_imagenet.py).
+
+## Gluon API
+```python
+from mxnet import autograd, gluon
+import mxnet as mx
+import horovod.mxnet as hvd
+
+# Initialize Horovod
+hvd.init()
+
+# Set context to current process 
+context = mx.cpu(hvd.local_rank()) if args.no_cuda else mx.gpu(hvd.local_rank())
+
+num_workers = hvd.size()
+
+# Build model
+model = ...
+model.hybridize()
+
+# Define hyper parameters
+optimizer_params = ...
+
+# Add Horovod Distributed Optimizer
+opt = mx.optimizer.create('sgd', **optimizer_params)
+opt = hvd.DistributedOptimizer(opt)
+
+# Initialize parameters
+model.initialize(initializer, ctx=context)
+
+# Fetch and broadcast parameters
+params = model.collect_params()
+if params is not None:
+    hvd.broadcast_parameters(params, root_rank=0)
+
+# Create trainer and loss function
+trainer = gluon.Trainer(params, opt, kvstore=None)
+loss_fn = ...
+
+# Train model
+for epoch in range(num_epoch):
+    train_data.reset()
+    for nbatch, batch in enumerate(train_data, start=1):
+        data = batch.data[0].as_in_context(context)
+        label = batch.label[0].as_in_context(context)
+        with autograd.record():
+            output = model(data.astype(dtype, copy=False))
+            loss = loss_fn(output, label)
+        loss.backward()
+        trainer.step(batch_size)
+```
+
+## Module API
+```python
+import mxnet as mx
+import horovod.mxnet as hvd
+
+# Initialize Horovod
+hvd.init()
+
+# Set context to current process
+context = mx.cpu(hvd.local_rank()) if args.no_cuda else mx.gpu(hvd.local_rank())
+num_workers = hvd.size()
+
+# Build model
+model = ...
+
+# Define hyper parameters
+optimizer_params = ...
+
+# Add Horovod Distributed Optimizer
+opt = mx.optimizer.create('sgd', **optimizer_params)
+opt = hvd.DistributedOptimizer(opt)
+
+# Initialize parameters
+initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in",
+                             magnitude=2)
+model.bind(data_shapes=train_data.provide_data,
+           label_shapes=train_data.provide_label)
+model.init_params(initializer)
+
+# Fetch and broadcast parameters
+(arg_params, aux_params) = model.get_params()
+if arg_params:
+    hvd.broadcast_parameters(arg_params, root_rank=0)
+if aux_params:
+    hvd.broadcast_parameters(aux_params, root_rank=0)
+model.set_params(arg_params=arg_params, aux_params=aux_params)
+
+# Train model
+model.fit(train_data,
+          kvstore=None,
+          optimizer=opt,
+          num_epoch=num_epoch)
+```
+
+
+# Running Horovod
+
+The example commands below show how to run distributed training. See the 
+[Running Horovod](https://github.com/horovod/horovod/blob/master/docs/running.md)
+page for more instructions, including RoCE/InfiniBand tweaks and tips for dealing with hangs.
+
+1. To run on a machine with 4 CPUs:
+
+```bash
+$ mpirun -np 4 \
+    -H localhost:4 \
+    -bind-to none -map-by slot \
+    python train.py
+```
+
+2. To run on 2 machines with 4 GPUs each:
+
+```bash
+$ mpirun -np 8 \
+    -H server1:4,server2:4 \
+    -bind-to none -map-by slot \
+    -x NCCL_DEBUG=INFO \
+    -mca pml ob1 -mca btl ^openib \
+    python train.py
+```
\ No newline at end of file
diff --git a/example/distributed_training-horovod/gluon_mnist.py b/example/distributed_training-horovod/gluon_mnist.py
new file mode 100644
index 0000000..7e4be58
--- /dev/null
+++ b/example/distributed_training-horovod/gluon_mnist.py
@@ -0,0 +1,186 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import argparse
+import logging
+import os
+import zipfile
+import time
+
+import mxnet as mx
+import horovod.mxnet as hvd
+from mxnet import autograd, gluon, nd
+from mxnet.test_utils import download
+
+# Training settings
+parser = argparse.ArgumentParser(description='MXNet MNIST Example')
+
+parser.add_argument('--batch-size', type=int, default=64,
+                    help='training batch size (default: 64)')
+parser.add_argument('--dtype', type=str, default='float32',
+                    help='training data type (default: float32)')
+parser.add_argument('--epochs', type=int, default=5,
+                    help='number of training epochs (default: 5)')
+parser.add_argument('--lr', type=float, default=0.01,
+                    help='learning rate (default: 0.01)')
+parser.add_argument('--momentum', type=float, default=0.9,
+                    help='SGD momentum (default: 0.9)')
+parser.add_argument('--use-gpu', action='store_true', default=False,
+                    help='run training on GPU (default: False)')
+args = parser.parse_args()
+
+logging.basicConfig(level=logging.INFO)
+logging.info(args)
+
+
+# Function to get mnist iterator given a rank
+def get_mnist_iterator(rank):
+    data_dir = "data-%d" % rank
+    if not os.path.isdir(data_dir):
+        os.makedirs(data_dir)
+    zip_file_path = download('http://data.mxnet.io/mxnet/data/mnist.zip',
+                             dirname=data_dir)
+    with zipfile.ZipFile(zip_file_path) as zf:
+        zf.extractall(data_dir)
+
+    input_shape = (1, 28, 28)
+    batch_size = args.batch_size
+
+    train_iter = mx.io.MNISTIter(
+        image="%s/train-images-idx3-ubyte" % data_dir,
+        label="%s/train-labels-idx1-ubyte" % data_dir,
+        input_shape=input_shape,
+        batch_size=batch_size,
+        shuffle=True,
+        flat=False,
+        num_parts=hvd.size(),
+        part_index=hvd.rank()
+    )
+
+    val_iter = mx.io.MNISTIter(
+        image="%s/t10k-images-idx3-ubyte" % data_dir,
+        label="%s/t10k-labels-idx1-ubyte" % data_dir,
+        input_shape=input_shape,
+        batch_size=batch_size,
+        flat=False,
+    )
+
+    return train_iter, val_iter
+
+
+# Function to define neural network
+def conv_nets():
+    net = gluon.nn.HybridSequential()
+    with net.name_scope():
+        net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu'))
+        net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
+        net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu'))
+        net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
+        net.add(gluon.nn.Flatten())
+        net.add(gluon.nn.Dense(512, activation="relu"))
+        net.add(gluon.nn.Dense(10))
+    return net
+
+
+# Function to evaluate accuracy for a model
+def evaluate(model, data_iter, context):
+    data_iter.reset()
+    metric = mx.metric.Accuracy()
+    for _, batch in enumerate(data_iter):
+        data = batch.data[0].as_in_context(context)
+        label = batch.label[0].as_in_context(context)
+        output = model(data.astype(args.dtype, copy=False))
+        metric.update([label], [output])
+
+    return metric.get()
+
+
+# Initialize Horovod
+hvd.init()
+
+# Horovod: pin context to local rank
+context = mx.gpu(hvd.local_rank()) if args.use_gpu else mx.cpu(hvd.local_rank())
+num_workers = hvd.size()
+
+# Load training and validation data
+train_data, val_data = get_mnist_iterator(hvd.rank())
+
+# Build model
+model = conv_nets()
+model.cast(args.dtype)
+model.hybridize()
+
+# Define hyper parameters
+optimizer_params = {'momentum': args.momentum,
+                    'learning_rate': args.lr * hvd.size(),
+                    'rescale_grad': 1.0 / args.batch_size}
+
+# Add Horovod Distributed Optimizer
+opt = mx.optimizer.create('sgd', **optimizer_params)
+opt = hvd.DistributedOptimizer(opt)
+
+# Initialize parameters
+initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in",
+                             magnitude=2)
+model.initialize(initializer, ctx=context)
+
+# Fetch and broadcast parameters
+params = model.collect_params()
+if params is not None:
+    hvd.broadcast_parameters(params, root_rank=0)
+
+# Create trainer, loss function and train metric
+trainer = gluon.Trainer(params, opt, kvstore=None)
+loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
+metric = mx.metric.Accuracy()
+
+# Train model
+for epoch in range(args.epochs):
+    tic = time.time()
+    train_data.reset()
+    metric.reset()
+    for nbatch, batch in enumerate(train_data, start=1):
+        data = batch.data[0].as_in_context(context)
+        label = batch.label[0].as_in_context(context)
+        with autograd.record():
+            output = model(data.astype(args.dtype, copy=False))
+            loss = loss_fn(output, label)
+        loss.backward()
+        trainer.step(args.batch_size)
+        metric.update([label], [output])
+
+        if nbatch % 100 == 0:
+            name, acc = metric.get()
+            logging.info('[Epoch %d Batch %d] Training: %s=%f' %
+                         (epoch, nbatch, name, acc))
+
+    if hvd.rank() == 0:
+        elapsed = time.time() - tic
+        speed = nbatch * args.batch_size * hvd.size() / elapsed
+        logging.info('Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f',
+                     epoch, speed, elapsed)
+
+    # Evaluate model accuracy
+    _, train_acc = metric.get()
+    name, val_acc = evaluate(model, val_data, context)
+    if hvd.rank() == 0:
+        logging.info('Epoch[%d]\tTrain: %s=%f\tValidation: %s=%f', epoch, name,
+                     train_acc, name, val_acc)
+
+    if hvd.rank() == 0 and epoch == args.epochs - 1:
+        assert val_acc > 0.96, "Achieved accuracy (%f) is lower than expected\
+                                (0.96)" % val_acc
diff --git a/example/distributed_training-horovod/module_mnist.py b/example/distributed_training-horovod/module_mnist.py
new file mode 100644
index 0000000..5c02aae
--- /dev/null
+++ b/example/distributed_training-horovod/module_mnist.py
@@ -0,0 +1,162 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import argparse
+import logging
+import os
+import zipfile
+
+import horovod.mxnet as hvd
+import mxnet as mx
+from mxnet.test_utils import download
+
+# Training settings
+parser = argparse.ArgumentParser(description='MXNet MNIST Example')
+parser.add_argument('--batch-size', type=int, default=64,
+                    help='training batch size (default: 64)')
+parser.add_argument('--dtype', type=str, default='float32',
+                    help='training data type (default: float32)')
+parser.add_argument('--epochs', type=int, default=5,
+                    help='number of training epochs (default: 5)')
+parser.add_argument('--lr', type=float, default=0.05,
+                    help='learning rate (default: 0.05)')
+parser.add_argument('--momentum', type=float, default=0.5,
+                    help='SGD momentum (default: 0.5)')
+parser.add_argument('--no-cuda', action='store_true', default=False,
+                    help='disables CUDA training (default: False)')
+args = parser.parse_args()
+
+if not args.no_cuda:
+    # Disable CUDA if there are no GPUs.
+    if not mx.test_utils.list_gpus():
+        args.no_cuda = True
+
+logging.basicConfig(level=logging.INFO)
+logging.info(args)
+
+
+# Function to get mnist iterator given a rank
+def get_mnist_iterator(rank):
+    data_dir = "data-%d" % rank
+    if not os.path.isdir(data_dir):
+        os.makedirs(data_dir)
+    zip_file_path = download('http://data.mxnet.io/mxnet/data/mnist.zip',
+                             dirname=data_dir)
+    with zipfile.ZipFile(zip_file_path) as zf:
+        zf.extractall(data_dir)
+
+    input_shape = (1, 28, 28)
+    batch_size = args.batch_size
+
+    train_iter = mx.io.MNISTIter(
+        image="%s/train-images-idx3-ubyte" % data_dir,
+        label="%s/train-labels-idx1-ubyte" % data_dir,
+        input_shape=input_shape,
+        batch_size=batch_size,
+        shuffle=True,
+        flat=False,
+        num_parts=hvd.size(),
+        part_index=hvd.rank()
+    )
+
+    val_iter = mx.io.MNISTIter(
+        image="%s/t10k-images-idx3-ubyte" % data_dir,
+        label="%s/t10k-labels-idx1-ubyte" % data_dir,
+        input_shape=input_shape,
+        batch_size=batch_size,
+        flat=False,
+        num_parts=hvd.size(),
+        part_index=hvd.rank()
+    )
+
+    return train_iter, val_iter
+
+# Step 1: initialize Horovod
+hvd.init()
+
+# Horovod: pin context to process
+context = mx.cpu(hvd.local_rank()) if args.no_cuda else mx.gpu(hvd.local_rank())
+
+# Step 2: load data
+train_iter, val_iter = get_mnist_iterator(hvd.rank())
+
+
+# Step 3: define network
+def conv_net():
+    # placeholder for data
+    data = mx.sym.var('data')
+    # first conv layer
+    conv1 = mx.sym.Convolution(data=data, kernel=(5, 5), num_filter=10)
+    relu1 = mx.sym.Activation(data=conv1, act_type='relu')
+    pool1 = mx.sym.Pooling(data=relu1, pool_type='max', kernel=(2, 2),
+                           stride=(2, 2))
+    # second conv layer
+    conv2 = mx.sym.Convolution(data=pool1, kernel=(5, 5), num_filter=20)
+    relu2 = mx.sym.Activation(data=conv2, act_type='relu')
+    pool2 = mx.sym.Pooling(data=relu2, pool_type='max', kernel=(2, 2),
+                           stride=(2, 2))
+    # first fully connected layer
+    flatten = mx.sym.flatten(data=pool2)
+    fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=50)
+    relu3 = mx.sym.Activation(data=fc1, act_type='relu')
+    # second fully connected layer
+    fc2 = mx.sym.FullyConnected(data=relu3, num_hidden=10)
+    # softmax loss
+    loss = mx.sym.SoftmaxOutput(data=fc2, name='softmax')
+    return loss
+
+
+# Step 4: fit the model
+net = conv_net()
+model = mx.mod.Module(symbol=net, context=context)
+optimizer_params = {'learning_rate': args.lr * hvd.size(),
+                    'rescale_grad': 1.0 / args.batch_size}
+opt = mx.optimizer.create('sgd', **optimizer_params)
+
+# Horovod: wrap optimizer with DistributedOptimizer
+opt = hvd.DistributedOptimizer(opt)
+
+initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in",
+                             magnitude=2)
+model.bind(data_shapes=train_iter.provide_data,
+           label_shapes=train_iter.provide_label)
+model.init_params(initializer)
+
+# Horovod: fetch and broadcast parameters
+(arg_params, aux_params) = model.get_params()
+if arg_params is not None:
+    hvd.broadcast_parameters(arg_params, root_rank=0)
+if aux_params is not None:
+    hvd.broadcast_parameters(aux_params, root_rank=0)
+model.set_params(arg_params=arg_params, aux_params=aux_params)
+
+model.fit(train_iter,  # train data
+          kvstore=None,  # no kvstore
+          eval_data=val_iter,  # validation data
+          optimizer=opt,  # use SGD to train
+          eval_metric='acc',  # report accuracy during training
+          batch_end_callback=mx.callback.Speedometer(args.batch_size),
+          num_epoch=args.epochs)  # train for at most 10 dataset passes
+
+# Step 5: evaluate model accuracy
+acc = mx.metric.Accuracy()
+model.score(val_iter, acc)
+
+if hvd.rank() == 0:
+    print(acc)
+    assert acc.get()[1] > 0.96, "Achieved accuracy (%f) is lower than \
+                                expected (0.96)" % acc.get()[1]
diff --git a/example/distributed_training-horovod/resnet50_imagenet.py b/example/distributed_training-horovod/resnet50_imagenet.py
new file mode 100644
index 0000000..9b99340
--- /dev/null
+++ b/example/distributed_training-horovod/resnet50_imagenet.py
@@ -0,0 +1,453 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import argparse
+import logging
+import math
+import os
+import time
+
+from gluoncv.model_zoo import get_model
+import horovod.mxnet as hvd
+import mxnet as mx
+import numpy as np
+from mxnet import autograd, gluon, lr_scheduler
+from mxnet.io import DataBatch, DataIter
+
+
+# Training settings
+parser = argparse.ArgumentParser(description='MXNet ImageNet Example',
+                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--use-rec', action='store_true', default=False,
+                    help='use image record iter for data input (default: False)')
+parser.add_argument('--data-nthreads', type=int, default=2,
+                    help='number of threads for data decoding (default: 2)')
+parser.add_argument('--rec-train', type=str, default='',
+                    help='the training data')
+parser.add_argument('--rec-train-idx', type=str, default='',
+                    help='the index of training data')
+parser.add_argument('--rec-val', type=str, default='',
+                    help='the validation data')
+parser.add_argument('--rec-val-idx', type=str, default='',
+                    help='the index of validation data')
+parser.add_argument('--batch-size', type=int, default=128,
+                    help='training batch size per device (default: 128)')
+parser.add_argument('--dtype', type=str, default='float32',
+                    help='data type for training (default: float32)')
+parser.add_argument('--num-epochs', type=int, default=90,
+                    help='number of training epochs (default: 90)')
+parser.add_argument('--lr', type=float, default=0.05,
+                    help='learning rate for a single GPU (default: 0.05)')
+parser.add_argument('--momentum', type=float, default=0.9,
+                    help='momentum value for optimizer (default: 0.9)')
+parser.add_argument('--wd', type=float, default=0.0001,
+                    help='weight decay rate (default: 0.0001)')
+parser.add_argument('--lr-mode', type=str, default='poly',
+                    help='learning rate scheduler mode. Options are step, \
+                    poly and cosine (default: poly)')
+parser.add_argument('--lr-decay', type=float, default=0.1,
+                    help='decay rate of learning rate (default: 0.1)')
+parser.add_argument('--lr-decay-epoch', type=str, default='40,60',
+                    help='epoches at which learning rate decays (default: 40,60)')
+parser.add_argument('--warmup-lr', type=float, default=0.0,
+                    help='starting warmup learning rate (default: 0.0)')
+parser.add_argument('--warmup-epochs', type=int, default=10,
+                    help='number of warmup epochs (default: 10)')
+parser.add_argument('--last-gamma', action='store_true', default=False,
+                    help='whether to init gamma of the last BN layer in \
+                    each bottleneck to 0 (default: False)')
+parser.add_argument('--model', type=str, default='resnet50_v1',
+                    help='type of model to use. see vision_model for options.')
+parser.add_argument('--mode', type=str, default='module',
+                    help='mode in which to train the model. options are \
+                    module, gluon (default: module)')
+parser.add_argument('--use-pretrained', action='store_true', default=False,
+                    help='load pretrained model weights (default: False)')
+parser.add_argument('--no-cuda', action='store_true', default=False,
+                    help='disables CUDA training (default: False)')
+parser.add_argument('--eval-epoch', action='store_true', default=False,
+                    help='evaluate validation accuracy after each epoch \
+                    when training in module mode (default: False)')
+parser.add_argument('--eval-frequency', type=int, default=0,
+                    help='frequency of evaluating validation accuracy \
+                    when training with gluon mode (default: 0)')
+parser.add_argument('--log-interval', type=int, default=0,
+                    help='number of batches to wait before logging (default: 0)')
+parser.add_argument('--save-frequency', type=int, default=0,
+                    help='frequency of model saving (default: 0)')
+
+
+args = parser.parse_args()
+
+logging.basicConfig(level=logging.INFO)
+logging.info(args)
+
+# Horovod: initialize Horovod
+hvd.init()
+num_workers = hvd.size()
+rank = hvd.rank()
+local_rank = hvd.local_rank()
+
+num_classes = 1000
+num_training_samples = 1281167
+batch_size = args.batch_size
+epoch_size = \
+    int(math.ceil(int(num_training_samples // num_workers) / batch_size))
+
+if args.lr_mode == 'step':
+    lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(',')]
+    steps = [epoch_size * x for x in lr_decay_epoch]
+    lr_sched = lr_scheduler.MultiFactorScheduler(
+        step=steps,
+        factor=args.lr_decay,
+        base_lr=(args.lr * num_workers),
+        warmup_steps=(args.warmup_epochs * epoch_size),
+        warmup_begin_lr=args.warmup_lr
+    )
+elif args.lr_mode == 'poly':
+    lr_sched = lr_scheduler.PolyScheduler(
+        args.num_epochs * epoch_size,
+        base_lr=(args.lr * num_workers),
+        pwr=2,
+        warmup_steps=(args.warmup_epochs * epoch_size),
+        warmup_begin_lr=args.warmup_lr
+    )
+elif args.lr_mode == 'cosine':
+    lr_sched = lr_scheduler.CosineScheduler(
+        args.num_epochs * epoch_size,
+        base_lr=(args.lr * num_workers),
+        warmup_steps=(args.warmup_epochs * epoch_size),
+        warmup_begin_lr=args.warmup_lr
+    )
+else:
+    raise ValueError('Invalid lr mode')
+
+# Function for reading data from record file
+# For more details about data loading in MXNet, please refer to
+# https://mxnet.incubator.apache.org/tutorials/basic/data.html?highlight=imagerecorditer
+def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size,
+                 data_nthreads):
+    rec_train = os.path.expanduser(rec_train)
+    rec_train_idx = os.path.expanduser(rec_train_idx)
+    rec_val = os.path.expanduser(rec_val)
+    rec_val_idx = os.path.expanduser(rec_val_idx)
+    jitter_param = 0.4
+    lighting_param = 0.1
+    mean_rgb = [123.68, 116.779, 103.939]
+
+    def batch_fn(batch, ctx):
+        data = batch.data[0].as_in_context(ctx)
+        label = batch.label[0].as_in_context(ctx)
+        return data, label
+
+    train_data = mx.io.ImageRecordIter(
+        path_imgrec=rec_train,
+        path_imgidx=rec_train_idx,
+        preprocess_threads=data_nthreads,
+        shuffle=True,
+        batch_size=batch_size,
+        label_width=1,
+        data_shape=(3, 224, 224),
+        mean_r=mean_rgb[0],
+        mean_g=mean_rgb[1],
+        mean_b=mean_rgb[2],
+        rand_mirror=True,
+        rand_crop=False,
+        random_resized_crop=True,
+        max_aspect_ratio=4. / 3.,
+        min_aspect_ratio=3. / 4.,
+        max_random_area=1,
+        min_random_area=0.08,
+        verbose=False,
+        brightness=jitter_param,
+        saturation=jitter_param,
+        contrast=jitter_param,
+        pca_noise=lighting_param,
+        num_parts=num_workers,
+        part_index=rank,
+        device_id=local_rank
+    )
+    # Kept each node to use full val data to make it easy to monitor results
+    val_data = mx.io.ImageRecordIter(
+        path_imgrec=rec_val,
+        path_imgidx=rec_val_idx,
+        preprocess_threads=data_nthreads,
+        shuffle=False,
+        batch_size=batch_size,
+        resize=256,
+        label_width=1,
+        rand_crop=False,
+        rand_mirror=False,
+        data_shape=(3, 224, 224),
+        mean_r=mean_rgb[0],
+        mean_g=mean_rgb[1],
+        mean_b=mean_rgb[2],
+        device_id=local_rank
+    )
+
+    return train_data, val_data, batch_fn
+
+# Create data iterator for synthetic data
+class SyntheticDataIter(DataIter):
+    def __init__(self, num_classes, data_shape, max_iter, dtype, ctx):
+        self.batch_size = data_shape[0]
+        self.cur_iter = 0
+        self.max_iter = max_iter
+        self.dtype = dtype
+        label = np.random.randint(0, num_classes, [self.batch_size, ])
+        data = np.random.uniform(-1, 1, data_shape)
+        self.data = mx.nd.array(data, dtype=self.dtype,
+                                ctx=ctx)
+        self.label = mx.nd.array(label, dtype=self.dtype,
+                                 ctx=ctx)
+
+    def __iter__(self):
+        return self
+
+    @property
+    def provide_data(self):
+        return [mx.io.DataDesc('data', self.data.shape, self.dtype)]
+
+    @property
+    def provide_label(self):
+        return [mx.io.DataDesc('softmax_label',
+                               (self.batch_size,), self.dtype)]
+
+    def next(self):
+        self.cur_iter += 1
+        if self.cur_iter <= self.max_iter:
+            return DataBatch(data=(self.data,),
+                             label=(self.label,),
+                             pad=0,
+                             index=None,
+                             provide_data=self.provide_data,
+                             provide_label=self.provide_label)
+        else:
+            raise StopIteration
+
+    def __next__(self):
+        return self.next()
+
+    def reset(self):
+        self.cur_iter = 0
+
+# Horovod: pin GPU to local rank
+context = mx.cpu(local_rank) if args.no_cuda else mx.gpu(local_rank)
+
+if args.use_rec:
+    # Fetch training and validation data if present
+    train_data, val_data, batch_fn = get_data_rec(args.rec_train,
+                                                  args.rec_train_idx,
+                                                  args.rec_val,
+                                                  args.rec_val_idx,
+                                                  batch_size,
+                                                  args.data_nthreads)
+else:
+    # Otherwise use synthetic data
+    image_shape = (3, 224, 224)
+    data_shape = (batch_size,) + image_shape
+    train_data = SyntheticDataIter(num_classes, data_shape, epoch_size,
+                                   np.float32, context)
+    val_data = None
+
+
+# Get model from GluonCV model zoo
+# https://gluon-cv.mxnet.io/model_zoo/index.html
+kwargs = {'ctx': context,
+          'pretrained': args.use_pretrained,
+          'classes': num_classes}
+if args.last_gamma:
+    kwargs['last_gamma'] = True
+net = get_model(args.model, **kwargs)
+net.cast(args.dtype)
+
+# Create initializer
+initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in",
+                             magnitude=2)
+
+# Create optimizer
+optimizer_params = {'wd': args.wd,
+                    'momentum': args.momentum,
+                    'rescale_grad': 1.0 / batch_size,
+                    'lr_scheduler': lr_sched}
+if args.dtype == 'float16':
+    optimizer_params['multi_precision'] = True
+opt = mx.optimizer.create('sgd', **optimizer_params)
+
+# Horovod: wrap optimizer with DistributedOptimizer
+opt = hvd.DistributedOptimizer(opt)
+
+
+def train_gluon():
+    def evaluate(epoch):
+        if not args.use_rec:
+            return
+
+        val_data.reset()
+        acc_top1 = mx.metric.Accuracy()
+        acc_top5 = mx.metric.TopKAccuracy(5)
+        for _, batch in enumerate(val_data):
+            data, label = batch_fn(batch, context)
+            output = net(data.astype(args.dtype, copy=False))
+            acc_top1.update([label], [output])
+            acc_top5.update([label], [output])
+
+        top1_name, top1_acc = acc_top1.get()
+        top5_name, top5_acc = acc_top5.get()
+        logging.info('Epoch[%d] Rank[%d]\tValidation-%s=%f\tValidation-%s=%f',
+                     epoch, rank, top1_name, top1_acc, top5_name, top5_acc)
+
+    # Hybridize and initialize model
+    net.hybridize()
+    net.initialize(initializer, ctx=context)
+
+    # Horovod: fetch and broadcast parameters
+    params = net.collect_params()
+    if params is not None:
+        hvd.broadcast_parameters(params, root_rank=0)
+
+    # Create trainer, loss function and train metric
+    trainer = gluon.Trainer(params, opt, kvstore=None)
+    loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
+    metric = mx.metric.Accuracy()
+
+    # Train model
+    for epoch in range(args.num_epochs):
+        tic = time.time()
+        if args.use_rec:
+            train_data.reset()
+        metric.reset()
+
+        btic = time.time()
+        for nbatch, batch in enumerate(train_data, start=1):
+            data, label = batch_fn(batch, context)
+            with autograd.record():
+                output = net(data.astype(args.dtype, copy=False))
+                loss = loss_fn(output, label)
+            loss.backward()
+            trainer.step(batch_size)
+
+            metric.update([label], [output])
+            if args.log_interval and nbatch % args.log_interval == 0:
+                name, acc = metric.get()
+                logging.info('Epoch[%d] Rank[%d] Batch[%d]\t%s=%f\tlr=%f',
+                             epoch, rank, nbatch, name, acc, trainer.learning_rate)
+                if rank == 0:
+                    batch_speed = num_workers * batch_size * args.log_interval / (time.time() - btic)
+                    logging.info('Epoch[%d] Batch[%d]\tSpeed: %.2f samples/sec',
+                                 epoch, nbatch, batch_speed)
+                btic = time.time()
+
+        # Report metrics
+        elapsed = time.time() - tic
+        _, acc = metric.get()
+        logging.info('Epoch[%d] Rank[%d] Batch[%d]\tTime cost=%.2f\tTrain-accuracy=%f',
+                     epoch, rank, nbatch, elapsed, acc)
+        if rank == 0:
+            epoch_speed = num_workers * batch_size * nbatch / elapsed
+            logging.info('Epoch[%d]\tSpeed: %.2f samples/sec', epoch, epoch_speed)
+
+        # Evaluate performance
+        if args.eval_frequency and (epoch + 1) % args.eval_frequency == 0:
+            evaluate(epoch)
+
+        # Save model
+        if args.save_frequency and (epoch + 1) % args.save_frequency == 0:
+            net.export('%s-%d' % (args.model, rank), epoch=epoch)
+
+    # Evaluate performance at the end of training
+    evaluate(epoch)
+
+
+def train_module():
+    # Create input symbol
+    data = mx.sym.var('data')
+    if args.dtype == 'float16':
+        data = mx.sym.Cast(data=data, dtype=np.float16)
+        net.cast(np.float16)
+
+    # Create output symbol
+    out = net(data)
+    if args.dtype == 'float16':
+        out = mx.sym.Cast(data=out, dtype=np.float32)
+    softmax = mx.sym.SoftmaxOutput(out, name='softmax')
+
+    # Create model
+    mod = mx.mod.Module(softmax, context=context)
+
+    # Initialize parameters
+    if args.use_pretrained:
+        arg_params = {}
+        for x in net.collect_params().values():
+            x.reset_ctx(mx.cpu())
+            arg_params[x.name] = x.data()
+    else:
+        arg_params = None
+    aux_params = None
+    mod.bind(data_shapes=train_data.provide_data,
+             label_shapes=train_data.provide_label)
+    mod.init_params(initializer, arg_params=arg_params, aux_params=aux_params)
+
+    # Horovod: fetch and broadcast parameters
+    (arg_params, aux_params) = mod.get_params()
+    if arg_params is not None:
+        hvd.broadcast_parameters(arg_params, root_rank=0)
+    if aux_params is not None:
+        hvd.broadcast_parameters(aux_params, root_rank=0)
+    mod.set_params(arg_params=arg_params, aux_params=aux_params)
+
+    # Setup validation data and callback during training
+    eval_data = None
+    if args.eval_epoch:
+        eval_data = val_data
+    batch_callback = None
+    if args.log_interval > 0 and rank == 0:
+        batch_callback = mx.callback.Speedometer(batch_size * num_workers,
+                                                 args.log_interval)
+
+    epoch_callback = None
+    if args.save_frequency > 0:
+        epoch_callback = mx.callback.do_checkpoint(
+            '%s-%d' % (args.model, rank),
+            period=args.save_frequency)
+
+    # Train model
+    mod.fit(train_data,
+            eval_data=eval_data,
+            num_epoch=args.num_epochs,
+            kvstore=None,
+            batch_end_callback=batch_callback,
+            epoch_end_callback=epoch_callback,
+            optimizer=opt)
+
+    # Evaluate performance if not using synthetic data
+    if args.use_rec:
+        acc_top1 = mx.metric.Accuracy()
+        acc_top5 = mx.metric.TopKAccuracy(5)
+        res = mod.score(val_data, [acc_top1, acc_top5])
+        for name, val in res:
+            logging.info('Epoch[%d] Rank[%d] Validation-%s=%f',
+                         args.num_epochs - 1, rank, name, val)
+
+
+if __name__ == '__main__':
+    if args.mode == 'module':
+        train_module()
+    elif args.mode == 'gluon':
+        train_gluon()
+    else:
+        raise ValueError('Invalid training mode.')