You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/12/14 19:28:38 UTC

[incubator-mxnet] branch master updated: add CapsNet example (#8787)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8623bab  add CapsNet example (#8787)
8623bab is described below

commit 8623bab8e2a495c69c8b83b12a7ee0ca35464632
Author: Soonhwan-Kwon <So...@users.noreply.github.com>
AuthorDate: Fri Dec 15 04:28:35 2017 +0900

    add CapsNet example (#8787)
    
    * add capsnet example's layer
    
    * add capsnet example
    
    * add recon_loss_weight option and tensorboard for plot
    
    * update readme to install tensorboard
    
    * fix print of loss scaled to 1/batchsize
---
 example/capsnet/README.md        |  66 ++++++++
 example/capsnet/capsulelayers.py | 106 ++++++++++++
 example/capsnet/capsulenet.py    | 348 +++++++++++++++++++++++++++++++++++++++
 example/capsnet/result.PNG       | Bin 0 -> 31313 bytes
 4 files changed, 520 insertions(+)

diff --git a/example/capsnet/README.md b/example/capsnet/README.md
new file mode 100644
index 0000000..49a6dd1
--- /dev/null
+++ b/example/capsnet/README.md
@@ -0,0 +1,66 @@
+**CapsNet-MXNet**
+=========================================
+
+This example is MXNet implementation of [CapsNet](https://arxiv.org/abs/1710.09829):  
+Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between Capsules. NIPS 2017
+- The current `best test error is 0.29%` and `average test error is 0.303%`
+- The `average test error on paper is 0.25%`  
+
+Log files for the error rate are uploaded in [repository](https://github.com/samsungsds-rnd/capsnet.mxnet).  
+* * *
+## **Usage**
+Install scipy with pip  
+```
+pip install scipy
+```
+Install tensorboard with pip
+```
+pip install tensorboard
+```
+
+On Single gpu
+```
+python capsulenet.py --devices gpu0
+```
+On Multi gpus
+```
+python capsulenet.py --devices gpu0,gpu1
+```
+Full arguments  
+```
+python capsulenet.py --batch_size 100 --devices gpu0,gpu1 --num_epoch 100 --lr 0.001 --num_routing 3 --model_prefix capsnet
+```  
+
+* * *
+## **Prerequisities**
+
+MXNet version above (0.11.0)  
+scipy version above (0.19.0)
+
+***
+## **Results**  
+Train time takes about 36 seconds for each epoch (batch_size=100, 2 gtx 1080 gpus)  
+
+CapsNet classification test error on MNIST  
+
+```
+python capsulenet.py --devices gpu0,gpu1 --lr 0.0005 --decay 0.99 --model_prefix lr_0_0005_decay_0_99 --batch_size 100 --num_routing 3 --num_epoch 200
+```
+
+![](result.PNG)
+
+| Trial | Epoch | train err(%) | test err(%) | train loss | test loss |
+| :---: | :---: | :---: | :---: | :---: | :---: |
+| 1 | 120 | 0.06 | 0.31 | 0.0056 | 0.0064 |
+| 2 | 167 | 0.03 | 0.29 | 0.0048 | 0.0058 |
+| 3 | 182 | 0.04 | 0.31 | 0.0046 | 0.0058 |
+| average | - | 0.043 | 0.303 | 0.005 | 0.006 |
+
+We achieved `the best test error rate=0.29%` and `average test error=0.303%`. It is the best accuracy and fastest training time result among other implementations(Keras, Tensorflow at 2017-11-23).
+The result on paper is `0.25% (average test error rate)`.
+
+| Implementation| test err(%) | ※train time/epoch | GPU  Used|
+| :---: | :---: | :---: |:---: |
+| MXNet | 0.29 | 36 sec | 2 GTX 1080 |
+| tensorflow | 0.49 | ※ 10 min | Unknown(4GB Memory) |
+| Keras | 0.30 | 55 sec | 2 GTX 1080 Ti |
diff --git a/example/capsnet/capsulelayers.py b/example/capsnet/capsulelayers.py
new file mode 100644
index 0000000..5ac4fad
--- /dev/null
+++ b/example/capsnet/capsulelayers.py
@@ -0,0 +1,106 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+
+
+def squash(data, squash_axis, name=''):
+    epsilon = 1e-08
+    s_squared_norm = mx.sym.sum(data=mx.sym.square(data, name='square_'+name),
+                                axis=squash_axis, keepdims=True, name='s_squared_norm_'+name)
+    scale = s_squared_norm / (1 + s_squared_norm) / mx.sym.sqrt(data=(s_squared_norm+epsilon),
+                                                                name='s_squared_norm_sqrt_'+name)
+    squashed_net = mx.sym.broadcast_mul(scale, data, name='squashed_net_'+name)
+    return squashed_net
+
+
+def primary_caps(data, dim_vector, n_channels, kernel, strides, name=''):
+    out = mx.sym.Convolution(data=data,
+                             num_filter=dim_vector * n_channels,
+                             kernel=kernel,
+                             stride=strides,
+                             name=name
+                             )
+    out = mx.sym.Reshape(data=out, shape=(0, -1, dim_vector))
+    out = squash(out, squash_axis=2)
+    return out
+
+
+class CapsuleLayer:
+    """
+    The capsule layer with dynamic routing.
+    [batch_size, input_num_capsule, input_dim_vector] => [batch_size, num_capsule, dim_vector]
+    """
+
+    def __init__(self, num_capsule, dim_vector, batch_size, kernel_initializer, bias_initializer, num_routing=3):
+        self.num_capsule = num_capsule
+        self.dim_vector = dim_vector
+        self.batch_size = batch_size
+        self.num_routing = num_routing
+        self.kernel_initializer = kernel_initializer
+        self.bias_initializer = bias_initializer
+
+    def __call__(self, data):
+        _, out_shapes, __ = data.infer_shape(data=(self.batch_size, 1, 28, 28))
+        _, input_num_capsule, input_dim_vector = out_shapes[0]
+
+        # build w and bias
+        # W : (input_num_capsule, num_capsule, input_dim_vector, dim_vector)
+        # bias : (batch_size, input_num_capsule, num_capsule ,1, 1)
+        w = mx.sym.Variable('Weight',
+                            shape=(1, input_num_capsule, self.num_capsule, input_dim_vector, self.dim_vector),
+                            init=self.kernel_initializer)
+        bias = mx.sym.Variable('Bias',
+                               shape=(self.batch_size, input_num_capsule, self.num_capsule, 1, 1),
+                               init=self.bias_initializer)
+        bias = mx.sym.BlockGrad(bias)
+        bias_ = bias
+
+        # input : (batch_size, input_num_capsule, input_dim_vector)
+        # inputs_expand : (batch_size, input_num_capsule, 1, input_dim_vector, 1)
+        inputs_expand = mx.sym.Reshape(data=data, shape=(0, 0, -4, -1, 1))
+        inputs_expand = mx.sym.Reshape(data=inputs_expand, shape=(0, 0, -4, 1, -1, 0))
+        # input_tiled (batch_size, input_num_capsule, num_capsule, input_dim_vector, 1)
+        inputs_tiled = mx.sym.tile(data=inputs_expand, reps=(1, 1, self.num_capsule, 1, 1))
+        # w_tiled : [(1L, input_num_capsule, num_capsule, input_dim_vector, dim_vector)]
+        w_tiled = mx.sym.tile(w, reps=(self.batch_size, 1, 1, 1, 1))
+
+        # inputs_hat : [(1L, input_num_capsule, num_capsule, 1, dim_vector)]
+        inputs_hat = mx.sym.linalg_gemm2(w_tiled, inputs_tiled, transpose_a=True)
+
+        inputs_hat = mx.sym.swapaxes(data=inputs_hat, dim1=3, dim2=4)
+        inputs_hat_stopped = inputs_hat
+        inputs_hat_stopped = mx.sym.BlockGrad(inputs_hat_stopped)
+
+        for i in range(0, self.num_routing):
+            c = mx.sym.softmax(bias_, axis=2, name='c' + str(i))
+            if i == self.num_routing - 1:
+                outputs = squash(
+                    mx.sym.sum(mx.sym.broadcast_mul(c, inputs_hat, name='broadcast_mul_' + str(i)),
+                               axis=1, keepdims=True,
+                               name='sum_' + str(i)), name='output_' + str(i), squash_axis=4)
+            else:
+                outputs = squash(
+                    mx.sym.sum(mx.sym.broadcast_mul(c, inputs_hat_stopped, name='broadcast_mul_' + str(i)),
+                               axis=1, keepdims=True,
+                               name='sum_' + str(i)), name='output_' + str(i), squash_axis=4)
+                bias_ = bias_ + mx.sym.sum(mx.sym.broadcast_mul(c, inputs_hat_stopped, name='bias_broadcast_mul' + str(i)),
+                                           axis=4,
+                                           keepdims=True, name='bias_' + str(i))
+
+        outputs = mx.sym.Reshape(data=outputs, shape=(-1, self.num_capsule, self.dim_vector))
+        return outputs
diff --git a/example/capsnet/capsulenet.py b/example/capsnet/capsulenet.py
new file mode 100644
index 0000000..6b44c3d
--- /dev/null
+++ b/example/capsnet/capsulenet.py
@@ -0,0 +1,348 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import mxnet as mx
+import numpy as np
+import os
+import re
+import urllib
+import gzip
+import struct
+import scipy.ndimage as ndi
+from capsulelayers import primary_caps, CapsuleLayer
+
+from tensorboard import SummaryWriter
+
+def margin_loss(y_true, y_pred):
+    loss = y_true * mx.sym.square(mx.sym.maximum(0., 0.9 - y_pred)) +\
+        0.5 * (1 - y_true) * mx.sym.square(mx.sym.maximum(0., y_pred - 0.1))
+    return mx.sym.mean(data=mx.sym.sum(loss, 1))
+
+
+def capsnet(batch_size, n_class, num_routing,recon_loss_weight):
+    # data.shape = [batch_size, 1, 28, 28]
+    data = mx.sym.Variable('data')
+
+    input_shape = (1, 28, 28)
+    # Conv2D layer
+    # net.shape = [batch_size, 256, 20, 20]
+    conv1 = mx.sym.Convolution(data=data,
+                               num_filter=256,
+                               kernel=(9, 9),
+                               layout='NCHW',
+                               name='conv1')
+    conv1 = mx.sym.Activation(data=conv1, act_type='relu', name='conv1_act')
+    # net.shape = [batch_size, 256, 6, 6]
+
+    primarycaps = primary_caps(data=conv1,
+                               dim_vector=8,
+                               n_channels=32,
+                               kernel=(9, 9),
+                               strides=[2, 2],
+                               name='primarycaps')
+    primarycaps.infer_shape(data=(batch_size, 1, 28, 28))
+    # CapsuleLayer
+    kernel_initializer = mx.init.Xavier(rnd_type='uniform', factor_type='avg', magnitude=3)
+    bias_initializer = mx.init.Zero()
+    digitcaps = CapsuleLayer(num_capsule=10,
+                             dim_vector=16,
+                             batch_size=batch_size,
+                             kernel_initializer=kernel_initializer,
+                             bias_initializer=bias_initializer,
+                             num_routing=num_routing)(primarycaps)
+
+    # out_caps : (batch_size, 10)
+    out_caps = mx.sym.sqrt(data=mx.sym.sum(mx.sym.square(digitcaps), 2))
+    out_caps.infer_shape(data=(batch_size, 1, 28, 28))
+
+    y = mx.sym.Variable('softmax_label', shape=(batch_size,))
+    y_onehot = mx.sym.one_hot(y, n_class)
+    y_reshaped = mx.sym.Reshape(data=y_onehot, shape=(batch_size, -4, n_class, -1))
+    y_reshaped.infer_shape(softmax_label=(batch_size,))
+
+    # inputs_masked : (batch_size, 16)
+    inputs_masked = mx.sym.linalg_gemm2(y_reshaped, digitcaps, transpose_a=True)
+    inputs_masked = mx.sym.Reshape(data=inputs_masked, shape=(-3, 0))
+    x_recon = mx.sym.FullyConnected(data=inputs_masked, num_hidden=512, name='x_recon')
+    x_recon = mx.sym.Activation(data=x_recon, act_type='relu', name='x_recon_act')
+    x_recon = mx.sym.FullyConnected(data=x_recon, num_hidden=1024, name='x_recon2')
+    x_recon = mx.sym.Activation(data=x_recon, act_type='relu', name='x_recon_act2')
+    x_recon = mx.sym.FullyConnected(data=x_recon, num_hidden=np.prod(input_shape), name='x_recon3')
+    x_recon = mx.sym.Activation(data=x_recon, act_type='sigmoid', name='x_recon_act3')
+
+    data_flatten = mx.sym.flatten(data=data)
+    squared_error = mx.sym.square(x_recon-data_flatten)
+    recon_error = mx.sym.mean(squared_error)
+    recon_error_stopped = recon_error
+    recon_error_stopped = mx.sym.BlockGrad(recon_error_stopped)
+    loss = mx.symbol.MakeLoss((1-recon_loss_weight)*margin_loss(y_onehot, out_caps)+recon_loss_weight*recon_error)
+
+    out_caps_blocked = out_caps
+    out_caps_blocked = mx.sym.BlockGrad(out_caps_blocked)
+    return mx.sym.Group([out_caps_blocked, loss, recon_error_stopped])
+
+
+def download_data(url, force_download=False):
+    fname = url.split("/")[-1]
+    if force_download or not os.path.exists(fname):
+        urllib.urlretrieve(url, fname)
+    return fname
+
+
+def read_data(label_url, image_url):
+    with gzip.open(download_data(label_url)) as flbl:
+        magic, num = struct.unpack(">II", flbl.read(8))
+        label = np.fromstring(flbl.read(), dtype=np.int8)
+    with gzip.open(download_data(image_url), 'rb') as fimg:
+        magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
+        image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
+    return label, image
+
+
+def to4d(img):
+    return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255
+
+
+class LossMetric(mx.metric.EvalMetric):
+    def __init__(self, batch_size, num_gpu):
+        super(LossMetric, self).__init__('LossMetric')
+        self.batch_size = batch_size
+        self.num_gpu = num_gpu
+        self.sum_metric = 0
+        self.num_inst = 0
+        self.loss = 0.0
+        self.batch_sum_metric = 0
+        self.batch_num_inst = 0
+        self.batch_loss = 0.0
+        self.recon_loss = 0.0
+        self.n_batch = 0
+
+    def update(self, labels, preds):
+        batch_sum_metric = 0
+        batch_num_inst = 0
+        for label, pred_outcaps in zip(labels[0], preds[0]):
+            label_np = int(label.asnumpy())
+            pred_label = int(np.argmax(pred_outcaps.asnumpy()))
+            batch_sum_metric += int(label_np == pred_label)
+            batch_num_inst += 1
+        batch_loss = preds[1].asnumpy()
+        recon_loss = preds[2].asnumpy()
+        self.sum_metric += batch_sum_metric
+        self.num_inst += batch_num_inst
+        self.loss += batch_loss
+        self.recon_loss += recon_loss
+        self.batch_sum_metric = batch_sum_metric
+        self.batch_num_inst = batch_num_inst
+        self.batch_loss = batch_loss
+        self.n_batch += 1 
+
+    def get_name_value(self):
+        acc = float(self.sum_metric)/float(self.num_inst)
+        mean_loss = self.loss / float(self.n_batch)
+        mean_recon_loss = self.recon_loss / float(self.n_batch)
+        return acc, mean_loss, mean_recon_loss
+
+    def get_batch_log(self, n_batch):
+        print("n_batch :"+str(n_batch)+" batch_acc:" +
+              str(float(self.batch_sum_metric) / float(self.batch_num_inst)) +
+              ' batch_loss:' + str(float(self.batch_loss)/float(self.batch_num_inst)))
+        self.batch_sum_metric = 0
+        self.batch_num_inst = 0
+        self.batch_loss = 0.0
+
+    def reset(self):
+        self.sum_metric = 0
+        self.num_inst = 0
+        self.loss = 0.0
+        self.recon_loss = 0.0
+        self.n_batch = 0
+
+
+class SimpleLRScheduler(mx.lr_scheduler.LRScheduler):
+    """A simple lr schedule that simply return `dynamic_lr`. We will set `dynamic_lr`
+    dynamically based on performance on the validation set.
+    """
+
+    def __init__(self, learning_rate=0.001):
+        super(SimpleLRScheduler, self).__init__()
+        self.learning_rate = learning_rate
+
+    def __call__(self, num_update):
+        return self.learning_rate
+
+
+def do_training(num_epoch, optimizer, kvstore, learning_rate, model_prefix, decay):
+    summary_writer = SummaryWriter(args.tblog_dir)
+    lr_scheduler = SimpleLRScheduler(learning_rate)
+    optimizer_params = {'lr_scheduler': lr_scheduler}
+    module.init_params()
+    module.init_optimizer(kvstore=kvstore,
+                          optimizer=optimizer,
+                          optimizer_params=optimizer_params)
+    n_epoch = 0
+    while True:
+        if n_epoch >= num_epoch:
+            break
+        train_iter.reset()
+        val_iter.reset()
+        loss_metric.reset()
+        for n_batch, data_batch in enumerate(train_iter):
+            module.forward_backward(data_batch)
+            module.update()
+            module.update_metric(loss_metric, data_batch.label)
+            loss_metric.get_batch_log(n_batch)
+        train_acc, train_loss, train_recon_err = loss_metric.get_name_value()
+        loss_metric.reset()
+        for n_batch, data_batch in enumerate(val_iter):
+            module.forward(data_batch)
+            module.update_metric(loss_metric, data_batch.label)
+            loss_metric.get_batch_log(n_batch)
+        val_acc, val_loss, val_recon_err = loss_metric.get_name_value()
+
+        summary_writer.add_scalar('train_acc', train_acc, n_epoch)
+        summary_writer.add_scalar('train_loss', train_loss, n_epoch)
+        summary_writer.add_scalar('train_recon_err', train_recon_err, n_epoch)
+        summary_writer.add_scalar('val_acc', val_acc, n_epoch)
+        summary_writer.add_scalar('val_loss', val_loss, n_epoch)
+        summary_writer.add_scalar('val_recon_err', val_recon_err, n_epoch)
+
+        print('Epoch[%d] train acc: %.4f loss: %.6f recon_err: %.6f' % (n_epoch, train_acc, train_loss, train_recon_err))
+        print('Epoch[%d] val acc: %.4f loss: %.6f recon_err: %.6f' % (n_epoch, val_acc, val_loss, val_recon_err))
+        print('SAVE CHECKPOINT')
+
+        module.save_checkpoint(prefix=model_prefix, epoch=n_epoch)
+        n_epoch += 1
+        lr_scheduler.learning_rate = learning_rate * (decay ** n_epoch)
+
+
+def apply_transform(x,
+                    transform_matrix,
+                    fill_mode='nearest',
+                    cval=0.):
+    x = np.rollaxis(x, 0, 0)
+    final_affine_matrix = transform_matrix[:2, :2]
+    final_offset = transform_matrix[:2, 2]
+    channel_images = [ndi.interpolation.affine_transform(
+        x_channel,
+        final_affine_matrix,
+        final_offset,
+        order=0,
+        mode=fill_mode,
+        cval=cval) for x_channel in x]
+    x = np.stack(channel_images, axis=0)
+    x = np.rollaxis(x, 0, 0 + 1)
+    return x
+
+
+def random_shift(x, width_shift_fraction, height_shift_fraction):
+    tx = np.random.uniform(-height_shift_fraction, height_shift_fraction) * x.shape[2]
+    ty = np.random.uniform(-width_shift_fraction, width_shift_fraction) * x.shape[1]
+    shift_matrix = np.array([[1, 0, tx],
+                             [0, 1, ty],
+                             [0, 0, 1]])
+    x = apply_transform(x, shift_matrix, 'nearest')
+    return x
+
+def _shuffle(data, idx):
+    """Shuffle the data."""
+    shuffle_data = []
+
+    for k, v in data:
+        shuffle_data.append((k, mx.ndarray.array(v.asnumpy()[idx], v.context)))
+
+    return shuffle_data
+
+class MNISTCustomIter(mx.io.NDArrayIter):
+    
+    def reset(self):
+        # shuffle data
+        if self.is_train:
+            np.random.shuffle(self.idx)
+            self.data = _shuffle(self.data, self.idx)
+            self.label = _shuffle(self.label, self.idx)
+        if self.last_batch_handle == 'roll_over' and self.cursor > self.num_data:
+            self.cursor = -self.batch_size + (self.cursor%self.num_data)%self.batch_size
+        else:
+            self.cursor = -self.batch_size
+    def set_is_train(self, is_train):
+        self.is_train = is_train
+    def next(self):
+        if self.iter_next():
+            if self.is_train:
+                data_raw_list = self.getdata()
+                data_shifted = []
+                for data_raw in data_raw_list[0]:
+                    data_shifted.append(random_shift(data_raw.asnumpy(), 0.1, 0.1))
+                return mx.io.DataBatch(data=[mx.nd.array(data_shifted)], label=self.getlabel(),
+                                       pad=self.getpad(), index=None)
+            else:
+                 return mx.io.DataBatch(data=self.getdata(), label=self.getlabel(), \
+                                  pad=self.getpad(), index=None)
+
+        else:
+            raise StopIteration
+
+
+if __name__ == "__main__":
+    # Read mnist data set
+    path = 'http://yann.lecun.com/exdb/mnist/'
+    (train_lbl, train_img) = read_data(
+        path + 'train-labels-idx1-ubyte.gz', path + 'train-images-idx3-ubyte.gz')
+    (val_lbl, val_img) = read_data(
+        path + 't10k-labels-idx1-ubyte.gz', path + 't10k-images-idx3-ubyte.gz')
+    # set batch size
+    import argparse
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--batch_size', default=100, type=int)
+    parser.add_argument('--devices', default='gpu0', type=str)
+    parser.add_argument('--num_epoch', default=100, type=int)
+    parser.add_argument('--lr', default=0.001, type=float)
+    parser.add_argument('--num_routing', default=3, type=int)
+    parser.add_argument('--model_prefix', default='capsnet', type=str)
+    parser.add_argument('--decay', default=0.9, type=float)
+    parser.add_argument('--tblog_dir', default='tblog', type=str)
+    parser.add_argument('--recon_loss_weight', default=0.392, type=float)
+    args = parser.parse_args()
+    for k, v in sorted(vars(args).items()):
+        print("{0}: {1}".format(k, v))
+    contexts = re.split(r'\W+', args.devices)
+    for i, ctx in enumerate(contexts):
+        if ctx[:3] == 'gpu':
+            contexts[i] = mx.context.gpu(int(ctx[3:]))
+        else:
+            contexts[i] = mx.context.cpu()
+    num_gpu = len(contexts)
+
+    if args.batch_size % num_gpu != 0:
+        raise Exception('num_gpu should be positive divisor of batch_size')
+
+    # generate train_iter, val_iter
+    train_iter = MNISTCustomIter(data=to4d(train_img), label=train_lbl, batch_size=args.batch_size, shuffle=True)
+    train_iter.set_is_train(True)
+    val_iter = MNISTCustomIter(data=to4d(val_img), label=val_lbl, batch_size=args.batch_size,)
+    val_iter.set_is_train(False)
+    # define capsnet
+    final_net = capsnet(batch_size=args.batch_size/num_gpu, n_class=10, num_routing=args.num_routing, recon_loss_weight=args.recon_loss_weight)
+    # set metric
+    loss_metric = LossMetric(args.batch_size/num_gpu, 1)
+
+    # run model
+    module = mx.mod.Module(symbol=final_net, context=contexts, data_names=('data',), label_names=('softmax_label',))
+    module.bind(data_shapes=train_iter.provide_data,
+                label_shapes=val_iter.provide_label,
+                for_training=True)
+    do_training(num_epoch=args.num_epoch, optimizer='adam', kvstore='device', learning_rate=args.lr,
+                model_prefix=args.model_prefix, decay=args.decay)
diff --git a/example/capsnet/result.PNG b/example/capsnet/result.PNG
new file mode 100644
index 0000000..62885dd
Binary files /dev/null and b/example/capsnet/result.PNG differ

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].