You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/27 01:17:12 UTC

[GitHub] eric-haibin-lin closed pull request #10391: [MXNET-139] Tutorial for mixed precision training with float16

eric-haibin-lin closed pull request #10391: [MXNET-139] Tutorial for mixed precision training with float16
URL: https://github.com/apache/incubator-mxnet/pull/10391
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/faq/float16.md b/docs/faq/float16.md
new file mode 100644
index 00000000000..cbb308f6918
--- /dev/null
+++ b/docs/faq/float16.md
@@ -0,0 +1,166 @@
+# Mixed precision training using float16
+
+In this tutorial you will walk through how one can train deep learning neural networks with mixed precision on supported hardware. You will first see how to use float16 (both with Gluon and Symbolic APIs) and then some techniques on achieving good performance and accuracy.
+
+## Background
+The computational resources required for training deep neural networks have been increasing of late because of complexity of the architectures and size of models. Mixed precision training allows us to reduces the resources required by using lower precision arithmetic. In this approach you can train using 16 bit floating points (half precision) while using 32 bit floating points (single precision) for output buffers of float16 computation. This combination of single and half precision gives rise to the name mixed precision. It allows us to achieve the same accuracy as training with single precision, while decreasing the required memory and training or inference time.
+
+The float16 data type is a 16 bit floating point representation according to the IEEE 754 standard. It has a dynamic range where the precision can go from 0.0000000596046 (highest, for values closest to 0) to 32 (lowest, for values in the range 32768-65536). Despite the inherent reduced precision when compared to single precision float (float32), using float16 has many advantages. The most obvious advantages are that you can reduce the size of the model by half allowing the training of larger models and using larger batch sizes. The reduced memory footprint also helps in reducing the pressure on memory bandwidth and lowering communication costs. On hardware with specialized support for float16 computation you can also greatly improve the speed of training and inference. The Volta range of Graphics Processing Units (GPUs) from Nvidia have [Tensor Cores](https://www.nvidia.com/en-us/data-center/tensorcore/) which perform efficient float16 computation. A tensor core allows accumulation of half precision products into single or half precision outputs. For the rest of this tutorial we assume that we are working with Nvidia's Tensor Cores on a Volta GPU.
+
+## Prerequisites
+- Volta range of Nvidia GPUs
+- Cuda 9 or higher
+- CUDNN v7 or higher
+
+This tutorial also assumes that you understand how to train a network with float32. Please refer to other tutorials [here](http://mxnet.incubator.apache.org/tutorials/index.html) to get started with MXNet and/or Gluon. This tutorial focuses on the changes needed to switch from float32 to mixed precision and tips on achieving the best performance with mixed precision.
+
+## Using the Gluon API
+
+### Training or Inference
+
+With Gluon, you need to take care of three things to convert a model to support float16.
+
+1. Cast the Gluon Block, so as to cast the parameters of layers and change the type of input expected, to float16. This is as simple as calling the [cast](https://mxnet.incubator.apache.org/api/python/gluon/gluon.html#mxnet.gluon.Block.cast) method of the Block representing the network.
+```
+net = net.cast('float16')
+```
+
+2. Ensure the data input to the network is of float16 type. If your DataLoader or Iterator produces output in another datatype, then you would have to cast your data. There are different ways you can do this. The easiest would be to use the [`astype`](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.astype) method of ndarrays.
+```
+data = data.astype('float16', copy=False)
+```
+
+If you are using images and DataLoader, you can also use a [Cast transform](https://mxnet.incubator.apache.org/api/python/gluon/data.html#mxnet.gluon.data.vision.transforms.Cast)
+
+3. It is preferable to use **multi_precision mode of optimizer** when training in float16. This mode of optimizer maintains a master copy of weights in float32 even when the training (i.e. forward and backward pass) is in float16. This helps increase precision of the weight updates and can lead to faster convergence for some networks. (Further discussion on this towards the end.)
+
+```python
+optimizer = mx.optimizer.create('sgd', multi_precision=True, lr=0.01)
+```
+
+You can play around with mixed precision using the image classification example [here](https://github.com/apache/incubator-mxnet/blob/master/example/gluon/image_classification.py). We suggest using the Caltech101 dataset option in that example and using a Resnet50_v1 network so you can quickly see the performance improvement and how the accuracy is unaffected. Here's a starter command to run this.
+
+```
+python image_classification.py --model resnet50_v1 --dataset caltech101 --gpus 0 --num-worker 30 --dtype float16
+```
+
+
+### Fine-tuning
+
+You can also fine-tune in float16, a model which was originally trained in float32. Here is how you would do it. As an example if you are trying to use a model pretrained on the Imagenet dataset from the ModelZoo, you would first fetch the pretrained network and then cast that network to float16.
+
+```
+pretrained_net = models.get_model(name='resnet50_v2', ctx=ctx, pretrained=True, classes=1000)
+pretrained_net.cast('float16')
+```
+Then if you have another Resnet50_v2 model you want to fine-tune, you can just assign the features to that network and then cast it.
+
+```
+net = models.get_model(name='resnet50_v2', ctx=ctx, pretrained=False, classes=101)
+net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
+net.features = pretrained_net.features
+net.cast(dtype)
+```
+
+## Using the Symbolic API
+
+Training a network in float16 with the Symbolic API involves the following steps.
+1. Add a layer at the beginning of the network, to cast the data to float16. This will ensure that all the following layers compute in float16.
+2. It is advisable to cast the output of the layers before softmax to float32, so that the softmax computation is done in float32. This is because softmax involves large reductions and it helps to keep that in float32 for more precise answer.
+3. It is advisable to use the multi-precision mode of the optimizer for more precise weight updates. This is discussed in some detail below. Here's how you would enable this mode when creating an optimizer.
+
+```python
+optimizer = mx.optimizer.create('sgd', multi_precision=True, lr=0.01)
+```
+
+There are a few examples of building such networks which can handle float16 input in [examples/image-classification/symbols/](https://github.com/apache/incubator-mxnet/tree/master/example/image-classification/symbols). Specifically you could look at the [resnet](https://github.com/apache/incubator-mxnet/blob/master/example/image-classification/symbols/resnet.py) example.
+
+An illustration of the relevant section of the code is below.
+```
+data = mx.sym.Variable(name="data")
+if dtype == 'float16':
+    data = mx.sym.Cast(data=data, dtype=np.float16)
+
+// the rest of the network 
+net_out = net(data)
+
+if dtype == 'float16':
+    net_out = mx.sym.Cast(data=net_out, dtype=np.float32)
+output = mx.sym.SoftmaxOutput(data=net_out, name='softmax')
+```
+
+We have an example script which show how to train imagenet with resnet50 using float16 [here](https://github.com/apache/incubator-mxnet/tree/master/example/image-classification/train_imagenet.py) 
+
+Here's how you can use the above script to train Resnet50 v1 model with synthetic data using float16, so you can try it out even if you don't have the Imagenet dataset handy.
+```
+python train_imagenet.py --network resnet-v1 --num-layers 50 --benchmark 1 --gpus 0 --batch-size 256 --dtype float16
+```
+
+There's a similar example for fine tuning [here](https://github.com/apache/incubator-mxnet/tree/master/example/image-classification/fine-tune.py). The following command shows how to use that script to fine tune a Resnet50 model trained on Imagenet for the Caltech 256 dataset using float16.
+```
+python fine-tune.py --network resnet --num-layers 50 --pretrained-model imagenet1k-resnet-50 --data-train ~/data/caltech-256/caltech256-train.rec --data-val ~/data/caltech-256/caltech256-val.rec --num-examples 15420 --num-classes 256 --gpus 0 --batch-size 64 --dtype float16
+```
+
+## Example training results
+Here is a plot to compare the training curves of a Resnet50 v1 network on the Imagenet 2012 dataset. These training jobs ran for 95 epochs with a batch size of 1024 using a learning rate of 0.4 decayed by a factor of 1 at epochs 30,60,90 and used Gluon. The only changes made for the float16 job when compared to the float32 job were that the network and data were cast to float16, and the multi-precision mode was used for optimizer. The final accuracies at 95th epoch were **76.598% for float16** and **76.486% for float32**. The difference is within what's normal random variation, and there is no reason to expect float16 to have better accuracy than float32 in general. This run was approximately **65% faster** to train with float16.
+
+![Training curves of Resnet50 v1 on Imagenet 2012](https://raw.githubusercontent.com/rahul003/web-data/03929a8beb8ac574f2392ed34cc6d4b2f052826a/mxnet/tutorials/mixed-precision/resnet50v1b_imagenet_fp16_fp32_training.png)
+
+## Things to keep in mind
+
+### For performance
+
+Typical performance gains seen for float16 typically range 1.6x-2x for convolutional networks like Resnet and even about 3x for networks with LSTMs. The performance gain you see can depend on certain things which this section will introduce you to.
+
+1. Nvidia Tensor Cores essentially perform the computation D = A * B + C, where A and B are half precision matrices, while C and D could be either half precision or full precision. The tensor cores are most efficient when dimensions of these matrices are multiples of 8. This means that Tensor Cores can not be used in all cases for fast float16 computation. When training models like Resnet50 on the Cifar10 dataset, the tensors involved are sometimes smaller, and Tensor Cores can not always be used. The computation in that case falls back to slower algorithms and using float16 turns out to be slower than float32 on a single GPU. Note that when using multiple GPUs, using float16 can still be faster than float32 because of reduction in communication costs.
+
+2. When you scale up the batch size ensure that IO and data pre-processing is not your bottleneck. If you see a slowdown this would be the first thing to check.
+
+3. It is advisable to use batch sizes that are multiples of 8 because of the above reason when training with float16. As always, batch sizes which are powers of 2 would be best when compared to those around it.
+
+4. You can check whether your program is using Tensor cores for fast float16 computation by profiling with `nvprof`.
+The operations with `s884cudnn` in their names represent the use of Tensor cores.
+
+5. When not limited by GPU memory, it can help to set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 2. This configures MXNet to run tuning tests and choose the fastest convolution algorithm whose memory requirements may exceed the default memory of CUDA workspace.
+
+6. Please note that float16 on CPU might not be supported for all operators, as in most cases float16 on CPU is much slower than float32.
+
+
+### For accuracy
+
+#### Multi precision mode
+When training in float16, it is advisable to still store the master copy of the weights in float32 for better accuracy. The higher precision of float32 helps overcome cases where gradient update can become 0 if represented in float16. This mode can be activated by setting the parameter `multi_precision` of optimizer params to `True` as in the above example. It has been found that this is not required for all networks to achieve the same accuracy as with float32, but nevertheless recommended. Note that for distributed training, this is currently slightly slower than without `multi_precision`, but still much faster than using float32 for training.
+
+#### Large reductions 
+Since float16 has low precision for large numbers, it is best to leave layers which perform large reductions in float32. This includes BatchNorm and Softmax. Ensuring that Batchnorm performs reduction in float32 is handled by default in both Gluon and Module APIs. While Softmax is set to use float32 even during float16 training in Gluon, in the Module API there needs to be a cast to float32 before softmax as the above symbolic example code shows.
+
+#### Loss scaling
+For some networks just switching the training to float16 mode was not found to be enough to reach the same accuracy as when training with float32. This is because the activation gradients computed are too small and could not be represented in float16 representable range. Such networks can be made to achieve the accuracy reached by float32 with a couple of changes. 
+
+Most of the float16 representable range is not used by activation gradients generally. So you can shift the gradients into float16 range by scaling up the loss by a factor `S`. By the chain rule, this scales up the loss before backward pass, and then you can scale back the gradients before updating the weights. This ensures that training in float16 can use the same hyperparameters as used during float32 training.
+
+Here's how you can configure the loss to be scaled up by 128 and rescale the gradient down before updating the weights.
+
+*Gluon*
+```
+loss = gluon.loss.SoftmaxCrossEntropyLoss(weight=128)
+optimizer = mx.optimizer.create('sgd', multi_precision=True, rescale_grad=1.0/128)
+```
+*Module*
+```
+mxnet.sym.SoftmaxOutput(other_args, grad_scale=128.0)
+optimizer = mx.optimizer.create('sgd', multi_precision=True, rescale_grad=1.0/128)
+```
+
+Networks like Multibox SSD, R-CNN, bigLSTM and Seq2seq were found to exhibit such behavior.
+You can choose a constant scaling factor while ensuring that the absolute value of gradient when multiplied by this factor remains in the range of float16. Generally powers of 2 like 64,128,256,512 are chosen. Refer the linked articles below for more details on this.
+
+## Video Tutorial
+
+We also have a video tutorial for using Mixed Precision with MXNet. You can check that out [here](https://www.youtube.com/watch?v=pR4KMh1lGC0)
+
+## References
+1. [Training with Mixed Precision User Guide](http://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html)
+2. [Mixed Precision Training at ICLR 2018](https://arxiv.org/pdf/1710.03740.pdf)
+3. [Mixed-Precision Training of Deep Neural Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/)
+
diff --git a/docs/faq/index.md b/docs/faq/index.md
index ebe16e2bb99..c351bc90dba 100644
--- a/docs/faq/index.md
+++ b/docs/faq/index.md
@@ -32,6 +32,7 @@ and full working examples, visit the [tutorials section](../tutorials/index.md).
 
 * [What are the best setup and data-handling tips and tricks for improving speed?](http://mxnet.io/faq/perf.html)
 
+* [How do I use mixed precision with MXNet or Gluon?](http://mxnet.io/faq/float16.html)
 
 ## Deployment Environments
 * [Can I run MXNet on smart or mobile devices?](http://mxnet.io/faq/smart_device.html)
diff --git a/example/gluon/data.py b/example/gluon/data.py
index 56e89065afe..5a7f9e6ca09 100644
--- a/example/gluon/data.py
+++ b/example/gluon/data.py
@@ -88,6 +88,40 @@ def get_imagenet_iterator(root, batch_size, num_workers, data_shape=224, dtype='
     val_data = DataLoader(val_dataset, batch_size, last_batch='keep', num_workers=num_workers)
     return DataLoaderIter(train_data, dtype), DataLoaderIter(val_data, dtype)
 
+def get_caltech101_data():
+    url = "https://s3.us-east-2.amazonaws.com/mxnet-public/101_ObjectCategories.tar.gz"
+    dataset_name = "101_ObjectCategories"
+    if not os.path.isdir("data"):
+        os.makedirs(data_folder)
+    tar_path = mx.gluon.utils.download(url, path='data')
+    if (not os.path.isdir(os.path.join(data_folder, "101_ObjectCategories")) or
+        not os.path.isdir(os.path.join(data_folder, "101_ObjectCategories_test"))):
+        tar = tarfile.open(tar_path, "r:gz")
+        tar.extractall(data_folder)
+        tar.close()
+        print('Data extracted')
+    training_path = os.path.join(data_folder, dataset_name)
+    testing_path = os.path.join(data_folder, "{}_test".format(dataset_name))
+    return training_path, testing_path
+
+def get_caltech101_iterator(batch_size, num_workers, dtype):
+    def transform(image, label):
+        # resize the shorter edge to 224, the longer edge will be greater or equal to 224
+        resized = mx.image.resize_short(image, 224)
+        # center and crop an area of size (224,224)
+        cropped, crop_info = mx.image.center_crop(resized, 224)
+        # transpose the channels to be (3,224,224)
+        transposed = nd.transpose(cropped, (2, 0, 1))
+        image = mx.nd.cast(image, dtype)
+        return image, label
+
+    training_path, testing_path = get_caltech101_data()
+    dataset_train = ImageFolderDataset(root=training_path, transform=transform)
+    dataset_test = ImageFolderDataset(root=testing_path, transform=transform)
+
+    train_data = gluon.data.DataLoader(dataset_train, batch_size, shuffle=True, num_workers=num_workers)
+    test_data = gluon.data.DataLoader(dataset_test, batch_size, shuffle=False, num_workers=num_workers)
+    return DataLoaderIter(train_data), DataLoaderIter(test_data)
 
 class DummyIter(mx.io.DataIter):
     def __init__(self, batch_size, data_shape, batches = 100):
diff --git a/example/gluon/image_classification.py b/example/gluon/image_classification.py
index b21e943f17f..fe0a346f42d 100644
--- a/example/gluon/image_classification.py
+++ b/example/gluon/image_classification.py
@@ -47,13 +47,13 @@
 # CLI
 parser = argparse.ArgumentParser(description='Train a model for image classification.')
 parser.add_argument('--dataset', type=str, default='cifar10',
-                    help='dataset to use. options are mnist, cifar10, imagenet and dummy.')
+                    help='dataset to use. options are mnist, cifar10, caltech101, imagenet and dummy.')
 parser.add_argument('--data-dir', type=str, default='',
-                    help='training directory of imagenet images, contains train/val subdirs.')
+                  help='training directory of imagenet images, contains train/val subdirs.')
+parser.add_argument('--num-worker', '-j', dest='num_workers', default=4, type=int,
+                    help='number of workers for dataloader')
 parser.add_argument('--batch-size', type=int, default=32,
                     help='training batch size per device (CPU/GPU).')
-parser.add_argument('--num-worker', '-j', dest='num_workers', default=4, type=int,
-                    help='number of workers of dataloader.')
 parser.add_argument('--gpus', type=str, default='',
                     help='ordinates of gpus to use, can be "0,1,2" or empty for cpu only.')
 parser.add_argument('--epochs', type=int, default=120,
@@ -104,13 +104,14 @@
 logger.info('Starting new image-classification task:, %s',opt)
 mx.random.seed(opt.seed)
 model_name = opt.model
-dataset_classes = {'mnist': 10, 'cifar10': 10, 'imagenet': 1000, 'dummy': 1000}
+dataset_classes = {'mnist': 10, 'cifar10': 10, 'caltech101':101, 'imagenet': 1000, 'dummy': 1000}
 batch_size, dataset, classes = opt.batch_size, opt.dataset, dataset_classes[opt.dataset]
 context = [mx.gpu(int(i)) for i in opt.gpus.split(',')] if opt.gpus.strip() else [mx.cpu()]
 num_gpus = len(context)
 batch_size *= max(1, num_gpus)
 lr_steps = [int(x) for x in opt.lr_steps.split(',') if x.strip()]
 metric = CompositeEvalMetric([Accuracy(), TopKAccuracy(5)])
+kv = mx.kv.create(opt.kvstore)
 
 def get_model(model, ctx, opt):
     """Model initialization."""
@@ -133,37 +134,39 @@ def get_model(model, ctx, opt):
 
 net = get_model(opt.model, context, opt)
 
-def get_data_iters(dataset, batch_size, num_workers=1, rank=0):
+def get_data_iters(dataset, batch_size, opt):
     """get dataset iterators"""
     if dataset == 'mnist':
         train_data, val_data = get_mnist_iterator(batch_size, (1, 28, 28),
-                                                  num_parts=num_workers, part_index=rank)
+                                                  num_parts=kv.num_workers, part_index=kv.rank)
     elif dataset == 'cifar10':
         train_data, val_data = get_cifar10_iterator(batch_size, (3, 32, 32),
-                                                    num_parts=num_workers, part_index=rank)
+                                                    num_parts=kv.num_workers, part_index=kv.rank)
     elif dataset == 'imagenet':
+        shape_dim = 299 if model_name == 'inceptionv3' else 224
+
         if not opt.data_dir:
-            raise ValueError('Dir containing raw images in train/val is required for imagenet, plz specify "--data-dir"')
-        if model_name == 'inceptionv3':
-            train_data, val_data = get_imagenet_iterator(opt.data_dir, batch_size, opt.num_workers, 299, opt.dtype)
-        else:
-            train_data, val_data = get_imagenet_iterator(opt.data_dir, batch_size, opt.num_workers, 224, opt.dtype)
+            raise ValueError('Dir containing raw images in train/val is required for imagenet.'
+                             'Please specify "--data-dir"')
+
+        train_data, val_data = get_imagenet_iterator(opt.data_dir, batch_size,
+                                                                opt.num_workers, shape_dim, opt.dtype)
+    elif dataset == 'caltech101':
+        train_data, val_data = get_caltech101_iterator(batch_size, opt.num_workers, opt.dtype)
     elif dataset == 'dummy':
-        if model_name == 'inceptionv3':
-            train_data, val_data = dummy_iterator(batch_size, (3, 299, 299))
-        else:
-            train_data, val_data = dummy_iterator(batch_size, (3, 224, 224))
+        shape_dim = 299 if model_name == 'inceptionv3' else 224
+        train_data, val_data = dummy_iterator(batch_size, (3, shape_dim, shape_dim))
     return train_data, val_data
 
 def test(ctx, val_data):
     metric.reset()
     val_data.reset()
     for batch in val_data:
-        data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
-        label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
-        outputs = []
-        for x in data:
-            outputs.append(net(x))
+        data = gluon.utils.split_and_load(batch.data[0].astype(opt.dtype, copy=False),
+                                          ctx_list=ctx, batch_axis=0)
+        label = gluon.utils.split_and_load(batch.label[0].astype(opt.dtype, copy=False),
+                                           ctx_list=ctx, batch_axis=0)
+        outputs = [net(X) for X in data]
         metric.update(label, outputs)
     return metric.get()
 
@@ -187,16 +190,17 @@ def save_checkpoint(epoch, top1, best_acc):
 def train(opt, ctx):
     if isinstance(ctx, mx.Context):
         ctx = [ctx]
-    kv = mx.kv.create(opt.kvstore)
-    train_data, val_data = get_data_iters(dataset, batch_size, kv.num_workers, kv.rank)
+
+    train_data, val_data = get_data_iters(dataset, batch_size, opt)
     net.collect_params().reset_ctx(ctx)
     trainer = gluon.Trainer(net.collect_params(), 'sgd',
-                            {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum,
-                             'multi_precision': True},
-                            kvstore = kv)
+                            optimizer_params={'learning_rate': opt.lr,
+                                              'wd': opt.wd,
+                                              'momentum': opt.momentum,
+                                              'multi_precision': True},
+                            kvstore=kv)
     loss = gluon.loss.SoftmaxCrossEntropyLoss()
 
-
     total_time = 0
     num_epochs = 0
     best_acc = [0]
@@ -253,13 +257,16 @@ def main():
         profiler.set_state('run')
     if opt.mode == 'symbolic':
         data = mx.sym.var('data')
+        if opt.dtype == 'float16':
+            data = mx.sym.Cast(data=data, dtype=np.float16)
         out = net(data)
+        if opt.dtype == 'float16':
+            out = mx.sym.Cast(data=out, dtype=np.float32)
         softmax = mx.sym.SoftmaxOutput(out, name='softmax')
         mod = mx.mod.Module(softmax, context=context)
-        kv = mx.kv.create(opt.kvstore)
-        train_data, val_data = get_data_iters(dataset, batch_size, kv.num_workers, kv.rank)
+        train_data, val_data = get_data_iters(dataset, batch_size, opt)
         mod.fit(train_data,
-                eval_data = val_data,
+                eval_data=val_data,
                 num_epoch=opt.epochs,
                 kvstore=kv,
                 batch_end_callback = mx.callback.Speedometer(batch_size, max(1, opt.log_interval)),
diff --git a/example/image-classification/benchmark_score.py b/example/image-classification/benchmark_score.py
index 82903b63238..0d47d859d00 100644
--- a/example/image-classification/benchmark_score.py
+++ b/example/image-classification/benchmark_score.py
@@ -27,7 +27,7 @@
 import numpy as np
 logging.basicConfig(level=logging.DEBUG)
 
-def get_symbol(network, batch_size):
+def get_symbol(network, batch_size, dtype):
     image_shape = (3,299,299) if network == 'inception-v3' else (3,224,224)
     num_layers = 0
     if 'resnet' in network:
@@ -37,14 +37,15 @@ def get_symbol(network, batch_size):
         num_layers = int(network.split('-')[1])
         network = 'vgg'
     net = import_module('symbols.'+network)
-    sym = net.get_symbol(num_classes = 1000,
-                         image_shape = ','.join([str(i) for i in image_shape]),
-                         num_layers  = num_layers)
+    sym = net.get_symbol(num_classes=1000,
+                         image_shape=','.join([str(i) for i in image_shape]),
+                         num_layers=num_layers,
+                         dtype=dtype)
     return (sym, [('data', (batch_size,)+image_shape)])
 
-def score(network, dev, batch_size, num_batches):
+def score(network, dev, batch_size, num_batches, dtype):
     # get mod
-    sym, data_shape = get_symbol(network, batch_size)
+    sym, data_shape = get_symbol(network, batch_size, dtype)
     mod = mx.mod.Module(symbol=sym, context=dev)
     mod.bind(for_training     = False,
              inputs_need_grad = False,
@@ -74,11 +75,17 @@ def score(network, dev, batch_size, num_batches):
     devs.append(mx.cpu())
 
     batch_sizes = [1, 2, 4, 8, 16, 32]
-
     for net in networks:
         logging.info('network: %s', net)
         for d in devs:
             logging.info('device: %s', d)
             for b in batch_sizes:
-                speed = score(network=net, dev=d, batch_size=b, num_batches=10)
-                logging.info('batch size %2d, image/sec: %f', b, speed)
+                for dtype in ['float32', 'float16']:
+                    if d == mx.cpu() and dtype == 'float16':
+                        #float16 is not supported on CPU
+                        continue
+                    elif net in ['inception-bn', 'alexnet'] and dt == 'float16':
+                        logging.info('{} does not support float16'.format(net))
+                    else:
+                        speed = score(network=net, dev=d, batch_size=b, num_batches=10, dtype=dtype)
+                        logging.info('batch size %2d, dtype %s image/sec: %f', b, dtype, speed)
diff --git a/example/image-classification/common/fit.py b/example/image-classification/common/fit.py
index ca1c09ab08a..d5ca87f081b 100755
--- a/example/image-classification/common/fit.py
+++ b/example/image-classification/common/fit.py
@@ -202,6 +202,8 @@ def fit(args, network, data_loader, **kwargs):
     # learning rate
     lr, lr_scheduler = _get_lr_scheduler(args, kv)
 
+    network = mx.symbol.load('%s-symbol.json')
+
     # create model
     model = mx.mod.Module(
         context=devs,
diff --git a/example/image-classification/fine-tune.py b/example/image-classification/fine-tune.py
index a5fb2434d95..2a0c0ec99f6 100644
--- a/example/image-classification/fine-tune.py
+++ b/example/image-classification/fine-tune.py
@@ -22,8 +22,10 @@
 from common import find_mxnet
 from common import data, fit, modelzoo
 import mxnet as mx
+import numpy as np
 
-def get_fine_tune_model(symbol, arg_params, num_classes, layer_name):
+
+def get_fine_tune_model(symbol, arg_params, num_classes, layer_name, dtype='float32'):
     """
     symbol: the pre-trained network symbol
     arg_params: the argument parameters of the pre-trained model
@@ -33,11 +35,12 @@ def get_fine_tune_model(symbol, arg_params, num_classes, layer_name):
     all_layers = symbol.get_internals()
     net = all_layers[layer_name+'_output']
     net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc')
+    if dtype == 'float16':
+        net = mx.sym.Cast(data=net, dtype=np.float32)
     net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
     new_args = dict({k:arg_params[k] for k in arg_params if 'fc' not in k})
     return (net, new_args)
 
-
 if __name__ == "__main__":
     # parse args
     parser = argparse.ArgumentParser(description="fine-tune a dataset",
@@ -46,18 +49,24 @@ def get_fine_tune_model(symbol, arg_params, num_classes, layer_name):
     data.add_data_args(parser)
     aug = data.add_data_aug_args(parser)
     parser.add_argument('--pretrained-model', type=str,
-                        help='the pre-trained model')
+                        help='the pre-trained model. can be prefix of local model files prefix \
+                        or a model name from common/modelzoo')
     parser.add_argument('--layer-before-fullc', type=str, default='flatten0',
-                        help='the name of the layer before the last fullc layer')
+                        help='the name of the layer before the last fullc layer')\
+
     # use less augmentations for fine-tune
     data.set_data_aug_level(parser, 1)
     # use a small learning rate and less regularizations
-    parser.set_defaults(image_shape='3,224,224', num_epochs=30,
-                        lr=.01, lr_step_epochs='20', wd=0, mom=0)
-
+    parser.set_defaults(image_shape='3,224,224',
+                        num_epochs=30,
+                        lr=.01,
+                        lr_step_epochs='20',
+                        wd=0,
+                        mom=0)
     args = parser.parse_args()
 
-    # load pretrained model
+
+    # load pretrained model and params
     dir_path = os.path.dirname(os.path.realpath(__file__))
     (prefix, epoch) = modelzoo.download_model(
         args.pretrained_model, os.path.join(dir_path, 'model'))
@@ -65,10 +74,26 @@ def get_fine_tune_model(symbol, arg_params, num_classes, layer_name):
         (prefix, epoch) = (args.pretrained_model, args.load_epoch)
     sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
 
-    # remove the last fullc layer
-    (new_sym, new_args) = get_fine_tune_model(
-        sym, arg_params, args.num_classes, args.layer_before_fullc)
+    if args.dtype != 'float32':
+        # load symbol of trained network, so we can cast it to support other dtype
+        # fine tuning a network in a datatype which was not used for training originally,
+        # requires access to the code used to generate the symbol used to train that model.
+        # we then need to modify the symbol to add a layer at the beginning
+        # to cast data to that dtype. We also need to cast output of layers before softmax
+        # to float32 so that softmax can still be in float32.
+        # if the network chosen from symols/ folder doesn't have cast for the new datatype,
+        # it will still train in fp32
+        if args.network not in ['inception-v3',\
+                                 'inception-v4', 'resnet-v1', 'resnet', 'resnext', 'vgg']:
+            raise ValueError('Given network does not have support for dtypes other than float32.\
+                Please add a cast layer at the beginning to train in that mode.')
+        from importlib import import_module
+        net = import_module('symbols.'+args.network)
+        sym = net.get_symbol(**vars(args))
 
+    # remove the last fullc layer and add a new softmax layer
+    (new_sym, new_args) = get_fine_tune_model(sym, arg_params, args.num_classes,
+                                              args.layer_before_fullc, args.dtype)
     # train
     fit.fit(args        = args,
             network     = new_sym,


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services