You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/06 23:27:35 UTC

[GitHub] indhub commented on a change in pull request #11002: [MXNET-433] Tutorial on saving and loading gluon models

indhub commented on a change in pull request #11002: [MXNET-433] Tutorial on saving and loading gluon models
URL: https://github.com/apache/incubator-mxnet/pull/11002#discussion_r193589063
 
 

 ##########
 File path: docs/tutorials/gluon/save_load_params.md
 ##########
 @@ -0,0 +1,269 @@
+# Saving and Loading Gluon Models
+
+Training large models take a lot of time and it is a good idea to save the trained models to files to avoid training them again and again. There is a number of reasons to do this. For example, you might want to do inference on a machine that is different from the one where the model was trained. Sometimes model's performance on validation set decreases towards the end of the training because of overfitting. If you saved your model parameters after every epoch, at the end you can decide to use the model that performs best on the validation set.
+
+In this tutorial we will learn ways to save and load Gluon models. There are two ways to save/load Gluon models:
+
+**1. Save/load model parameters only**
+
+Parameters of any Gluon model can be saved using the `save_params` and `load_params` method. This does not save model architecture. This method is used to save parameters of dynamic (non Hybrid) models. Model architecture cannot be saved for dynamic models because model architecture changes during execution.
+
+**2. Save/load model parameters AND architecture**
+
+Model architecture of `Hybrid` models stays static and don't change during execution. Therefore both model parameters AND architecture can be saved and loaded using `export`, `load_checkpoint` and `load` methods.
+
+Let's look at the above methods in more detail. Let's start by importing the modules we'll need.
+
+```python
+from __future__ import print_function
+
+import mxnet as mx
+import mxnet.ndarray as nd
+from mxnet import nd, autograd, gluon
+from mxnet.gluon.data.vision import transforms
+
+import numpy as np
+```
+
+## Setup: build and train a simple model
+
+We need a trained model before we can save it to a file. So let's go ahead and build a very simple convolutional network and train it on MNIST data.
+
+Let's define a helper function to build a LeNet model and another helper to train LeNet with MNIST.
+
+```python
+# Use GPU if one exists, else use CPU
+ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()
+
+# MNIST images are 28x28. Total pixels in input layer is 28x28 = 784
+num_inputs = 784
+# Clasify the images into one of the 10 digits
+num_outputs = 10
+# 64 images in a batch
+batch_size = 64
+
+# Load the training data
+train_data = gluon.data.DataLoader(gluon.data.vision.MNIST(train=True).transform_first(transforms.ToTensor()),
+                                   batch_size, shuffle=True)
+
+# Build a simple convolutional network
+def build_lenet(net):    
+    with net.name_scope():
+        # First convolution
+        net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu'))
+        net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
+        # Second convolution
+        net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu'))
+        net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
+        # Flatten the output before the fully connected layers
+        net.add(gluon.nn.Flatten())
+        # First fully connected layers with 512 neurons
+        net.add(gluon.nn.Dense(512, activation="relu"))
+        # Second fully connected layer with as many neurons as the number of classes
+        net.add(gluon.nn.Dense(num_outputs))
+        
+        return net
+
+# Train a given model using MNIST data
+def train_model(model):
+    # Initialize the parameters with Xavier initializer
+    net.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
 
 Review comment:
   Good catch. thanks!

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services