You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/12/14 19:28:38 UTC
[incubator-mxnet] branch master updated: add CapsNet example (#8787)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 8623bab add CapsNet example (#8787)
8623bab is described below
commit 8623bab8e2a495c69c8b83b12a7ee0ca35464632
Author: Soonhwan-Kwon <So...@users.noreply.github.com>
AuthorDate: Fri Dec 15 04:28:35 2017 +0900
add CapsNet example (#8787)
* add capsnet example's layer
* add capsnet example
* add recon_loss_weight option and tensorboard for plot
* update readme to install tensorboard
* fix print of loss scaled to 1/batchsize
---
example/capsnet/README.md | 66 ++++++++
example/capsnet/capsulelayers.py | 106 ++++++++++++
example/capsnet/capsulenet.py | 348 +++++++++++++++++++++++++++++++++++++++
example/capsnet/result.PNG | Bin 0 -> 31313 bytes
4 files changed, 520 insertions(+)
diff --git a/example/capsnet/README.md b/example/capsnet/README.md
new file mode 100644
index 0000000..49a6dd1
--- /dev/null
+++ b/example/capsnet/README.md
@@ -0,0 +1,66 @@
+**CapsNet-MXNet**
+=========================================
+
+This example is MXNet implementation of [CapsNet](https://arxiv.org/abs/1710.09829):
+Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between Capsules. NIPS 2017
+- The current `best test error is 0.29%` and `average test error is 0.303%`
+- The `average test error on paper is 0.25%`
+
+Log files for the error rate are uploaded in [repository](https://github.com/samsungsds-rnd/capsnet.mxnet).
+* * *
+## **Usage**
+Install scipy with pip
+```
+pip install scipy
+```
+Install tensorboard with pip
+```
+pip install tensorboard
+```
+
+On Single gpu
+```
+python capsulenet.py --devices gpu0
+```
+On Multi gpus
+```
+python capsulenet.py --devices gpu0,gpu1
+```
+Full arguments
+```
+python capsulenet.py --batch_size 100 --devices gpu0,gpu1 --num_epoch 100 --lr 0.001 --num_routing 3 --model_prefix capsnet
+```
+
+* * *
+## **Prerequisities**
+
+MXNet version above (0.11.0)
+scipy version above (0.19.0)
+
+***
+## **Results**
+Train time takes about 36 seconds for each epoch (batch_size=100, 2 gtx 1080 gpus)
+
+CapsNet classification test error on MNIST
+
+```
+python capsulenet.py --devices gpu0,gpu1 --lr 0.0005 --decay 0.99 --model_prefix lr_0_0005_decay_0_99 --batch_size 100 --num_routing 3 --num_epoch 200
+```
+
+![](result.PNG)
+
+| Trial | Epoch | train err(%) | test err(%) | train loss | test loss |
+| :---: | :---: | :---: | :---: | :---: | :---: |
+| 1 | 120 | 0.06 | 0.31 | 0.0056 | 0.0064 |
+| 2 | 167 | 0.03 | 0.29 | 0.0048 | 0.0058 |
+| 3 | 182 | 0.04 | 0.31 | 0.0046 | 0.0058 |
+| average | - | 0.043 | 0.303 | 0.005 | 0.006 |
+
+We achieved `the best test error rate=0.29%` and `average test error=0.303%`. It is the best accuracy and fastest training time result among other implementations(Keras, Tensorflow at 2017-11-23).
+The result on paper is `0.25% (average test error rate)`.
+
+| Implementation| test err(%) | ※train time/epoch | GPU Used|
+| :---: | :---: | :---: |:---: |
+| MXNet | 0.29 | 36 sec | 2 GTX 1080 |
+| tensorflow | 0.49 | ※ 10 min | Unknown(4GB Memory) |
+| Keras | 0.30 | 55 sec | 2 GTX 1080 Ti |
diff --git a/example/capsnet/capsulelayers.py b/example/capsnet/capsulelayers.py
new file mode 100644
index 0000000..5ac4fad
--- /dev/null
+++ b/example/capsnet/capsulelayers.py
@@ -0,0 +1,106 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+
+
+def squash(data, squash_axis, name=''):
+ epsilon = 1e-08
+ s_squared_norm = mx.sym.sum(data=mx.sym.square(data, name='square_'+name),
+ axis=squash_axis, keepdims=True, name='s_squared_norm_'+name)
+ scale = s_squared_norm / (1 + s_squared_norm) / mx.sym.sqrt(data=(s_squared_norm+epsilon),
+ name='s_squared_norm_sqrt_'+name)
+ squashed_net = mx.sym.broadcast_mul(scale, data, name='squashed_net_'+name)
+ return squashed_net
+
+
+def primary_caps(data, dim_vector, n_channels, kernel, strides, name=''):
+ out = mx.sym.Convolution(data=data,
+ num_filter=dim_vector * n_channels,
+ kernel=kernel,
+ stride=strides,
+ name=name
+ )
+ out = mx.sym.Reshape(data=out, shape=(0, -1, dim_vector))
+ out = squash(out, squash_axis=2)
+ return out
+
+
+class CapsuleLayer:
+ """
+ The capsule layer with dynamic routing.
+ [batch_size, input_num_capsule, input_dim_vector] => [batch_size, num_capsule, dim_vector]
+ """
+
+ def __init__(self, num_capsule, dim_vector, batch_size, kernel_initializer, bias_initializer, num_routing=3):
+ self.num_capsule = num_capsule
+ self.dim_vector = dim_vector
+ self.batch_size = batch_size
+ self.num_routing = num_routing
+ self.kernel_initializer = kernel_initializer
+ self.bias_initializer = bias_initializer
+
+ def __call__(self, data):
+ _, out_shapes, __ = data.infer_shape(data=(self.batch_size, 1, 28, 28))
+ _, input_num_capsule, input_dim_vector = out_shapes[0]
+
+ # build w and bias
+ # W : (input_num_capsule, num_capsule, input_dim_vector, dim_vector)
+ # bias : (batch_size, input_num_capsule, num_capsule ,1, 1)
+ w = mx.sym.Variable('Weight',
+ shape=(1, input_num_capsule, self.num_capsule, input_dim_vector, self.dim_vector),
+ init=self.kernel_initializer)
+ bias = mx.sym.Variable('Bias',
+ shape=(self.batch_size, input_num_capsule, self.num_capsule, 1, 1),
+ init=self.bias_initializer)
+ bias = mx.sym.BlockGrad(bias)
+ bias_ = bias
+
+ # input : (batch_size, input_num_capsule, input_dim_vector)
+ # inputs_expand : (batch_size, input_num_capsule, 1, input_dim_vector, 1)
+ inputs_expand = mx.sym.Reshape(data=data, shape=(0, 0, -4, -1, 1))
+ inputs_expand = mx.sym.Reshape(data=inputs_expand, shape=(0, 0, -4, 1, -1, 0))
+ # input_tiled (batch_size, input_num_capsule, num_capsule, input_dim_vector, 1)
+ inputs_tiled = mx.sym.tile(data=inputs_expand, reps=(1, 1, self.num_capsule, 1, 1))
+ # w_tiled : [(1L, input_num_capsule, num_capsule, input_dim_vector, dim_vector)]
+ w_tiled = mx.sym.tile(w, reps=(self.batch_size, 1, 1, 1, 1))
+
+ # inputs_hat : [(1L, input_num_capsule, num_capsule, 1, dim_vector)]
+ inputs_hat = mx.sym.linalg_gemm2(w_tiled, inputs_tiled, transpose_a=True)
+
+ inputs_hat = mx.sym.swapaxes(data=inputs_hat, dim1=3, dim2=4)
+ inputs_hat_stopped = inputs_hat
+ inputs_hat_stopped = mx.sym.BlockGrad(inputs_hat_stopped)
+
+ for i in range(0, self.num_routing):
+ c = mx.sym.softmax(bias_, axis=2, name='c' + str(i))
+ if i == self.num_routing - 1:
+ outputs = squash(
+ mx.sym.sum(mx.sym.broadcast_mul(c, inputs_hat, name='broadcast_mul_' + str(i)),
+ axis=1, keepdims=True,
+ name='sum_' + str(i)), name='output_' + str(i), squash_axis=4)
+ else:
+ outputs = squash(
+ mx.sym.sum(mx.sym.broadcast_mul(c, inputs_hat_stopped, name='broadcast_mul_' + str(i)),
+ axis=1, keepdims=True,
+ name='sum_' + str(i)), name='output_' + str(i), squash_axis=4)
+ bias_ = bias_ + mx.sym.sum(mx.sym.broadcast_mul(c, inputs_hat_stopped, name='bias_broadcast_mul' + str(i)),
+ axis=4,
+ keepdims=True, name='bias_' + str(i))
+
+ outputs = mx.sym.Reshape(data=outputs, shape=(-1, self.num_capsule, self.dim_vector))
+ return outputs
diff --git a/example/capsnet/capsulenet.py b/example/capsnet/capsulenet.py
new file mode 100644
index 0000000..6b44c3d
--- /dev/null
+++ b/example/capsnet/capsulenet.py
@@ -0,0 +1,348 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import mxnet as mx
+import numpy as np
+import os
+import re
+import urllib
+import gzip
+import struct
+import scipy.ndimage as ndi
+from capsulelayers import primary_caps, CapsuleLayer
+
+from tensorboard import SummaryWriter
+
+def margin_loss(y_true, y_pred):
+ loss = y_true * mx.sym.square(mx.sym.maximum(0., 0.9 - y_pred)) +\
+ 0.5 * (1 - y_true) * mx.sym.square(mx.sym.maximum(0., y_pred - 0.1))
+ return mx.sym.mean(data=mx.sym.sum(loss, 1))
+
+
+def capsnet(batch_size, n_class, num_routing,recon_loss_weight):
+ # data.shape = [batch_size, 1, 28, 28]
+ data = mx.sym.Variable('data')
+
+ input_shape = (1, 28, 28)
+ # Conv2D layer
+ # net.shape = [batch_size, 256, 20, 20]
+ conv1 = mx.sym.Convolution(data=data,
+ num_filter=256,
+ kernel=(9, 9),
+ layout='NCHW',
+ name='conv1')
+ conv1 = mx.sym.Activation(data=conv1, act_type='relu', name='conv1_act')
+ # net.shape = [batch_size, 256, 6, 6]
+
+ primarycaps = primary_caps(data=conv1,
+ dim_vector=8,
+ n_channels=32,
+ kernel=(9, 9),
+ strides=[2, 2],
+ name='primarycaps')
+ primarycaps.infer_shape(data=(batch_size, 1, 28, 28))
+ # CapsuleLayer
+ kernel_initializer = mx.init.Xavier(rnd_type='uniform', factor_type='avg', magnitude=3)
+ bias_initializer = mx.init.Zero()
+ digitcaps = CapsuleLayer(num_capsule=10,
+ dim_vector=16,
+ batch_size=batch_size,
+ kernel_initializer=kernel_initializer,
+ bias_initializer=bias_initializer,
+ num_routing=num_routing)(primarycaps)
+
+ # out_caps : (batch_size, 10)
+ out_caps = mx.sym.sqrt(data=mx.sym.sum(mx.sym.square(digitcaps), 2))
+ out_caps.infer_shape(data=(batch_size, 1, 28, 28))
+
+ y = mx.sym.Variable('softmax_label', shape=(batch_size,))
+ y_onehot = mx.sym.one_hot(y, n_class)
+ y_reshaped = mx.sym.Reshape(data=y_onehot, shape=(batch_size, -4, n_class, -1))
+ y_reshaped.infer_shape(softmax_label=(batch_size,))
+
+ # inputs_masked : (batch_size, 16)
+ inputs_masked = mx.sym.linalg_gemm2(y_reshaped, digitcaps, transpose_a=True)
+ inputs_masked = mx.sym.Reshape(data=inputs_masked, shape=(-3, 0))
+ x_recon = mx.sym.FullyConnected(data=inputs_masked, num_hidden=512, name='x_recon')
+ x_recon = mx.sym.Activation(data=x_recon, act_type='relu', name='x_recon_act')
+ x_recon = mx.sym.FullyConnected(data=x_recon, num_hidden=1024, name='x_recon2')
+ x_recon = mx.sym.Activation(data=x_recon, act_type='relu', name='x_recon_act2')
+ x_recon = mx.sym.FullyConnected(data=x_recon, num_hidden=np.prod(input_shape), name='x_recon3')
+ x_recon = mx.sym.Activation(data=x_recon, act_type='sigmoid', name='x_recon_act3')
+
+ data_flatten = mx.sym.flatten(data=data)
+ squared_error = mx.sym.square(x_recon-data_flatten)
+ recon_error = mx.sym.mean(squared_error)
+ recon_error_stopped = recon_error
+ recon_error_stopped = mx.sym.BlockGrad(recon_error_stopped)
+ loss = mx.symbol.MakeLoss((1-recon_loss_weight)*margin_loss(y_onehot, out_caps)+recon_loss_weight*recon_error)
+
+ out_caps_blocked = out_caps
+ out_caps_blocked = mx.sym.BlockGrad(out_caps_blocked)
+ return mx.sym.Group([out_caps_blocked, loss, recon_error_stopped])
+
+
+def download_data(url, force_download=False):
+ fname = url.split("/")[-1]
+ if force_download or not os.path.exists(fname):
+ urllib.urlretrieve(url, fname)
+ return fname
+
+
+def read_data(label_url, image_url):
+ with gzip.open(download_data(label_url)) as flbl:
+ magic, num = struct.unpack(">II", flbl.read(8))
+ label = np.fromstring(flbl.read(), dtype=np.int8)
+ with gzip.open(download_data(image_url), 'rb') as fimg:
+ magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
+ image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
+ return label, image
+
+
+def to4d(img):
+ return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255
+
+
+class LossMetric(mx.metric.EvalMetric):
+ def __init__(self, batch_size, num_gpu):
+ super(LossMetric, self).__init__('LossMetric')
+ self.batch_size = batch_size
+ self.num_gpu = num_gpu
+ self.sum_metric = 0
+ self.num_inst = 0
+ self.loss = 0.0
+ self.batch_sum_metric = 0
+ self.batch_num_inst = 0
+ self.batch_loss = 0.0
+ self.recon_loss = 0.0
+ self.n_batch = 0
+
+ def update(self, labels, preds):
+ batch_sum_metric = 0
+ batch_num_inst = 0
+ for label, pred_outcaps in zip(labels[0], preds[0]):
+ label_np = int(label.asnumpy())
+ pred_label = int(np.argmax(pred_outcaps.asnumpy()))
+ batch_sum_metric += int(label_np == pred_label)
+ batch_num_inst += 1
+ batch_loss = preds[1].asnumpy()
+ recon_loss = preds[2].asnumpy()
+ self.sum_metric += batch_sum_metric
+ self.num_inst += batch_num_inst
+ self.loss += batch_loss
+ self.recon_loss += recon_loss
+ self.batch_sum_metric = batch_sum_metric
+ self.batch_num_inst = batch_num_inst
+ self.batch_loss = batch_loss
+ self.n_batch += 1
+
+ def get_name_value(self):
+ acc = float(self.sum_metric)/float(self.num_inst)
+ mean_loss = self.loss / float(self.n_batch)
+ mean_recon_loss = self.recon_loss / float(self.n_batch)
+ return acc, mean_loss, mean_recon_loss
+
+ def get_batch_log(self, n_batch):
+ print("n_batch :"+str(n_batch)+" batch_acc:" +
+ str(float(self.batch_sum_metric) / float(self.batch_num_inst)) +
+ ' batch_loss:' + str(float(self.batch_loss)/float(self.batch_num_inst)))
+ self.batch_sum_metric = 0
+ self.batch_num_inst = 0
+ self.batch_loss = 0.0
+
+ def reset(self):
+ self.sum_metric = 0
+ self.num_inst = 0
+ self.loss = 0.0
+ self.recon_loss = 0.0
+ self.n_batch = 0
+
+
+class SimpleLRScheduler(mx.lr_scheduler.LRScheduler):
+ """A simple lr schedule that simply return `dynamic_lr`. We will set `dynamic_lr`
+ dynamically based on performance on the validation set.
+ """
+
+ def __init__(self, learning_rate=0.001):
+ super(SimpleLRScheduler, self).__init__()
+ self.learning_rate = learning_rate
+
+ def __call__(self, num_update):
+ return self.learning_rate
+
+
+def do_training(num_epoch, optimizer, kvstore, learning_rate, model_prefix, decay):
+ summary_writer = SummaryWriter(args.tblog_dir)
+ lr_scheduler = SimpleLRScheduler(learning_rate)
+ optimizer_params = {'lr_scheduler': lr_scheduler}
+ module.init_params()
+ module.init_optimizer(kvstore=kvstore,
+ optimizer=optimizer,
+ optimizer_params=optimizer_params)
+ n_epoch = 0
+ while True:
+ if n_epoch >= num_epoch:
+ break
+ train_iter.reset()
+ val_iter.reset()
+ loss_metric.reset()
+ for n_batch, data_batch in enumerate(train_iter):
+ module.forward_backward(data_batch)
+ module.update()
+ module.update_metric(loss_metric, data_batch.label)
+ loss_metric.get_batch_log(n_batch)
+ train_acc, train_loss, train_recon_err = loss_metric.get_name_value()
+ loss_metric.reset()
+ for n_batch, data_batch in enumerate(val_iter):
+ module.forward(data_batch)
+ module.update_metric(loss_metric, data_batch.label)
+ loss_metric.get_batch_log(n_batch)
+ val_acc, val_loss, val_recon_err = loss_metric.get_name_value()
+
+ summary_writer.add_scalar('train_acc', train_acc, n_epoch)
+ summary_writer.add_scalar('train_loss', train_loss, n_epoch)
+ summary_writer.add_scalar('train_recon_err', train_recon_err, n_epoch)
+ summary_writer.add_scalar('val_acc', val_acc, n_epoch)
+ summary_writer.add_scalar('val_loss', val_loss, n_epoch)
+ summary_writer.add_scalar('val_recon_err', val_recon_err, n_epoch)
+
+ print('Epoch[%d] train acc: %.4f loss: %.6f recon_err: %.6f' % (n_epoch, train_acc, train_loss, train_recon_err))
+ print('Epoch[%d] val acc: %.4f loss: %.6f recon_err: %.6f' % (n_epoch, val_acc, val_loss, val_recon_err))
+ print('SAVE CHECKPOINT')
+
+ module.save_checkpoint(prefix=model_prefix, epoch=n_epoch)
+ n_epoch += 1
+ lr_scheduler.learning_rate = learning_rate * (decay ** n_epoch)
+
+
+def apply_transform(x,
+ transform_matrix,
+ fill_mode='nearest',
+ cval=0.):
+ x = np.rollaxis(x, 0, 0)
+ final_affine_matrix = transform_matrix[:2, :2]
+ final_offset = transform_matrix[:2, 2]
+ channel_images = [ndi.interpolation.affine_transform(
+ x_channel,
+ final_affine_matrix,
+ final_offset,
+ order=0,
+ mode=fill_mode,
+ cval=cval) for x_channel in x]
+ x = np.stack(channel_images, axis=0)
+ x = np.rollaxis(x, 0, 0 + 1)
+ return x
+
+
+def random_shift(x, width_shift_fraction, height_shift_fraction):
+ tx = np.random.uniform(-height_shift_fraction, height_shift_fraction) * x.shape[2]
+ ty = np.random.uniform(-width_shift_fraction, width_shift_fraction) * x.shape[1]
+ shift_matrix = np.array([[1, 0, tx],
+ [0, 1, ty],
+ [0, 0, 1]])
+ x = apply_transform(x, shift_matrix, 'nearest')
+ return x
+
+def _shuffle(data, idx):
+ """Shuffle the data."""
+ shuffle_data = []
+
+ for k, v in data:
+ shuffle_data.append((k, mx.ndarray.array(v.asnumpy()[idx], v.context)))
+
+ return shuffle_data
+
+class MNISTCustomIter(mx.io.NDArrayIter):
+
+ def reset(self):
+ # shuffle data
+ if self.is_train:
+ np.random.shuffle(self.idx)
+ self.data = _shuffle(self.data, self.idx)
+ self.label = _shuffle(self.label, self.idx)
+ if self.last_batch_handle == 'roll_over' and self.cursor > self.num_data:
+ self.cursor = -self.batch_size + (self.cursor%self.num_data)%self.batch_size
+ else:
+ self.cursor = -self.batch_size
+ def set_is_train(self, is_train):
+ self.is_train = is_train
+ def next(self):
+ if self.iter_next():
+ if self.is_train:
+ data_raw_list = self.getdata()
+ data_shifted = []
+ for data_raw in data_raw_list[0]:
+ data_shifted.append(random_shift(data_raw.asnumpy(), 0.1, 0.1))
+ return mx.io.DataBatch(data=[mx.nd.array(data_shifted)], label=self.getlabel(),
+ pad=self.getpad(), index=None)
+ else:
+ return mx.io.DataBatch(data=self.getdata(), label=self.getlabel(), \
+ pad=self.getpad(), index=None)
+
+ else:
+ raise StopIteration
+
+
+if __name__ == "__main__":
+ # Read mnist data set
+ path = 'http://yann.lecun.com/exdb/mnist/'
+ (train_lbl, train_img) = read_data(
+ path + 'train-labels-idx1-ubyte.gz', path + 'train-images-idx3-ubyte.gz')
+ (val_lbl, val_img) = read_data(
+ path + 't10k-labels-idx1-ubyte.gz', path + 't10k-images-idx3-ubyte.gz')
+ # set batch size
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--batch_size', default=100, type=int)
+ parser.add_argument('--devices', default='gpu0', type=str)
+ parser.add_argument('--num_epoch', default=100, type=int)
+ parser.add_argument('--lr', default=0.001, type=float)
+ parser.add_argument('--num_routing', default=3, type=int)
+ parser.add_argument('--model_prefix', default='capsnet', type=str)
+ parser.add_argument('--decay', default=0.9, type=float)
+ parser.add_argument('--tblog_dir', default='tblog', type=str)
+ parser.add_argument('--recon_loss_weight', default=0.392, type=float)
+ args = parser.parse_args()
+ for k, v in sorted(vars(args).items()):
+ print("{0}: {1}".format(k, v))
+ contexts = re.split(r'\W+', args.devices)
+ for i, ctx in enumerate(contexts):
+ if ctx[:3] == 'gpu':
+ contexts[i] = mx.context.gpu(int(ctx[3:]))
+ else:
+ contexts[i] = mx.context.cpu()
+ num_gpu = len(contexts)
+
+ if args.batch_size % num_gpu != 0:
+ raise Exception('num_gpu should be positive divisor of batch_size')
+
+ # generate train_iter, val_iter
+ train_iter = MNISTCustomIter(data=to4d(train_img), label=train_lbl, batch_size=args.batch_size, shuffle=True)
+ train_iter.set_is_train(True)
+ val_iter = MNISTCustomIter(data=to4d(val_img), label=val_lbl, batch_size=args.batch_size,)
+ val_iter.set_is_train(False)
+ # define capsnet
+ final_net = capsnet(batch_size=args.batch_size/num_gpu, n_class=10, num_routing=args.num_routing, recon_loss_weight=args.recon_loss_weight)
+ # set metric
+ loss_metric = LossMetric(args.batch_size/num_gpu, 1)
+
+ # run model
+ module = mx.mod.Module(symbol=final_net, context=contexts, data_names=('data',), label_names=('softmax_label',))
+ module.bind(data_shapes=train_iter.provide_data,
+ label_shapes=val_iter.provide_label,
+ for_training=True)
+ do_training(num_epoch=args.num_epoch, optimizer='adam', kvstore='device', learning_rate=args.lr,
+ model_prefix=args.model_prefix, decay=args.decay)
diff --git a/example/capsnet/result.PNG b/example/capsnet/result.PNG
new file mode 100644
index 0000000..62885dd
Binary files /dev/null and b/example/capsnet/result.PNG differ
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].