You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/10/21 18:39:32 UTC

[incubator-mxnet] branch master updated: Extending the DCGAN example implemented by gluon API to provide a more straight-forward evaluation on the generated image (#12790)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0137483  Extending the DCGAN example implemented by gluon API to provide a more straight-forward evaluation on the generated image (#12790)
0137483 is described below

commit 013748300a1182f30d80442c7f3b3164b49af25a
Author: pengxin99 <yc...@126.com>
AuthorDate: Mon Oct 22 02:39:18 2018 +0800

    Extending the DCGAN example implemented by gluon API to provide a more straight-forward evaluation on the generated image (#12790)
    
    * add inception_score to metric dcgan model
    
    * Update README.md
    
    * add two pic
    
    * updata readme
    
    * updata
    
    * Update README.md
    
    * add license
    
    * refine1
    
    * refine2
    
    * refine3
    
    * fix review comments
    
    * Update README.md
    
    * Update example/gluon/DCGAN/README.md
    
    * Update example/gluon/DCGAN/README.md
    
    * Update example/gluon/DCGAN/README.md
    
    * Update example/gluon/DCGAN/README.md
    
    * Update example/gluon/DCGAN/README.md
    
    * Update example/gluon/DCGAN/README.md
    
    * Update example/gluon/DCGAN/README.md
    
    * Update example/gluon/DCGAN/README.md
    
    * Update example/gluon/DCGAN/README.md
    
    * Update example/gluon/DCGAN/README.md
    
    * Update example/gluon/DCGAN/README.md
    
    * modify sn_gan file links to DCGAN
    
    * update pic links to web-data
    
    * update the pic path of readme.md
    
    * rm folder pic/, and related links update to https://github.com/dmlc/web-data/mxnet/example/gluon/DCGAN/
    
    * Update README.md
---
 example/gluon/DCGAN/README.md          |  52 +++++
 example/gluon/DCGAN/__init__.py        |   0
 example/gluon/DCGAN/dcgan.py           | 340 +++++++++++++++++++++++++++++++++
 example/gluon/DCGAN/inception_score.py | 110 +++++++++++
 example/gluon/dcgan.py                 | 236 -----------------------
 example/gluon/sn_gan/data.py           |   2 +-
 example/gluon/sn_gan/model.py          |   2 +-
 example/gluon/sn_gan/train.py          |   2 +-
 example/gluon/sn_gan/utils.py          |   2 +-
 9 files changed, 506 insertions(+), 240 deletions(-)

diff --git a/example/gluon/DCGAN/README.md b/example/gluon/DCGAN/README.md
new file mode 100644
index 0000000..5aacd78
--- /dev/null
+++ b/example/gluon/DCGAN/README.md
@@ -0,0 +1,52 @@
+# DCGAN in MXNet
+
+[Deep Convolutional Generative Adversarial Networks(DCGAN)](https://arxiv.org/abs/1511.06434) implementation with Apache MXNet GLUON.
+This implementation uses [inception_score](https://github.com/openai/improved-gan) to evaluate the model.
+
+You can use this reference implementation on the MNIST and CIFAR-10 datasets.
+
+
+#### Generated image output examples from the CIFAR-10 dataset
+![Generated image output examples from the CIFAR-10 dataset](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/gluon/DCGAN/fake_img_iter_13900.png)
+
+#### Generated image output examples from the MNIST dataset
+![Generated image output examples from the MNIST dataset](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/gluon/DCGAN/fake_img_iter_21700.png)
+
+#### inception_score in cpu and gpu (the real image`s score is around 3.3)
+CPU & GPU
+
+![inception score with CPU](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/gluon/DCGAN/inception_score_cifar10_cpu.png)
+![inception score with GPU](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/gluon/DCGAN/inception_score_cifar10.png)
+
+## Quick start
+Use the following code to see the configurations you can set:
+```bash
+python dcgan.py -h
+```
+    
+
+    optional arguments:
+      -h, --help            show this help message and exit
+      --dataset DATASET     dataset to use. options are cifar10 and mnist.
+      --batch-size BATCH_SIZE  input batch size, default is 64
+      --nz NZ               size of the latent z vector, default is 100
+      --ngf NGF             the channel of each generator filter layer, default is 64.
+      --ndf NDF             the channel of each descriminator filter layer, default is 64.
+      --nepoch NEPOCH       number of epochs to train for, default is 25.
+      --niter NITER         save generated images and inception_score per niter iters, default is 100.
+      --lr LR               learning rate, default=0.0002
+      --beta1 BETA1         beta1 for adam. default=0.5
+      --cuda                enables cuda
+      --netG NETG           path to netG (to continue training)
+      --netD NETD           path to netD (to continue training)
+      --outf OUTF           folder to output images and model checkpoints
+      --check-point CHECK_POINT
+                            save results at each epoch or not
+      --inception_score INCEPTION_SCORE
+                            To record the inception_score, default is True.
+
+
+Use the following Python script to train a DCGAN model with default configurations using the CIFAR-10 dataset and record metrics with `inception_score`:
+```bash
+python dcgan.py
+```
diff --git a/example/gluon/DCGAN/__init__.py b/example/gluon/DCGAN/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/example/gluon/DCGAN/dcgan.py b/example/gluon/DCGAN/dcgan.py
new file mode 100644
index 0000000..970c35d
--- /dev/null
+++ b/example/gluon/DCGAN/dcgan.py
@@ -0,0 +1,340 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import matplotlib as mpl
+mpl.use('Agg')
+from matplotlib import pyplot as plt
+
+import argparse
+import mxnet as mx
+from mxnet import gluon
+from mxnet.gluon import nn
+from mxnet import autograd
+import numpy as np
+import logging
+from datetime import datetime
+import os
+import time
+
+from inception_score import get_inception_score
+
+
+def fill_buf(buf, i, img, shape):
+    """
+    Reposition the images generated by the generator so that it can be saved as picture matrix.
+    :param buf: the images metric
+    :param i: index of each image
+    :param img: images generated by generator once
+    :param shape: each image`s shape
+    :return: Adjust images for output
+    """
+    n = buf.shape[0]//shape[1]
+    m = buf.shape[1]//shape[0]
+
+    sx = (i%m)*shape[0]
+    sy = (i//m)*shape[1]
+    buf[sy:sy+shape[1], sx:sx+shape[0], :] = img
+    return None
+
+
+def visual(title, X, name):
+    """
+    Image visualization and preservation
+    :param title: title
+    :param X: images to visualized
+    :param name: saved picture`s name
+    :return:
+    """
+    assert len(X.shape) == 4
+    X = X.transpose((0, 2, 3, 1))
+    X = np.clip((X - np.min(X))*(255.0/(np.max(X) - np.min(X))), 0, 255).astype(np.uint8)
+    n = np.ceil(np.sqrt(X.shape[0]))
+    buff = np.zeros((int(n*X.shape[1]), int(n*X.shape[2]), int(X.shape[3])), dtype=np.uint8)
+    for i, img in enumerate(X):
+        fill_buf(buff, i, img, X.shape[1:3])
+    buff = buff[:, :, ::-1]
+    plt.imshow(buff)
+    plt.title(title)
+    plt.savefig(name)
+
+
+parser = argparse.ArgumentParser()
+parser = argparse.ArgumentParser(description='Train a DCgan model for image generation '
+                                             'and then use inception_score to metric the result.')
+parser.add_argument('--dataset', type=str, default='cifar10', help='dataset to use. options are cifar10 and mnist.')
+parser.add_argument('--batch-size', type=int, default=64, help='input batch size, default is 64')
+parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector, default is 100')
+parser.add_argument('--ngf', type=int, default=64, help='the channel of each generator filter layer, default is 64.')
+parser.add_argument('--ndf', type=int, default=64, help='the channel of each descriminator filter layer, default is 64.')
+parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for, default is 25.')
+parser.add_argument('--niter', type=int, default=10, help='save generated images and inception_score per niter iters, default is 100.')
+parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
+parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
+parser.add_argument('--cuda', action='store_true', help='enables cuda')
+parser.add_argument('--netG', default='', help="path to netG (to continue training)")
+parser.add_argument('--netD', default='', help="path to netD (to continue training)")
+parser.add_argument('--outf', default='./results', help='folder to output images and model checkpoints')
+parser.add_argument('--check-point', default=True, help="save results at each epoch or not")
+parser.add_argument('--inception_score', type=bool, default=True, help='To record the inception_score, default is True.')
+
+opt = parser.parse_args()
+print(opt)
+
+logging.basicConfig(level=logging.DEBUG)
+
+nz = int(opt.nz)
+ngf = int(opt.ngf)
+ndf = int(opt.ndf)
+niter = opt.niter
+nc = 3
+if opt.cuda:
+    ctx = mx.gpu(0)
+else:
+    ctx = mx.cpu()
+batch_size = opt.batch_size
+check_point = bool(opt.check_point)
+outf = opt.outf
+dataset = opt.dataset
+
+if not os.path.exists(outf):
+    os.makedirs(outf)
+
+
+def transformer(data, label):
+    # resize to 64x64
+    data = mx.image.imresize(data, 64, 64)
+    # transpose from (64, 64, 3) to (3, 64, 64)
+    data = mx.nd.transpose(data, (2, 0, 1))
+    # normalize to [-1, 1]
+    data = data.astype(np.float32)/128 - 1
+    # if image is greyscale, repeat 3 times to get RGB image.
+    if data.shape[0] == 1:
+        data = mx.nd.tile(data, (3, 1, 1))
+    return data, label
+
+
+# get dataset with the batch_size num each time
+def get_dataset(dataset):
+    # mnist
+    if dataset == "mnist":
+        train_data = gluon.data.DataLoader(
+            gluon.data.vision.MNIST('./data', train=True, transform=transformer),
+            batch_size, shuffle=True, last_batch='discard')
+
+        val_data = gluon.data.DataLoader(
+            gluon.data.vision.MNIST('./data', train=False, transform=transformer),
+            batch_size, shuffle=False)
+    # cifar10
+    elif dataset == "cifar10":
+        train_data = gluon.data.DataLoader(
+            gluon.data.vision.CIFAR10('./data', train=True, transform=transformer),
+            batch_size, shuffle=True, last_batch='discard')
+
+        val_data = gluon.data.DataLoader(
+            gluon.data.vision.CIFAR10('./data', train=False, transform=transformer),
+            batch_size, shuffle=False)
+
+    return train_data, val_data
+
+
+def get_netG():
+    # build the generator
+    netG = nn.Sequential()
+    with netG.name_scope():
+        # input is Z, going into a convolution
+        netG.add(nn.Conv2DTranspose(ngf * 8, 4, 1, 0, use_bias=False))
+        netG.add(nn.BatchNorm())
+        netG.add(nn.Activation('relu'))
+        # state size. (ngf*8) x 4 x 4
+        netG.add(nn.Conv2DTranspose(ngf * 4, 4, 2, 1, use_bias=False))
+        netG.add(nn.BatchNorm())
+        netG.add(nn.Activation('relu'))
+        # state size. (ngf*4) x 8 x 8
+        netG.add(nn.Conv2DTranspose(ngf * 2, 4, 2, 1, use_bias=False))
+        netG.add(nn.BatchNorm())
+        netG.add(nn.Activation('relu'))
+        # state size. (ngf*2) x 16 x 16
+        netG.add(nn.Conv2DTranspose(ngf, 4, 2, 1, use_bias=False))
+        netG.add(nn.BatchNorm())
+        netG.add(nn.Activation('relu'))
+        # state size. (ngf) x 32 x 32
+        netG.add(nn.Conv2DTranspose(nc, 4, 2, 1, use_bias=False))
+        netG.add(nn.Activation('tanh'))
+        # state size. (nc) x 64 x 64
+
+    return netG
+
+
+def get_netD():
+    # build the discriminator
+    netD = nn.Sequential()
+    with netD.name_scope():
+        # input is (nc) x 64 x 64
+        netD.add(nn.Conv2D(ndf, 4, 2, 1, use_bias=False))
+        netD.add(nn.LeakyReLU(0.2))
+        # state size. (ndf) x 32 x 32
+        netD.add(nn.Conv2D(ndf * 2, 4, 2, 1, use_bias=False))
+        netD.add(nn.BatchNorm())
+        netD.add(nn.LeakyReLU(0.2))
+        # state size. (ndf*2) x 16 x 16
+        netD.add(nn.Conv2D(ndf * 4, 4, 2, 1, use_bias=False))
+        netD.add(nn.BatchNorm())
+        netD.add(nn.LeakyReLU(0.2))
+        # state size. (ndf*4) x 8 x 8
+        netD.add(nn.Conv2D(ndf * 8, 4, 2, 1, use_bias=False))
+        netD.add(nn.BatchNorm())
+        netD.add(nn.LeakyReLU(0.2))
+        # state size. (ndf*8) x 4 x 4
+        netD.add(nn.Conv2D(2, 4, 1, 0, use_bias=False))
+        # state size. 2 x 1 x 1
+
+    return netD
+
+
+def get_configurations(netG, netD):
+    # loss
+    loss = gluon.loss.SoftmaxCrossEntropyLoss()
+
+    # initialize the generator and the discriminator
+    netG.initialize(mx.init.Normal(0.02), ctx=ctx)
+    netD.initialize(mx.init.Normal(0.02), ctx=ctx)
+
+    # trainer for the generator and the discriminator
+    trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': opt.lr, 'beta1': opt.beta1})
+    trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': opt.lr, 'beta1': opt.beta1})
+
+    return loss, trainerG, trainerD
+
+
+def ins_save(inception_score):
+    # draw the inception_score curve
+    length = len(inception_score)
+    x = np.arange(0, length)
+    plt.figure(figsize=(8.0, 6.0))
+    plt.plot(x, inception_score)
+    plt.xlabel("iter/100")
+    plt.ylabel("inception_score")
+    plt.savefig("inception_score.png")
+
+
+# main function
+def main():
+    print("|------- new changes!!!!!!!!!")
+    # to get the dataset and net configuration
+    train_data, val_data = get_dataset(dataset)
+    netG = get_netG()
+    netD = get_netD()
+    loss, trainerG, trainerD = get_configurations(netG, netD)
+
+    # set labels
+    real_label = mx.nd.ones((opt.batch_size,), ctx=ctx)
+    fake_label = mx.nd.zeros((opt.batch_size,), ctx=ctx)
+
+    metric = mx.metric.Accuracy()
+    print('Training... ')
+    stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
+
+    iter = 0
+
+    # to metric the network
+    loss_d = []
+    loss_g = []
+    inception_score = []
+
+    for epoch in range(opt.nepoch):
+        tic = time.time()
+        btic = time.time()
+        for data, _ in train_data:
+            ############################
+            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
+            ###########################
+            # train with real_t
+            data = data.as_in_context(ctx)
+            noise = mx.nd.random.normal(0, 1, shape=(opt.batch_size, nz, 1, 1), ctx=ctx)
+
+            with autograd.record():
+                output = netD(data)
+                # reshape output from (opt.batch_size, 2, 1, 1) to (opt.batch_size, 2)
+                output = output.reshape((opt.batch_size, 2))
+                errD_real = loss(output, real_label)
+
+            metric.update([real_label, ], [output, ])
+
+            with autograd.record():
+                fake = netG(noise)
+                output = netD(fake.detach())
+                output = output.reshape((opt.batch_size, 2))
+                errD_fake = loss(output, fake_label)
+                errD = errD_real + errD_fake
+
+            errD.backward()
+            metric.update([fake_label,], [output,])
+
+            trainerD.step(opt.batch_size)
+
+            ############################
+            # (2) Update G network: maximize log(D(G(z)))
+            ###########################
+            with autograd.record():
+                output = netD(fake)
+                output = output.reshape((-1, 2))
+                errG = loss(output, real_label)
+
+            errG.backward()
+
+            trainerG.step(opt.batch_size)
+
+            name, acc = metric.get()
+            logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d'
+                         % (mx.nd.mean(errD).asscalar(), mx.nd.mean(errG).asscalar(), acc, iter, epoch))
+            if iter % niter == 0:
+                visual('gout', fake.asnumpy(), name=os.path.join(outf, 'fake_img_iter_%d.png' % iter))
+                visual('data', data.asnumpy(), name=os.path.join(outf, 'real_img_iter_%d.png' % iter))
+                # record the metric data
+                loss_d.append(errD)
+                loss_g.append(errG)
+                if opt.inception_score:
+                    score, _ = get_inception_score(fake)
+                    inception_score.append(score)
+
+            iter = iter + 1
+            btic = time.time()
+
+        name, acc = metric.get()
+        metric.reset()
+        logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
+        logging.info('time: %f' % (time.time() - tic))
+
+        # save check_point
+        if check_point:
+            netG.save_parameters(os.path.join(outf,'generator_epoch_%d.params' %epoch))
+            netD.save_parameters(os.path.join(outf,'discriminator_epoch_%d.params' % epoch))
+
+    # save parameter
+    netG.save_parameters(os.path.join(outf, 'generator.params'))
+    netD.save_parameters(os.path.join(outf, 'discriminator.params'))
+
+    # visualization the inception_score as a picture
+    if opt.inception_score:
+        ins_save(inception_score)
+
+
+if __name__ == '__main__':
+    if opt.inception_score:
+        print("Use inception_score to metric this DCgan model, the reusult is save as a picture named \"inception_score.png\"!")
+    main()
+
diff --git a/example/gluon/DCGAN/inception_score.py b/example/gluon/DCGAN/inception_score.py
new file mode 100644
index 0000000..e23513f
--- /dev/null
+++ b/example/gluon/DCGAN/inception_score.py
@@ -0,0 +1,110 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from mxnet.gluon.model_zoo import vision as models
+import mxnet as mx
+from mxnet import nd
+import numpy as np
+import math
+import sys
+
+import cv2
+
+
+inception_model = None
+
+
+def get_inception_score(images, splits=10):
+    """
+    Inception_score function.
+        The images will be divided into 'splits' parts, and calculate each inception_score separately,
+        then return the mean and std of inception_scores of these parts.
+    :param images: Images(num x c x w x h) that needs to calculate inception_score.
+    :param splits:
+    :return: mean and std of inception_score
+    """
+    assert (images.shape[1] == 3)
+
+    # load inception model
+    if inception_model is None:
+        _init_inception()
+
+    # resize images to adapt inception model(inceptionV3)
+    if images.shape[2] != 299:
+        images = resize(images, 299, 299)
+
+    preds = []
+    bs = 4
+    n_batches = int(math.ceil(float(images.shape[0])/float(bs)))
+
+    # to get the predictions/picture of inception model
+    for i in range(n_batches):
+        sys.stdout.write(".")
+        sys.stdout.flush()
+        inps = images[(i * bs):min((i + 1) * bs, len(images))]
+        # inps size. bs x 3 x 299 x 299
+        pred = nd.softmax(inception_model(inps))
+        # pred size. bs x 1000
+        preds.append(pred.asnumpy())
+
+    # list to array
+    preds = np.concatenate(preds, 0)
+    scores = []
+
+    # to calculate the inception_score each split.
+    for i in range(splits):
+        # extract per split image pred
+        part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
+        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
+        kl = np.mean(np.sum(kl, 1))
+        scores.append(np.exp(kl))
+
+    return np.mean(scores), np.std(scores)
+
+
+def _init_inception():
+    global inception_model
+    inception_model = models.inception_v3(pretrained=True)
+    print("success import inception model, and the model is inception_v3!")
+
+
+def resize(images, w, h):
+    nums = images.shape[0]
+    res = nd.random.uniform(0, 255, (nums, 3, w, h))
+    for i in range(nums):
+        img = images[i, :, :, :]
+        img = mx.nd.transpose(img, (1, 2, 0))
+        # Replace 'mx.image.imresize()' with 'cv2.resize()' because : Operator _cvimresize is not implemented for GPU.
+        # img = mx.image.imresize(img, w, h)
+        img = cv2.resize(img.asnumpy(), (299, 299))
+        img = nd.array(img)
+        img = mx.nd.transpose(img, (2, 0, 1))
+        res[i, :, :, :] = img
+
+    return res
+
+
+if __name__ == '__main__':
+    if inception_model is None:
+        _init_inception()
+    # dummy data
+    images = nd.random.uniform(0, 255, (64, 3, 64, 64))
+    print(images.shape[0])
+    # resize(images,299,299)
+
+    score = get_inception_score(images)
+    print(score)
diff --git a/example/gluon/dcgan.py b/example/gluon/dcgan.py
deleted file mode 100644
index 8ac9c52..0000000
--- a/example/gluon/dcgan.py
+++ /dev/null
@@ -1,236 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import matplotlib as mpl
-mpl.use('Agg')
-from matplotlib import pyplot as plt
-
-import argparse
-import mxnet as mx
-from mxnet import gluon
-from mxnet.gluon import nn
-from mxnet import autograd
-import numpy as np
-import logging
-from datetime import datetime
-import os
-import time
-
-def fill_buf(buf, i, img, shape):
-    n = buf.shape[0]//shape[1]
-    m = buf.shape[1]//shape[0]
-
-    sx = (i%m)*shape[0]
-    sy = (i//m)*shape[1]
-    buf[sy:sy+shape[1], sx:sx+shape[0], :] = img
-    return None
-
-def visual(title, X, name):
-    assert len(X.shape) == 4
-    X = X.transpose((0, 2, 3, 1))
-    X = np.clip((X - np.min(X))*(255.0/(np.max(X) - np.min(X))), 0, 255).astype(np.uint8)
-    n = np.ceil(np.sqrt(X.shape[0]))
-    buff = np.zeros((int(n*X.shape[1]), int(n*X.shape[2]), int(X.shape[3])), dtype=np.uint8)
-    for i, img in enumerate(X):
-        fill_buf(buff, i, img, X.shape[1:3])
-    buff = buff[:,:,::-1]
-    plt.imshow(buff)
-    plt.title(title)
-    plt.savefig(name)
-
-
-parser = argparse.ArgumentParser()
-parser.add_argument('--dataset', type=str, default='cifar10', help='dataset to use. options are cifar10 and imagenet.')
-parser.add_argument('--batch-size', type=int, default=64, help='input batch size')
-parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
-parser.add_argument('--ngf', type=int, default=64)
-parser.add_argument('--ndf', type=int, default=64)
-parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for')
-parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
-parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
-parser.add_argument('--cuda', action='store_true', help='enables cuda')
-parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
-parser.add_argument('--netG', default='', help="path to netG (to continue training)")
-parser.add_argument('--netD', default='', help="path to netD (to continue training)")
-parser.add_argument('--outf', default='./results', help='folder to output images and model checkpoints')
-parser.add_argument('--check-point', default=True, help="save results at each epoch or not")
-
-opt = parser.parse_args()
-print(opt)
-
-logging.basicConfig(level=logging.DEBUG)
-ngpu = int(opt.ngpu)
-nz = int(opt.nz)
-ngf = int(opt.ngf)
-ndf = int(opt.ndf)
-nc = 3
-if opt.cuda:
-    ctx = mx.gpu(0)
-else:
-    ctx = mx.cpu()
-check_point = bool(opt.check_point)
-outf = opt.outf
-
-if not os.path.exists(outf):
-    os.makedirs(outf)
-
-
-def transformer(data, label):
-    # resize to 64x64
-    data = mx.image.imresize(data, 64, 64)
-    # transpose from (64, 64, 3) to (3, 64, 64)
-    data = mx.nd.transpose(data, (2,0,1))
-    # normalize to [-1, 1]
-    data = data.astype(np.float32)/128 - 1
-    # if image is greyscale, repeat 3 times to get RGB image.
-    if data.shape[0] == 1:
-        data = mx.nd.tile(data, (3, 1, 1))
-    return data, label
-
-train_data = gluon.data.DataLoader(
-    gluon.data.vision.MNIST('./data', train=True, transform=transformer),
-    batch_size=opt.batch_size, shuffle=True, last_batch='discard')
-
-val_data = gluon.data.DataLoader(
-    gluon.data.vision.MNIST('./data', train=False, transform=transformer),
-    batch_size=opt.batch_size, shuffle=False)
-
-
-# build the generator
-netG = nn.Sequential()
-with netG.name_scope():
-    # input is Z, going into a convolution
-    netG.add(nn.Conv2DTranspose(ngf * 8, 4, 1, 0, use_bias=False))
-    netG.add(nn.BatchNorm())
-    netG.add(nn.Activation('relu'))
-    # state size. (ngf*8) x 4 x 4
-    netG.add(nn.Conv2DTranspose(ngf * 4, 4, 2, 1, use_bias=False))
-    netG.add(nn.BatchNorm())
-    netG.add(nn.Activation('relu'))
-    # state size. (ngf*8) x 8 x 8
-    netG.add(nn.Conv2DTranspose(ngf * 2, 4, 2, 1, use_bias=False))
-    netG.add(nn.BatchNorm())
-    netG.add(nn.Activation('relu'))
-    # state size. (ngf*8) x 16 x 16
-    netG.add(nn.Conv2DTranspose(ngf, 4, 2, 1, use_bias=False))
-    netG.add(nn.BatchNorm())
-    netG.add(nn.Activation('relu'))
-    # state size. (ngf*8) x 32 x 32
-    netG.add(nn.Conv2DTranspose(nc, 4, 2, 1, use_bias=False))
-    netG.add(nn.Activation('tanh'))
-    # state size. (nc) x 64 x 64
-
-# build the discriminator
-netD = nn.Sequential()
-with netD.name_scope():
-    # input is (nc) x 64 x 64
-    netD.add(nn.Conv2D(ndf, 4, 2, 1, use_bias=False))
-    netD.add(nn.LeakyReLU(0.2))
-    # state size. (ndf) x 32 x 32
-    netD.add(nn.Conv2D(ndf * 2, 4, 2, 1, use_bias=False))
-    netD.add(nn.BatchNorm())
-    netD.add(nn.LeakyReLU(0.2))
-    # state size. (ndf) x 16 x 16
-    netD.add(nn.Conv2D(ndf * 4, 4, 2, 1, use_bias=False))
-    netD.add(nn.BatchNorm())
-    netD.add(nn.LeakyReLU(0.2))
-    # state size. (ndf) x 8 x 8
-    netD.add(nn.Conv2D(ndf * 8, 4, 2, 1, use_bias=False))
-    netD.add(nn.BatchNorm())
-    netD.add(nn.LeakyReLU(0.2))
-    # state size. (ndf) x 4 x 4
-    netD.add(nn.Conv2D(2, 4, 1, 0, use_bias=False))
-
-# loss
-loss = gluon.loss.SoftmaxCrossEntropyLoss()
-
-# initialize the generator and the discriminator
-netG.initialize(mx.init.Normal(0.02), ctx=ctx)
-netD.initialize(mx.init.Normal(0.02), ctx=ctx)
-
-# trainer for the generator and the discriminator
-trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': opt.lr, 'beta1': opt.beta1})
-trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': opt.lr, 'beta1': opt.beta1})
-
-# ============printing==============
-real_label = mx.nd.ones((opt.batch_size,), ctx=ctx)
-fake_label = mx.nd.zeros((opt.batch_size,), ctx=ctx)
-
-metric = mx.metric.Accuracy()
-print('Training... ')
-stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')
-
-iter = 0
-for epoch in range(opt.nepoch):
-    tic = time.time()
-    btic = time.time()
-    for data, _ in train_data:
-        ############################
-        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
-        ###########################
-        # train with real_t
-        data = data.as_in_context(ctx)
-        noise = mx.nd.random.normal(0, 1, shape=(opt.batch_size, nz, 1, 1), ctx=ctx)
-
-        with autograd.record():
-            output = netD(data)
-            output = output.reshape((opt.batch_size, 2))
-            errD_real = loss(output, real_label)
-            metric.update([real_label,], [output,])
-
-            fake = netG(noise)
-            output = netD(fake.detach())
-            output = output.reshape((opt.batch_size, 2))
-            errD_fake = loss(output, fake_label)
-            errD = errD_real + errD_fake
-            errD.backward()
-            metric.update([fake_label,], [output,])
-
-        trainerD.step(opt.batch_size)
-
-        ############################
-        # (2) Update G network: maximize log(D(G(z)))
-        ###########################
-        with autograd.record():
-            output = netD(fake)
-            output = output.reshape((-1, 2))
-            errG = loss(output, real_label)
-            errG.backward()
-
-        trainerG.step(opt.batch_size)
-
-        name, acc = metric.get()
-        # logging.info('speed: {} samples/s'.format(opt.batch_size / (time.time() - btic)))
-        logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d' %(mx.nd.mean(errD).asscalar(), mx.nd.mean(errG).asscalar(), acc, iter, epoch))
-        if iter % 1 == 0:
-            visual('gout', fake.asnumpy(), name=os.path.join(outf,'fake_img_iter_%d.png' %iter))
-            visual('data', data.asnumpy(), name=os.path.join(outf,'real_img_iter_%d.png' %iter))
-
-        iter = iter + 1
-        btic = time.time()
-
-    name, acc = metric.get()
-    metric.reset()
-    logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
-    logging.info('time: %f' % (time.time() - tic))
-
-    if check_point:
-        netG.save_parameters(os.path.join(outf,'generator_epoch_%d.params' %epoch))
-        netD.save_parameters(os.path.join(outf,'discriminator_epoch_%d.params' % epoch))
-
-netG.save_parameters(os.path.join(outf, 'generator.params'))
-netD.save_parameters(os.path.join(outf, 'discriminator.params'))
diff --git a/example/gluon/sn_gan/data.py b/example/gluon/sn_gan/data.py
index 333125d..7ed4c38 100644
--- a/example/gluon/sn_gan/data.py
+++ b/example/gluon/sn_gan/data.py
@@ -17,7 +17,7 @@
 
 # This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
 # https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
-# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py
+# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/DCGAN/dcgan.py
 
 import numpy as np
 
diff --git a/example/gluon/sn_gan/model.py b/example/gluon/sn_gan/model.py
index 38f87eb..b714c75 100644
--- a/example/gluon/sn_gan/model.py
+++ b/example/gluon/sn_gan/model.py
@@ -17,7 +17,7 @@
 
 # This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
 # https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
-# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py
+# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/DCGAN/dcgan.py
 
 import mxnet as mx
 from mxnet import nd
diff --git a/example/gluon/sn_gan/train.py b/example/gluon/sn_gan/train.py
index 1cba1f5..f4b9884 100644
--- a/example/gluon/sn_gan/train.py
+++ b/example/gluon/sn_gan/train.py
@@ -17,7 +17,7 @@
 
 # This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
 # https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
-# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py
+# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/DCGAN/dcgan.py
 
 
 import os
diff --git a/example/gluon/sn_gan/utils.py b/example/gluon/sn_gan/utils.py
index d3f1b86..06c0230 100644
--- a/example/gluon/sn_gan/utils.py
+++ b/example/gluon/sn_gan/utils.py
@@ -17,7 +17,7 @@
 
 # This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
 # https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
-# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py
+# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/DCGAN/dcgan.py
 
 import math