This is an automated email from the ASF dual-hosted git repository.
indhub pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 68bc9b7 Updated capsnet example (#12934)
68bc9b7 is described below
commit 68bc9b7f444e76e42c02adfa97ec12149ba0d996
Author: Thomas Delteil <[email protected]>
AuthorDate: Thu Nov 8 08:38:42 2018 -0800
Updated capsnet example (#12934)
* Updated capsnet
* trigger CI
* Update README.md
---
example/capsnet/README.md | 132 ++++----
example/capsnet/capsulenet.py | 695 +++++++++++++++++++++---------------------
2 files changed, 413 insertions(+), 414 deletions(-)
diff --git a/example/capsnet/README.md b/example/capsnet/README.md
index 49a6dd1..500c7df 100644
--- a/example/capsnet/README.md
+++ b/example/capsnet/README.md
@@ -1,66 +1,66 @@
-**CapsNet-MXNet**
-=========================================
-
-This example is MXNet implementation of
[CapsNet](https://arxiv.org/abs/1710.09829):
-Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between
Capsules. NIPS 2017
-- The current `best test error is 0.29%` and `average test error is 0.303%`
-- The `average test error on paper is 0.25%`
-
-Log files for the error rate are uploaded in
[repository](https://github.com/samsungsds-rnd/capsnet.mxnet).
-* * *
-## **Usage**
-Install scipy with pip
-```
-pip install scipy
-```
-Install tensorboard with pip
-```
-pip install tensorboard
-```
-
-On Single gpu
-```
-python capsulenet.py --devices gpu0
-```
-On Multi gpus
-```
-python capsulenet.py --devices gpu0,gpu1
-```
-Full arguments
-```
-python capsulenet.py --batch_size 100 --devices gpu0,gpu1 --num_epoch 100 --lr
0.001 --num_routing 3 --model_prefix capsnet
-```
-
-* * *
-## **Prerequisities**
-
-MXNet version above (0.11.0)
-scipy version above (0.19.0)
-
-***
-## **Results**
-Train time takes about 36 seconds for each epoch (batch_size=100, 2 gtx 1080
gpus)
-
-CapsNet classification test error on MNIST
-
-```
-python capsulenet.py --devices gpu0,gpu1 --lr 0.0005 --decay 0.99
--model_prefix lr_0_0005_decay_0_99 --batch_size 100 --num_routing 3
--num_epoch 200
-```
-
-
-
-| Trial | Epoch | train err(%) | test err(%) | train loss | test loss |
-| :---: | :---: | :---: | :---: | :---: | :---: |
-| 1 | 120 | 0.06 | 0.31 | 0.0056 | 0.0064 |
-| 2 | 167 | 0.03 | 0.29 | 0.0048 | 0.0058 |
-| 3 | 182 | 0.04 | 0.31 | 0.0046 | 0.0058 |
-| average | - | 0.043 | 0.303 | 0.005 | 0.006 |
-
-We achieved `the best test error rate=0.29%` and `average test error=0.303%`.
It is the best accuracy and fastest training time result among other
implementations(Keras, Tensorflow at 2017-11-23).
-The result on paper is `0.25% (average test error rate)`.
-
-| Implementation| test err(%) | ※train time/epoch | GPU Used|
-| :---: | :---: | :---: |:---: |
-| MXNet | 0.29 | 36 sec | 2 GTX 1080 |
-| tensorflow | 0.49 | ※ 10 min | Unknown(4GB Memory) |
-| Keras | 0.30 | 55 sec | 2 GTX 1080 Ti |
+**CapsNet-MXNet**
+=========================================
+
+This example is MXNet implementation of
[CapsNet](https://arxiv.org/abs/1710.09829):
+Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between
Capsules. NIPS 2017
+- The current `best test error is 0.29%` and `average test error is 0.303%`
+- The `average test error on paper is 0.25%`
+
+Log files for the error rate are uploaded in
[repository](https://github.com/samsungsds-rnd/capsnet.mxnet).
+* * *
+## **Usage**
+Install scipy with pip
+```
+pip install scipy
+```
+Install tensorboard and mxboard with pip
+```
+pip install mxboard tensorflow
+```
+
+On Single gpu
+```
+python capsulenet.py --devices gpu0
+```
+On Multi gpus
+```
+python capsulenet.py --devices gpu0,gpu1
+```
+Full arguments
+```
+python capsulenet.py --batch_size 100 --devices gpu0,gpu1 --num_epoch 100 --lr
0.001 --num_routing 3 --model_prefix capsnet
+```
+
+* * *
+## **Prerequisities**
+
+MXNet version above (1.2.0)
+scipy version above (0.19.0)
+
+***
+## **Results**
+Train time takes about 36 seconds for each epoch (batch_size=100, 2 gtx 1080
gpus)
+
+CapsNet classification test error on MNIST:
+
+```
+python capsulenet.py --devices gpu0,gpu1 --lr 0.0005 --decay 0.99
--model_prefix lr_0_0005_decay_0_99 --batch_size 100 --num_routing 3
--num_epoch 200
+```
+
+
+
+| Trial | Epoch | train err(%) | test err(%) | train loss | test loss |
+| :---: | :---: | :---: | :---: | :---: | :---: |
+| 1 | 120 | 0.06 | 0.31 | 0.0056 | 0.0064 |
+| 2 | 167 | 0.03 | 0.29 | 0.0048 | 0.0058 |
+| 3 | 182 | 0.04 | 0.31 | 0.0046 | 0.0058 |
+| average | - | 0.043 | 0.303 | 0.005 | 0.006 |
+
+We achieved `the best test error rate=0.29%` and `average test error=0.303%`.
It is the best accuracy and fastest training time result among other
implementations(Keras, Tensorflow at 2017-11-23).
+The result on paper is `0.25% (average test error rate)`.
+
+| Implementation| test err(%) | ※train time/epoch | GPU Used|
+| :---: | :---: | :---: |:---: |
+| MXNet | 0.29 | 36 sec | 2 GTX 1080 |
+| tensorflow | 0.49 | ※ 10 min | Unknown(4GB Memory) |
+| Keras | 0.30 | 55 sec | 2 GTX 1080 Ti |
diff --git a/example/capsnet/capsulenet.py b/example/capsnet/capsulenet.py
index 6b44c3d..6710875 100644
--- a/example/capsnet/capsulenet.py
+++ b/example/capsnet/capsulenet.py
@@ -1,348 +1,347 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import mxnet as mx
-import numpy as np
-import os
-import re
-import urllib
-import gzip
-import struct
-import scipy.ndimage as ndi
-from capsulelayers import primary_caps, CapsuleLayer
-
-from tensorboard import SummaryWriter
-
-def margin_loss(y_true, y_pred):
- loss = y_true * mx.sym.square(mx.sym.maximum(0., 0.9 - y_pred)) +\
- 0.5 * (1 - y_true) * mx.sym.square(mx.sym.maximum(0., y_pred - 0.1))
- return mx.sym.mean(data=mx.sym.sum(loss, 1))
-
-
-def capsnet(batch_size, n_class, num_routing,recon_loss_weight):
- # data.shape = [batch_size, 1, 28, 28]
- data = mx.sym.Variable('data')
-
- input_shape = (1, 28, 28)
- # Conv2D layer
- # net.shape = [batch_size, 256, 20, 20]
- conv1 = mx.sym.Convolution(data=data,
- num_filter=256,
- kernel=(9, 9),
- layout='NCHW',
- name='conv1')
- conv1 = mx.sym.Activation(data=conv1, act_type='relu', name='conv1_act')
- # net.shape = [batch_size, 256, 6, 6]
-
- primarycaps = primary_caps(data=conv1,
- dim_vector=8,
- n_channels=32,
- kernel=(9, 9),
- strides=[2, 2],
- name='primarycaps')
- primarycaps.infer_shape(data=(batch_size, 1, 28, 28))
- # CapsuleLayer
- kernel_initializer = mx.init.Xavier(rnd_type='uniform', factor_type='avg',
magnitude=3)
- bias_initializer = mx.init.Zero()
- digitcaps = CapsuleLayer(num_capsule=10,
- dim_vector=16,
- batch_size=batch_size,
- kernel_initializer=kernel_initializer,
- bias_initializer=bias_initializer,
- num_routing=num_routing)(primarycaps)
-
- # out_caps : (batch_size, 10)
- out_caps = mx.sym.sqrt(data=mx.sym.sum(mx.sym.square(digitcaps), 2))
- out_caps.infer_shape(data=(batch_size, 1, 28, 28))
-
- y = mx.sym.Variable('softmax_label', shape=(batch_size,))
- y_onehot = mx.sym.one_hot(y, n_class)
- y_reshaped = mx.sym.Reshape(data=y_onehot, shape=(batch_size, -4, n_class,
-1))
- y_reshaped.infer_shape(softmax_label=(batch_size,))
-
- # inputs_masked : (batch_size, 16)
- inputs_masked = mx.sym.linalg_gemm2(y_reshaped, digitcaps,
transpose_a=True)
- inputs_masked = mx.sym.Reshape(data=inputs_masked, shape=(-3, 0))
- x_recon = mx.sym.FullyConnected(data=inputs_masked, num_hidden=512,
name='x_recon')
- x_recon = mx.sym.Activation(data=x_recon, act_type='relu',
name='x_recon_act')
- x_recon = mx.sym.FullyConnected(data=x_recon, num_hidden=1024,
name='x_recon2')
- x_recon = mx.sym.Activation(data=x_recon, act_type='relu',
name='x_recon_act2')
- x_recon = mx.sym.FullyConnected(data=x_recon,
num_hidden=np.prod(input_shape), name='x_recon3')
- x_recon = mx.sym.Activation(data=x_recon, act_type='sigmoid',
name='x_recon_act3')
-
- data_flatten = mx.sym.flatten(data=data)
- squared_error = mx.sym.square(x_recon-data_flatten)
- recon_error = mx.sym.mean(squared_error)
- recon_error_stopped = recon_error
- recon_error_stopped = mx.sym.BlockGrad(recon_error_stopped)
- loss = mx.symbol.MakeLoss((1-recon_loss_weight)*margin_loss(y_onehot,
out_caps)+recon_loss_weight*recon_error)
-
- out_caps_blocked = out_caps
- out_caps_blocked = mx.sym.BlockGrad(out_caps_blocked)
- return mx.sym.Group([out_caps_blocked, loss, recon_error_stopped])
-
-
-def download_data(url, force_download=False):
- fname = url.split("/")[-1]
- if force_download or not os.path.exists(fname):
- urllib.urlretrieve(url, fname)
- return fname
-
-
-def read_data(label_url, image_url):
- with gzip.open(download_data(label_url)) as flbl:
- magic, num = struct.unpack(">II", flbl.read(8))
- label = np.fromstring(flbl.read(), dtype=np.int8)
- with gzip.open(download_data(image_url), 'rb') as fimg:
- magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
- image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label),
rows, cols)
- return label, image
-
-
-def to4d(img):
- return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255
-
-
-class LossMetric(mx.metric.EvalMetric):
- def __init__(self, batch_size, num_gpu):
- super(LossMetric, self).__init__('LossMetric')
- self.batch_size = batch_size
- self.num_gpu = num_gpu
- self.sum_metric = 0
- self.num_inst = 0
- self.loss = 0.0
- self.batch_sum_metric = 0
- self.batch_num_inst = 0
- self.batch_loss = 0.0
- self.recon_loss = 0.0
- self.n_batch = 0
-
- def update(self, labels, preds):
- batch_sum_metric = 0
- batch_num_inst = 0
- for label, pred_outcaps in zip(labels[0], preds[0]):
- label_np = int(label.asnumpy())
- pred_label = int(np.argmax(pred_outcaps.asnumpy()))
- batch_sum_metric += int(label_np == pred_label)
- batch_num_inst += 1
- batch_loss = preds[1].asnumpy()
- recon_loss = preds[2].asnumpy()
- self.sum_metric += batch_sum_metric
- self.num_inst += batch_num_inst
- self.loss += batch_loss
- self.recon_loss += recon_loss
- self.batch_sum_metric = batch_sum_metric
- self.batch_num_inst = batch_num_inst
- self.batch_loss = batch_loss
- self.n_batch += 1
-
- def get_name_value(self):
- acc = float(self.sum_metric)/float(self.num_inst)
- mean_loss = self.loss / float(self.n_batch)
- mean_recon_loss = self.recon_loss / float(self.n_batch)
- return acc, mean_loss, mean_recon_loss
-
- def get_batch_log(self, n_batch):
- print("n_batch :"+str(n_batch)+" batch_acc:" +
- str(float(self.batch_sum_metric) / float(self.batch_num_inst)) +
- ' batch_loss:' +
str(float(self.batch_loss)/float(self.batch_num_inst)))
- self.batch_sum_metric = 0
- self.batch_num_inst = 0
- self.batch_loss = 0.0
-
- def reset(self):
- self.sum_metric = 0
- self.num_inst = 0
- self.loss = 0.0
- self.recon_loss = 0.0
- self.n_batch = 0
-
-
-class SimpleLRScheduler(mx.lr_scheduler.LRScheduler):
- """A simple lr schedule that simply return `dynamic_lr`. We will set
`dynamic_lr`
- dynamically based on performance on the validation set.
- """
-
- def __init__(self, learning_rate=0.001):
- super(SimpleLRScheduler, self).__init__()
- self.learning_rate = learning_rate
-
- def __call__(self, num_update):
- return self.learning_rate
-
-
-def do_training(num_epoch, optimizer, kvstore, learning_rate, model_prefix,
decay):
- summary_writer = SummaryWriter(args.tblog_dir)
- lr_scheduler = SimpleLRScheduler(learning_rate)
- optimizer_params = {'lr_scheduler': lr_scheduler}
- module.init_params()
- module.init_optimizer(kvstore=kvstore,
- optimizer=optimizer,
- optimizer_params=optimizer_params)
- n_epoch = 0
- while True:
- if n_epoch >= num_epoch:
- break
- train_iter.reset()
- val_iter.reset()
- loss_metric.reset()
- for n_batch, data_batch in enumerate(train_iter):
- module.forward_backward(data_batch)
- module.update()
- module.update_metric(loss_metric, data_batch.label)
- loss_metric.get_batch_log(n_batch)
- train_acc, train_loss, train_recon_err = loss_metric.get_name_value()
- loss_metric.reset()
- for n_batch, data_batch in enumerate(val_iter):
- module.forward(data_batch)
- module.update_metric(loss_metric, data_batch.label)
- loss_metric.get_batch_log(n_batch)
- val_acc, val_loss, val_recon_err = loss_metric.get_name_value()
-
- summary_writer.add_scalar('train_acc', train_acc, n_epoch)
- summary_writer.add_scalar('train_loss', train_loss, n_epoch)
- summary_writer.add_scalar('train_recon_err', train_recon_err, n_epoch)
- summary_writer.add_scalar('val_acc', val_acc, n_epoch)
- summary_writer.add_scalar('val_loss', val_loss, n_epoch)
- summary_writer.add_scalar('val_recon_err', val_recon_err, n_epoch)
-
- print('Epoch[%d] train acc: %.4f loss: %.6f recon_err: %.6f' %
(n_epoch, train_acc, train_loss, train_recon_err))
- print('Epoch[%d] val acc: %.4f loss: %.6f recon_err: %.6f' % (n_epoch,
val_acc, val_loss, val_recon_err))
- print('SAVE CHECKPOINT')
-
- module.save_checkpoint(prefix=model_prefix, epoch=n_epoch)
- n_epoch += 1
- lr_scheduler.learning_rate = learning_rate * (decay ** n_epoch)
-
-
-def apply_transform(x,
- transform_matrix,
- fill_mode='nearest',
- cval=0.):
- x = np.rollaxis(x, 0, 0)
- final_affine_matrix = transform_matrix[:2, :2]
- final_offset = transform_matrix[:2, 2]
- channel_images = [ndi.interpolation.affine_transform(
- x_channel,
- final_affine_matrix,
- final_offset,
- order=0,
- mode=fill_mode,
- cval=cval) for x_channel in x]
- x = np.stack(channel_images, axis=0)
- x = np.rollaxis(x, 0, 0 + 1)
- return x
-
-
-def random_shift(x, width_shift_fraction, height_shift_fraction):
- tx = np.random.uniform(-height_shift_fraction, height_shift_fraction) *
x.shape[2]
- ty = np.random.uniform(-width_shift_fraction, width_shift_fraction) *
x.shape[1]
- shift_matrix = np.array([[1, 0, tx],
- [0, 1, ty],
- [0, 0, 1]])
- x = apply_transform(x, shift_matrix, 'nearest')
- return x
-
-def _shuffle(data, idx):
- """Shuffle the data."""
- shuffle_data = []
-
- for k, v in data:
- shuffle_data.append((k, mx.ndarray.array(v.asnumpy()[idx], v.context)))
-
- return shuffle_data
-
-class MNISTCustomIter(mx.io.NDArrayIter):
-
- def reset(self):
- # shuffle data
- if self.is_train:
- np.random.shuffle(self.idx)
- self.data = _shuffle(self.data, self.idx)
- self.label = _shuffle(self.label, self.idx)
- if self.last_batch_handle == 'roll_over' and self.cursor >
self.num_data:
- self.cursor = -self.batch_size +
(self.cursor%self.num_data)%self.batch_size
- else:
- self.cursor = -self.batch_size
- def set_is_train(self, is_train):
- self.is_train = is_train
- def next(self):
- if self.iter_next():
- if self.is_train:
- data_raw_list = self.getdata()
- data_shifted = []
- for data_raw in data_raw_list[0]:
- data_shifted.append(random_shift(data_raw.asnumpy(), 0.1,
0.1))
- return mx.io.DataBatch(data=[mx.nd.array(data_shifted)],
label=self.getlabel(),
- pad=self.getpad(), index=None)
- else:
- return mx.io.DataBatch(data=self.getdata(),
label=self.getlabel(), \
- pad=self.getpad(), index=None)
-
- else:
- raise StopIteration
-
-
-if __name__ == "__main__":
- # Read mnist data set
- path = 'http://yann.lecun.com/exdb/mnist/'
- (train_lbl, train_img) = read_data(
- path + 'train-labels-idx1-ubyte.gz', path +
'train-images-idx3-ubyte.gz')
- (val_lbl, val_img) = read_data(
- path + 't10k-labels-idx1-ubyte.gz', path + 't10k-images-idx3-ubyte.gz')
- # set batch size
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument('--batch_size', default=100, type=int)
- parser.add_argument('--devices', default='gpu0', type=str)
- parser.add_argument('--num_epoch', default=100, type=int)
- parser.add_argument('--lr', default=0.001, type=float)
- parser.add_argument('--num_routing', default=3, type=int)
- parser.add_argument('--model_prefix', default='capsnet', type=str)
- parser.add_argument('--decay', default=0.9, type=float)
- parser.add_argument('--tblog_dir', default='tblog', type=str)
- parser.add_argument('--recon_loss_weight', default=0.392, type=float)
- args = parser.parse_args()
- for k, v in sorted(vars(args).items()):
- print("{0}: {1}".format(k, v))
- contexts = re.split(r'\W+', args.devices)
- for i, ctx in enumerate(contexts):
- if ctx[:3] == 'gpu':
- contexts[i] = mx.context.gpu(int(ctx[3:]))
- else:
- contexts[i] = mx.context.cpu()
- num_gpu = len(contexts)
-
- if args.batch_size % num_gpu != 0:
- raise Exception('num_gpu should be positive divisor of batch_size')
-
- # generate train_iter, val_iter
- train_iter = MNISTCustomIter(data=to4d(train_img), label=train_lbl,
batch_size=args.batch_size, shuffle=True)
- train_iter.set_is_train(True)
- val_iter = MNISTCustomIter(data=to4d(val_img), label=val_lbl,
batch_size=args.batch_size,)
- val_iter.set_is_train(False)
- # define capsnet
- final_net = capsnet(batch_size=args.batch_size/num_gpu, n_class=10,
num_routing=args.num_routing, recon_loss_weight=args.recon_loss_weight)
- # set metric
- loss_metric = LossMetric(args.batch_size/num_gpu, 1)
-
- # run model
- module = mx.mod.Module(symbol=final_net, context=contexts,
data_names=('data',), label_names=('softmax_label',))
- module.bind(data_shapes=train_iter.provide_data,
- label_shapes=val_iter.provide_label,
- for_training=True)
- do_training(num_epoch=args.num_epoch, optimizer='adam', kvstore='device',
learning_rate=args.lr,
- model_prefix=args.model_prefix, decay=args.decay)
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import mxnet as mx
+import numpy as np
+import os
+import re
+import gzip
+import struct
+import scipy.ndimage as ndi
+from capsulelayers import primary_caps, CapsuleLayer
+
+from mxboard import SummaryWriter
+
+def margin_loss(y_true, y_pred):
+ loss = y_true * mx.sym.square(mx.sym.maximum(0., 0.9 - y_pred)) +\
+ 0.5 * (1 - y_true) * mx.sym.square(mx.sym.maximum(0., y_pred - 0.1))
+ return mx.sym.mean(data=mx.sym.sum(loss, 1))
+
+
+def capsnet(batch_size, n_class, num_routing,recon_loss_weight):
+ # data.shape = [batch_size, 1, 28, 28]
+ data = mx.sym.Variable('data')
+
+ input_shape = (1, 28, 28)
+ # Conv2D layer
+ # net.shape = [batch_size, 256, 20, 20]
+ conv1 = mx.sym.Convolution(data=data,
+ num_filter=256,
+ kernel=(9, 9),
+ layout='NCHW',
+ name='conv1')
+ conv1 = mx.sym.Activation(data=conv1, act_type='relu', name='conv1_act')
+ # net.shape = [batch_size, 256, 6, 6]
+
+ primarycaps = primary_caps(data=conv1,
+ dim_vector=8,
+ n_channels=32,
+ kernel=(9, 9),
+ strides=[2, 2],
+ name='primarycaps')
+ primarycaps.infer_shape(data=(batch_size, 1, 28, 28))
+ # CapsuleLayer
+ kernel_initializer = mx.init.Xavier(rnd_type='uniform', factor_type='avg',
magnitude=3)
+ bias_initializer = mx.init.Zero()
+ digitcaps = CapsuleLayer(num_capsule=10,
+ dim_vector=16,
+ batch_size=batch_size,
+ kernel_initializer=kernel_initializer,
+ bias_initializer=bias_initializer,
+ num_routing=num_routing)(primarycaps)
+
+ # out_caps : (batch_size, 10)
+ out_caps = mx.sym.sqrt(data=mx.sym.sum(mx.sym.square(digitcaps), 2))
+ out_caps.infer_shape(data=(batch_size, 1, 28, 28))
+
+ y = mx.sym.Variable('softmax_label', shape=(batch_size,))
+ y_onehot = mx.sym.one_hot(y, n_class)
+ y_reshaped = mx.sym.Reshape(data=y_onehot, shape=(batch_size, -4, n_class,
-1))
+ y_reshaped.infer_shape(softmax_label=(batch_size,))
+
+ # inputs_masked : (batch_size, 16)
+ inputs_masked = mx.sym.linalg_gemm2(y_reshaped, digitcaps,
transpose_a=True)
+ inputs_masked = mx.sym.Reshape(data=inputs_masked, shape=(-3, 0))
+ x_recon = mx.sym.FullyConnected(data=inputs_masked, num_hidden=512,
name='x_recon')
+ x_recon = mx.sym.Activation(data=x_recon, act_type='relu',
name='x_recon_act')
+ x_recon = mx.sym.FullyConnected(data=x_recon, num_hidden=1024,
name='x_recon2')
+ x_recon = mx.sym.Activation(data=x_recon, act_type='relu',
name='x_recon_act2')
+ x_recon = mx.sym.FullyConnected(data=x_recon,
num_hidden=np.prod(input_shape), name='x_recon3')
+ x_recon = mx.sym.Activation(data=x_recon, act_type='sigmoid',
name='x_recon_act3')
+
+ data_flatten = mx.sym.flatten(data=data)
+ squared_error = mx.sym.square(x_recon-data_flatten)
+ recon_error = mx.sym.mean(squared_error)
+ recon_error_stopped = recon_error
+ recon_error_stopped = mx.sym.BlockGrad(recon_error_stopped)
+ loss = mx.symbol.MakeLoss((1-recon_loss_weight)*margin_loss(y_onehot,
out_caps)+recon_loss_weight*recon_error)
+
+ out_caps_blocked = out_caps
+ out_caps_blocked = mx.sym.BlockGrad(out_caps_blocked)
+ return mx.sym.Group([out_caps_blocked, loss, recon_error_stopped])
+
+
+def download_data(url, force_download=False):
+ fname = url.split("/")[-1]
+ if force_download or not os.path.exists(fname):
+ mx.test_utils.download(url, fname)
+ return fname
+
+
+def read_data(label_url, image_url):
+ with gzip.open(download_data(label_url)) as flbl:
+ magic, num = struct.unpack(">II", flbl.read(8))
+ label = np.fromstring(flbl.read(), dtype=np.int8)
+ with gzip.open(download_data(image_url), 'rb') as fimg:
+ magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
+ image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label),
rows, cols)
+ return label, image
+
+
+def to4d(img):
+ return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255
+
+
+class LossMetric(mx.metric.EvalMetric):
+ def __init__(self, batch_size, num_gpu):
+ super(LossMetric, self).__init__('LossMetric')
+ self.batch_size = batch_size
+ self.num_gpu = num_gpu
+ self.sum_metric = 0
+ self.num_inst = 0
+ self.loss = 0.0
+ self.batch_sum_metric = 0
+ self.batch_num_inst = 0
+ self.batch_loss = 0.0
+ self.recon_loss = 0.0
+ self.n_batch = 0
+
+ def update(self, labels, preds):
+ batch_sum_metric = 0
+ batch_num_inst = 0
+ for label, pred_outcaps in zip(labels[0], preds[0]):
+ label_np = int(label.asnumpy())
+ pred_label = int(np.argmax(pred_outcaps.asnumpy()))
+ batch_sum_metric += int(label_np == pred_label)
+ batch_num_inst += 1
+ batch_loss = preds[1].asnumpy()
+ recon_loss = preds[2].asnumpy()
+ self.sum_metric += batch_sum_metric
+ self.num_inst += batch_num_inst
+ self.loss += batch_loss
+ self.recon_loss += recon_loss
+ self.batch_sum_metric = batch_sum_metric
+ self.batch_num_inst = batch_num_inst
+ self.batch_loss = batch_loss
+ self.n_batch += 1
+
+ def get_name_value(self):
+ acc = float(self.sum_metric)/float(self.num_inst)
+ mean_loss = self.loss / float(self.n_batch)
+ mean_recon_loss = self.recon_loss / float(self.n_batch)
+ return acc, mean_loss, mean_recon_loss
+
+ def get_batch_log(self, n_batch):
+ print("n_batch :"+str(n_batch)+" batch_acc:" +
+ str(float(self.batch_sum_metric) / float(self.batch_num_inst)) +
+ ' batch_loss:' +
str(float(self.batch_loss)/float(self.batch_num_inst)))
+ self.batch_sum_metric = 0
+ self.batch_num_inst = 0
+ self.batch_loss = 0.0
+
+ def reset(self):
+ self.sum_metric = 0
+ self.num_inst = 0
+ self.loss = 0.0
+ self.recon_loss = 0.0
+ self.n_batch = 0
+
+
+class SimpleLRScheduler(mx.lr_scheduler.LRScheduler):
+ """A simple lr schedule that simply return `dynamic_lr`. We will set
`dynamic_lr`
+ dynamically based on performance on the validation set.
+ """
+
+ def __init__(self, learning_rate=0.001):
+ super(SimpleLRScheduler, self).__init__()
+ self.learning_rate = learning_rate
+
+ def __call__(self, num_update):
+ return self.learning_rate
+
+
+def do_training(num_epoch, optimizer, kvstore, learning_rate, model_prefix,
decay):
+ summary_writer = SummaryWriter(args.tblog_dir)
+ lr_scheduler = SimpleLRScheduler(learning_rate)
+ optimizer_params = {'lr_scheduler': lr_scheduler}
+ module.init_params()
+ module.init_optimizer(kvstore=kvstore,
+ optimizer=optimizer,
+ optimizer_params=optimizer_params)
+ n_epoch = 0
+ while True:
+ if n_epoch >= num_epoch:
+ break
+ train_iter.reset()
+ val_iter.reset()
+ loss_metric.reset()
+ for n_batch, data_batch in enumerate(train_iter):
+ module.forward_backward(data_batch)
+ module.update()
+ module.update_metric(loss_metric, data_batch.label)
+ loss_metric.get_batch_log(n_batch)
+ train_acc, train_loss, train_recon_err = loss_metric.get_name_value()
+ loss_metric.reset()
+ for n_batch, data_batch in enumerate(val_iter):
+ module.forward(data_batch)
+ module.update_metric(loss_metric, data_batch.label)
+ loss_metric.get_batch_log(n_batch)
+ val_acc, val_loss, val_recon_err = loss_metric.get_name_value()
+
+ summary_writer.add_scalar('train_acc', train_acc, n_epoch)
+ summary_writer.add_scalar('train_loss', train_loss, n_epoch)
+ summary_writer.add_scalar('train_recon_err', train_recon_err, n_epoch)
+ summary_writer.add_scalar('val_acc', val_acc, n_epoch)
+ summary_writer.add_scalar('val_loss', val_loss, n_epoch)
+ summary_writer.add_scalar('val_recon_err', val_recon_err, n_epoch)
+
+ print('Epoch[%d] train acc: %.4f loss: %.6f recon_err: %.6f' %
(n_epoch, train_acc, train_loss, train_recon_err))
+ print('Epoch[%d] val acc: %.4f loss: %.6f recon_err: %.6f' % (n_epoch,
val_acc, val_loss, val_recon_err))
+ print('SAVE CHECKPOINT')
+
+ module.save_checkpoint(prefix=model_prefix, epoch=n_epoch)
+ n_epoch += 1
+ lr_scheduler.learning_rate = learning_rate * (decay ** n_epoch)
+
+
+def apply_transform(x,
+ transform_matrix,
+ fill_mode='nearest',
+ cval=0.):
+ x = np.rollaxis(x, 0, 0)
+ final_affine_matrix = transform_matrix[:2, :2]
+ final_offset = transform_matrix[:2, 2]
+ channel_images = [ndi.interpolation.affine_transform(
+ x_channel,
+ final_affine_matrix,
+ final_offset,
+ order=0,
+ mode=fill_mode,
+ cval=cval) for x_channel in x]
+ x = np.stack(channel_images, axis=0)
+ x = np.rollaxis(x, 0, 0 + 1)
+ return x
+
+
+def random_shift(x, width_shift_fraction, height_shift_fraction):
+ tx = np.random.uniform(-height_shift_fraction, height_shift_fraction) *
x.shape[2]
+ ty = np.random.uniform(-width_shift_fraction, width_shift_fraction) *
x.shape[1]
+ shift_matrix = np.array([[1, 0, tx],
+ [0, 1, ty],
+ [0, 0, 1]])
+ x = apply_transform(x, shift_matrix, 'nearest')
+ return x
+
+def _shuffle(data, idx):
+ """Shuffle the data."""
+ shuffle_data = []
+
+ for k, v in data:
+ shuffle_data.append((k, mx.ndarray.array(v.asnumpy()[idx], v.context)))
+
+ return shuffle_data
+
+class MNISTCustomIter(mx.io.NDArrayIter):
+
+ def reset(self):
+ # shuffle data
+ if self.is_train:
+ np.random.shuffle(self.idx)
+ self.data = _shuffle(self.data, self.idx)
+ self.label = _shuffle(self.label, self.idx)
+ if self.last_batch_handle == 'roll_over' and self.cursor >
self.num_data:
+ self.cursor = -self.batch_size +
(self.cursor%self.num_data)%self.batch_size
+ else:
+ self.cursor = -self.batch_size
+ def set_is_train(self, is_train):
+ self.is_train = is_train
+ def next(self):
+ if self.iter_next():
+ if self.is_train:
+ data_raw_list = self.getdata()
+ data_shifted = []
+ for data_raw in data_raw_list[0]:
+ data_shifted.append(random_shift(data_raw.asnumpy(), 0.1,
0.1))
+ return mx.io.DataBatch(data=[mx.nd.array(data_shifted)],
label=self.getlabel(),
+ pad=self.getpad(), index=None)
+ else:
+ return mx.io.DataBatch(data=self.getdata(),
label=self.getlabel(), \
+ pad=self.getpad(), index=None)
+
+ else:
+ raise StopIteration
+
+
+if __name__ == "__main__":
+ # Read mnist data set
+ path = 'http://yann.lecun.com/exdb/mnist/'
+ (train_lbl, train_img) = read_data(
+ path + 'train-labels-idx1-ubyte.gz', path +
'train-images-idx3-ubyte.gz')
+ (val_lbl, val_img) = read_data(
+ path + 't10k-labels-idx1-ubyte.gz', path + 't10k-images-idx3-ubyte.gz')
+ # set batch size
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--batch_size', default=100, type=int)
+ parser.add_argument('--devices', default='gpu0', type=str)
+ parser.add_argument('--num_epoch', default=100, type=int)
+ parser.add_argument('--lr', default=0.001, type=float)
+ parser.add_argument('--num_routing', default=3, type=int)
+ parser.add_argument('--model_prefix', default='capsnet', type=str)
+ parser.add_argument('--decay', default=0.9, type=float)
+ parser.add_argument('--tblog_dir', default='tblog', type=str)
+ parser.add_argument('--recon_loss_weight', default=0.392, type=float)
+ args = parser.parse_args()
+ for k, v in sorted(vars(args).items()):
+ print("{0}: {1}".format(k, v))
+ contexts = re.split(r'\W+', args.devices)
+ for i, ctx in enumerate(contexts):
+ if ctx[:3] == 'gpu':
+ contexts[i] = mx.context.gpu(int(ctx[3:]))
+ else:
+ contexts[i] = mx.context.cpu()
+ num_gpu = len(contexts)
+
+ if args.batch_size % num_gpu != 0:
+ raise Exception('num_gpu should be positive divisor of batch_size')
+
+ # generate train_iter, val_iter
+ train_iter = MNISTCustomIter(data=to4d(train_img), label=train_lbl,
batch_size=int(args.batch_size), shuffle=True)
+ train_iter.set_is_train(True)
+ val_iter = MNISTCustomIter(data=to4d(val_img), label=val_lbl,
batch_size=int(args.batch_size),)
+ val_iter.set_is_train(False)
+ # define capsnet
+ final_net = capsnet(batch_size=int(args.batch_size/num_gpu), n_class=10,
num_routing=args.num_routing, recon_loss_weight=args.recon_loss_weight)
+ # set metric
+ loss_metric = LossMetric(args.batch_size/num_gpu, 1)
+
+ # run model
+ module = mx.mod.Module(symbol=final_net, context=contexts,
data_names=('data',), label_names=('softmax_label',))
+ module.bind(data_shapes=train_iter.provide_data,
+ label_shapes=val_iter.provide_label,
+ for_training=True)
+ do_training(num_epoch=args.num_epoch, optimizer='adam', kvstore='device',
learning_rate=args.lr,
+ model_prefix=args.model_prefix, decay=args.decay)