File path: dev-support/examples/mnist-pytorch/DDP/
@@ -0,0 +1,192 @@
+from __future__ import print_function
+from submarine import ModelsClient
+import argparse
+import os
+from tensorboardX import SummaryWriter
+from torchvision import datasets, transforms
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+WORLD_SIZE = int(os.environ.get('WORLD_SIZE', 1))
+rank = int(os.environ.get('RANK', 0))
+print('WORLD={} , RANK={}'.format(WORLD_SIZE,rank))
+class Net(nn.Module):
+    def __init__(self):
+        super(Net, self).__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5, 1)
+        self.conv2 = nn.Conv2d(20, 50, 5, 1)
+        self.fc1 = nn.Linear(4*4*50, 500)
+        self.fc2 = nn.Linear(500, 10)
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        x = F.max_pool2d(x, 2, 2)
+        x = F.relu(self.conv2(x))
+        x = F.max_pool2d(x, 2, 2)
+        x = x.view(-1, 4*4*50)
+        x = F.relu(self.fc1(x))
+        x = self.fc2(x)
+        return F.log_softmax(x, dim=1)
+def train(args, model, device, train_loader, optimizer, epoch, writer, periscope):
+    model.train()
+    for batch_idx, (data, target) in enumerate(train_loader):
+        data, target =,
+        optimizer.zero_grad()
+        output = model(data)
+        loss = F.nll_loss(output, target)
+        loss.backward()
+        optimizer.step()
+        if batch_idx % args.log_interval == 0:
+            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format(
+                epoch, batch_idx * len(data), len(train_loader.dataset),
+                100. * batch_idx / len(train_loader), loss.item()))
+            niter = epoch * len(train_loader) + batch_idx
+            writer.add_scalar('loss', loss.item(), niter)
+            periscope.log_metric('loss', loss.item(), niter)
+def test(args, model, device, test_loader, writer, epoch, periscope):
+    model.eval()
+    test_loss = 0
+    correct = 0
+    with torch.no_grad():
+        for data, target in test_loader:
+            data, target =,
+            output = model(data)
+            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
+            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
+            correct += pred.eq(target.view_as(pred)).sum().item()
+    test_loss /= len(test_loader.dataset)
+    print('\naccuracy={:.4f}\n'.format(float(correct) / len(test_loader.dataset)))
+    writer.add_scalar('accuracy', float(correct) / len(test_loader.dataset), epoch)
+    periscope.log_metric('accuracy', float(correct) / len(test_loader.dataset), epoch)
+def should_distribute():
+    return dist.is_available() and WORLD_SIZE > 1
+def is_distributed():
+    return dist.is_available() and dist.is_initialized()
+if __name__ == '__main__':
+    # Training settings
+    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
+    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
+                        help='input batch size for training (default: 64)')
+    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
+                        help='input batch size for testing (default: 1000)')
+    parser.add_argument('--epochs', type=int, default=5, metavar='N',
+                        help='number of epochs to train (default: 5)')
+    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
+                        help='learning rate (default: 0.01)')
+    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
+                        help='SGD momentum (default: 0.5)')
+    parser.add_argument('--no-cuda', action='store_true', default=False,
+                        help='disables CUDA training')
+    parser.add_argument('--seed', type=int, default=1, metavar='S',
+                        help='random seed (default: 1)')
+    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
+                        help='how many batches to wait before logging training status')
+    parser.add_argument('--save-model', action='store_true', default=False,
+                        help='For Saving the current Model')
+    parser.add_argument('--dir', default='logs', metavar='L',
+                        help='directory where summary logs are stored')
+    if dist.is_available():
+        parser.add_argument('--backend', type=str, help='Distributed backend',
+                            choices=[dist.Backend.GLOO, dist.Backend.NCCL, dist.Backend.MPI],
+                            default=dist.Backend.GLOO)
+    args = parser.parse_args()
+    use_cuda = not args.no_cuda and torch.cuda.is_available()
+    if use_cuda:
+        print('Using CUDA')
+    else :
+        print('Not Using CUDA')
+    writer = SummaryWriter(args.dir)
+    torch.manual_seed(args.seed)
+    device = torch.device("cuda" if use_cuda else "cpu")
+    if should_distribute():
+        print('Using distributed PyTorch with {} backend'.format(args.backend))
+        dist.init_process_group(
+            backend=args.backend,
+            world_size=WORLD_SIZE,
+            rank=rank)
+    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
+    train_dataset = datasets.FashionMNIST('../data', train=True, download=True,
+                       transform=transforms.Compose([
+                           transforms.ToTensor(),
+                           transforms.Normalize((0.1307,), (0.3081,))
+                       ]))
+    train_sampler =
+        train_dataset,
+        num_replicas = WORLD_SIZE,
+        rank=rank
+    )
+    train_loader =
+        dataset = train_dataset,
+        batch_size = args.batch_size, 
+        shuffle = False,
+        sampler = train_sampler,
+        **kwargs)
+    test_loader =
+        datasets.FashionMNIST('../data', train=False, transform=transforms.Compose([
+                           transforms.ToTensor(),
+                           transforms.Normalize((0.1307,), (0.3081,))
+                       ])),
+        batch_size=args.test_batch_size, shuffle=False, **kwargs)
+    model = Net().to(device)
+    if is_distributed():
+        Distributor = nn.parallel.DistributedDataParallel if use_cuda \
+            else nn.parallel.DistributedDataParallelCPU
+        model = Distributor(model)
+    optimizer = optim.SGD(model.parameters(),, momentum=args.momentum)
+    periscope = ModelsClient()
+    with periscope.start() as run:
+        periscope.log_param("learning_rate",
+        periscope.log_param("batch_size", args.batch_size)
+        for epoch in range(1, args.epochs + 1):
+        # for epoch in range(1, 6):

