You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/30 19:57:54 UTC

[incubator-mxnet] branch master updated: add variational autoencoder example (#7190)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 62eaa36  add variational autoencoder example (#7190)
62eaa36 is described below

commit 62eaa36fd282d95d8e629d9d3681d336f71c514b
Author: xiaoyulu2014 <xi...@stats.ox.ac.uk>
AuthorDate: Wed Aug 30 20:57:51 2017 +0100

    add variational autoencoder example (#7190)
    
    * add variational encoder example
    
    * formatting
---
 example/vae/VAE.py            |  129 +++++
 example/vae/VAE_example.ipynb | 1167 +++++++++++++++++++++++++++++++++++++++++
 2 files changed, 1296 insertions(+)

diff --git a/example/vae/VAE.py b/example/vae/VAE.py
new file mode 100644
index 0000000..c0d5ec0
--- /dev/null
+++ b/example/vae/VAE.py
@@ -0,0 +1,129 @@
+import mxnet as mx
+import numpy as np
+import os
+import logging
+
+
+class VAE:
+    '''This class implements the Variational Auto Encoder'''
+    
+    def Bernoulli(x_hat,loss_label):
+        return(-mx.symbol.sum(mx.symbol.broadcast_mul(loss_label,mx.symbol.log(x_hat)) + mx.symbol.broadcast_mul(1-loss_label,mx.symbol.log(1-x_hat)),axis=1))
+
+    
+    def __init__(self,n_latent=5,num_hidden_ecoder=400,num_hidden_decoder=400,x_train=None,x_valid=None,batch_size=100,learning_rate=0.001,weight_decay=0.01,num_epoch=100,optimizer='sgd',model_prefix=None, initializer = mx.init.Normal(0.01),likelihood=Bernoulli):
+        
+
+        self.n_latent = n_latent                            #dimension of the latent space Z
+        self.num_hidden_ecoder = num_hidden_ecoder          #number of hidden units in the encoder
+        self.num_hidden_decoder = num_hidden_decoder        #number of hidden units in the decoder
+        self.batch_size = batch_size                        #mini batch size
+        self.learning_rate = learning_rate                  #learning rate during training
+        self.weight_decay = weight_decay                    #weight decay during training, for regulariization of parameters
+        self.num_epoch = num_epoch                          #total number of training epoch
+        self.optimizer = optimizer
+
+
+
+        #train the model
+        self.model, self.training_loss = VAE.train_vae(x_train,x_valid,batch_size,n_latent,num_hidden_ecoder,num_hidden_decoder,learning_rate,weight_decay,num_epoch,optimizer,model_prefix,likelihood,initializer)
+        #save model parameters (i.e. weights and biases)
+        self.arg_params = self.model.get_params()[0]
+        #save loss(ELBO) for the training set 
+        nd_iter = mx.io.NDArrayIter(data={'data':x_train},label={'loss_label':x_train},batch_size = batch_size)     
+
+        #if saved parameters, can access them at specific iteration e.g. last epoch using
+        #   sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, self.num_epoch)
+        #   assert sym.tojson() == output.tojson()
+        #   self.arg_params = arg_params 
+    def train_vae(x_train,x_valid,batch_size,n_latent,num_hidden_ecoder,num_hidden_decoder,learning_rate,weight_decay,num_epoch,optimizer,model_prefix,likelihood,initializer):
+        [N,features] = np.shape(x_train)          #number of examples and features
+
+        #create data iterator to feed into NN
+        nd_iter = mx.io.NDArrayIter(data={'data':x_train},label={'loss_label':x_train},batch_size = batch_size)
+        if x_valid is not None:
+            nd_iter_val = mx.io.NDArrayIter(data={'data':x_valid},label={'loss_label':x_valid},batch_size = batch_size)
+        else:
+            nd_iter_val = None
+        data = mx.sym.var('data')
+        loss_label = mx.sym.var('loss_label')
+
+
+        #build network architucture
+        encoder_h  = mx.sym.FullyConnected(data=data, name="encoder_h",num_hidden=num_hidden_ecoder)
+        act_h = mx.sym.Activation(data=encoder_h, act_type="tanh",name="activation_h")
+
+        
+        mu  = mx.sym.FullyConnected(data=act_h, name="mu",num_hidden = n_latent)
+        logvar  = mx.sym.FullyConnected(data=act_h, name="logvar",num_hidden = n_latent)
+        #latent manifold
+        z = mu + mx.symbol.broadcast_mul(mx.symbol.exp(0.5*logvar),mx.symbol.random_normal(loc=0, scale=1,shape=(batch_size,n_latent))) 
+        decoder_z = mx.sym.FullyConnected(data=z, name="decoder_z",num_hidden=num_hidden_decoder)
+        act_z = mx.sym.Activation(data=decoder_z, act_type="tanh",name="actication_z")
+
+        decoder_x = mx.sym.FullyConnected(data=act_z, name="decoder_x",num_hidden=features)
+        act_x = mx.sym.Activation(data=decoder_x, act_type="sigmoid",name='activation_x')
+
+        KL = -0.5*mx.symbol.sum(1+logvar-pow( mu,2)-mx.symbol.exp(logvar),axis=1)
+
+        #compute minus ELBO to minimize 
+        loss = likelihood(act_x,loss_label)+KL
+        output = mx.symbol.MakeLoss(sum(loss),name='loss')
+
+        #train the model
+        nd_iter.reset()
+        logging.getLogger().setLevel(logging.DEBUG)  # logging to stdout
+
+        model = mx.mod.Module(
+            symbol = output ,
+            data_names=['data'],
+            label_names = ['loss_label'])
+
+             #initialize the weights and bias 
+
+
+        training_loss = list()
+        def log_to_list(period, lst):
+                def _callback(param):
+                        """The checkpoint function."""
+                        if param.nbatch % period == 0:
+                                name, value = param.eval_metric.get()
+                                lst.append(value)
+                return _callback
+
+        model.fit(nd_iter,  # train data
+                    initializer = initializer,
+                    eval_data = nd_iter_val,
+                    optimizer = optimizer,  # use SGD to train
+                    optimizer_params = {'learning_rate':learning_rate,'wd':weight_decay},  
+                    epoch_end_callback  = None if model_prefix==None else mx.callback.do_checkpoint(model_prefix, 1),   #save parameters for each epoch if model_prefix is supplied
+                    batch_end_callback = log_to_list(int(N/batch_size),training_loss),  #this can save the training loss
+                    num_epoch = num_epoch,
+                    eval_metric = 'Loss')
+
+        return model,training_loss
+
+
+    def encoder(model,x):
+        params = model.arg_params
+        encoder_n = np.shape(params['encoder_h_bias'].asnumpy())[0]
+        encoder_h = np.dot(params['encoder_h_weight'].asnumpy(),np.transpose(x)) + np.reshape(params['encoder_h_bias'].asnumpy(),(encoder_n,1))
+        act_h = np.tanh(encoder_h)
+        mu = np.transpose(np.dot(params['mu_weight'].asnumpy(),act_h)) + params['mu_bias'].asnumpy()
+        logvar = np.transpose(np.dot(params['logvar_weight'].asnumpy(),act_h)) + params['logvar_bias'].asnumpy()
+        return mu,logvar
+
+    def sampler(mu,logvar):
+        z = mu + np.multiply(np.exp(0.5*logvar),np.random.normal(loc=0, scale=1,size=np.shape(logvar))) 
+        return z
+
+
+
+    def decoder(model,z):
+        params = model.arg_params
+        decoder_n = np.shape(params['decoder_z_bias'].asnumpy())[0]
+        decoder_z = np.dot(params['decoder_z_weight'].asnumpy(),np.transpose(z)) + np.reshape(params['decoder_z_bias'].asnumpy(),(decoder_n,1))
+        act_z = np.tanh(decoder_z)
+        decoder_x = np.transpose(np.dot(params['decoder_x_weight'].asnumpy(),act_z)) + params['decoder_x_bias'].asnumpy()
+        reconstructed_x = 1/(1+np.exp(-decoder_x))
+        return reconstructed_x
\ No newline at end of file
diff --git a/example/vae/VAE_example.ipynb b/example/vae/VAE_example.ipynb
new file mode 100644
index 0000000..c29348a
--- /dev/null
+++ b/example/vae/VAE_example.ipynb
@@ -0,0 +1,1167 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "import mxnet as mx\n",
+    "import numpy as np\n",
+    "import os\n",
+    "import logging\n",
+    "import matplotlib.pyplot as plt\n",
+    "import matplotlib.cm as cm"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Building a Variational Autoencoder in MXNet\n",
+    "\n",
+    "#### Xiaoyu Lu,  July 5th, 2017\n",
+    "\n",
+    "This tutorial guides you through the process of building a variational encoder in MXNet. in this notebook we'll focus on an example unsing the MNIST handwritten digit recognition dataset. Refer to [Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114/) for more details on the model description.\n",
+    "\n",
+    "## 1. Loading the Data\n",
+    "\n",
+    "We first load the MNIST dataset, which contains 60000 trainings and 10000 test examples. The following code import required modules and load the data. These images are stored in a 4-D matrix with shape (`batch_size, num_channels, width, height`). For the MNIST dataset, there is only one color channel, and both width and height are 28, so we reshape each image as a 28x28 array. See below for a visualization.\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "mnist = mx.test_utils.get_mnist()\n",
+    "image = np.reshape(mnist['train_data'],(60000,28*28))\n",
+    "label = image\n",
+    "image_test = np.reshape(mnist['test_data'],(10000,28*28))\n",
+    "label_test = image_test\n",
+    "[N,features] = np.shape(image)          #number of examples and features"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAsMAAADFCAYAAACxSv92AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAFzNJREFUeJzt3Xuw1XW5x/HPg8rdGMEtoYfj9gJqOZyNbdBGOYOCRFp5\noRwtHUpHgkzFQHFwTNMsrGMKURYIgdnxWIi3xjlSdPIymYlhBwEvpRBX2YxQQolcnvMH68wQ6/nF\nWnvd9lrf92umYe/P/u7fen6sZ+8efq7fd5m7CwAAAEhRp1oXAAAAANQKwzAAAACSxTAMAACAZDEM\nAwAAIFkMwwAAAEgWwzAAAACSxTAMAACAZDEMAwAAIFkMwwAAAEjWwaV8s5mNljRd0kGS7nP3af9s\n/eGHH+7Nzc2lPCSgl156abO7N1XzMeldlGrVqlXavHmzVfMx6VuUA79zUa8K7d12D8NmdpCk70k6\nW9Ja [...]
+      "text/plain": [
+       "<matplotlib.figure.Figure at 0x11234e9e8>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "f, (ax1, ax2, ax3, ax4) = plt.subplots(1,4,  sharex='col', sharey='row',figsize=(12,3))\n",
+    "ax1.imshow(np.reshape(image[0,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
+    "ax2.imshow(np.reshape(image[1,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
+    "ax3.imshow(np.reshape(image[2,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
+    "ax4.imshow(np.reshape(image[3,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We can optionally save the parameters in the directory variable 'model_prefix'. We first create data iterators for MXNet, with each batch of data containing 100 images."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "model_prefix = None\n",
+    "\n",
+    "batch_size = 100\n",
+    "nd_iter = mx.io.NDArrayIter(data={'data':image},label={'loss_label':label},\n",
+    "                            batch_size = batch_size)\n",
+    "nd_iter_test = mx.io.NDArrayIter(data={'data':image_test},label={'loss_label':label_test},\n",
+    "                            batch_size = batch_size)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## 2.  Building the Network Architecture\n",
+    "\n",
+    "### 2.1 Gaussian MLP as encoder\n",
+    "Next we constuct the neural network, as in the paper, we use *Multilayer Perceptron (MLP)* for both the encoder and decoder. For encoder, a Gaussian MLP is used:\n",
+    "\n",
+    "\\begin{align}\n",
+    "\\log q_{\\phi}(z|x) &= \\log \\mathcal{N}(z:\\mu,\\sigma^2I) \\\\\n",
+    "\\textit{ where } \\mu &= W_2h+b_2, \\log \\sigma^2 = W_3h+b_3\\\\\n",
+    "h &= \\tanh(W_1x+b_1)\n",
+    "\\end{align}\n",
+    "\n",
+    "where $\\{W_1,W_2,W_3,b_1,b_2,b_3\\}$ are the weights and biases of the MLP.\n",
+    "Note below that `encoder_mu` and `encoder_logvar` are symbols, can use `get_internals()` to get the values of them, after which we can sample the latent variable $z$.\n",
+    "\n",
+    "\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "## define data and loss labels as symbols \n",
+    "data = mx.sym.var('data')\n",
+    "loss_label = mx.sym.var('loss_label')\n",
+    "\n",
+    "## define fully connected and activation layers for the encoder, where we used tanh activation function.\n",
+    "encoder_h  = mx.sym.FullyConnected(data=data, name=\"encoder_h\",num_hidden=400)\n",
+    "act_h = mx.sym.Activation(data=encoder_h, act_type=\"tanh\",name=\"activation_h\")\n",
+    "\n",
+    "## define mu and log variance which are the fully connected layers of the previous activation layer\n",
+    "mu  = mx.sym.FullyConnected(data=act_h, name=\"mu\",num_hidden = 5)\n",
+    "logvar  = mx.sym.FullyConnected(data=act_h, name=\"logvar\",num_hidden = 5)\n",
+    "\n",
+    "## sample the latent variables z according to Normal(mu,var)\n",
+    "z = mu + np.multiply(mx.symbol.exp(0.5*logvar),mx.symbol.random_normal(loc=0, scale=1,shape=np.shape(logvar.get_internals()[\"logvar_output\"])))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 2.2 Bernoulli MLP as decoder\n",
+    "\n",
+    "In this case let $p_\\theta(x|z)$ be a multivariate Bernoulli whose probabilities are computed from $z$ with a feed forward neural network with a single hidden layer:\n",
+    "\n",
+    "\\begin{align}\n",
+    "\\log p(x|z) &= \\sum_{i=1}^D x_i\\log y_i + (1-x_i)\\log (1-y_i) \\\\\n",
+    "\\textit{ where }  y &= f_\\sigma(W_5\\tanh (W_4z+b_4)+b_5)\n",
+    "\\end{align}\n",
+    "\n",
+    "where $f_\\sigma(\\dot)$ is the elementwise sigmoid activation function, $\\{W_4,W_5,b_4,b_5\\}$ are the weights and biases of the decoder MLP. A Bernouilli likelihood is suitable for this type of data but you can easily extend it to other likelihood types by parsing into the argument `likelihood` in the `VAE` class, see section 4 for details."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "# define fully connected and tanh activation layers for the decoder\n",
+    "decoder_z = mx.sym.FullyConnected(data=z, name=\"decoder_z\",num_hidden=400)\n",
+    "act_z = mx.sym.Activation(data=decoder_z, act_type=\"tanh\",name=\"activation_z\")\n",
+    "\n",
+    "# define the output layer with sigmoid activation function, where the dimension is equal to the input dimension\n",
+    "decoder_x = mx.sym.FullyConnected(data=act_z, name=\"decoder_x\",num_hidden=features)\n",
+    "y = mx.sym.Activation(data=decoder_x, act_type=\"sigmoid\",name='activation_x')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 2.3 Joint Loss Function for the Encoder and the Decoder\n",
+    "\n",
+    "The variational lower bound can be estimated as:\n",
+    "\n",
+    "\\begin{align}\n",
+    "\\mathcal{L}(\\theta,\\phi;x_{(i)}) \\approx \\frac{1}{2}\\left(1+\\log ((\\sigma_j^{(i)})^2)-(\\mu_j^{(i)})^2-(\\sigma_j^{(i)})^2\\right) + \\log p_\\theta(x^{(i)}|z^{(i)})\n",
+    "\\end{align}\n",
+    "\n",
+    "where the first term is the KL divergence of the approximate posterior from the prior, and the second term is an expected negative reconstruction error. We would like to maximize this lower bound, so we can define the loss to be $-\\mathcal{L}$ for MXNet to minimize."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "# define the objective loss function that needs to be minimized\n",
+    "KL = 0.5*mx.symbol.sum(1+logvar-pow( mu,2)-mx.symbol.exp(logvar),axis=1)\n",
+    "loss = -mx.symbol.sum(mx.symbol.broadcast_mul(loss_label,mx.symbol.log(y)) + mx.symbol.broadcast_mul(1-loss_label,mx.symbol.log(1-y)),axis=1)-KL\n",
+    "output = mx.symbol.MakeLoss(sum(loss),name='loss')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## 3. Training the model\n",
+    "Now we can define the model and train it, we initilize the weights and the biases to be Gaussian(0,0.01), and use stochastic gradient descent for optimization. To warm start the training, one may initilize with pre-trainined parameters `arg_params` using `init=mx.initializer.Load(arg_params)`. To save intermediate results, we can optionally use `epoch_end_callback  = mx.callback.do_checkpoint(model_prefix, 1)` which saves the parameters to the path given by model_prefix, and with pe [...]
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "# set up the log\n",
+    "nd_iter.reset()\n",
+    "logging.getLogger().setLevel(logging.DEBUG)  \n",
+    "\n",
+    "#define function to trave back training loss\n",
+    "def log_to_list(period, lst):\n",
+    "    def _callback(param):\n",
+    "        \"\"\"The checkpoint function.\"\"\"\n",
+    "        if param.nbatch % period == 0:\n",
+    "            name, value = param.eval_metric.get()\n",
+    "            lst.append(value)\n",
+    "    return _callback\n",
+    "\n",
+    "# define the model\n",
+    "model = mx.mod.Module(\n",
+    "    symbol = output ,\n",
+    "    data_names=['data'],\n",
+    "    label_names = ['loss_label'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:root:Epoch[0] Train-loss=375.023381\n",
+      "INFO:root:Epoch[0] Time cost=6.127\n",
+      "INFO:root:Epoch[1] Train-loss=212.780315\n",
+      "INFO:root:Epoch[1] Time cost=6.409\n",
+      "INFO:root:Epoch[2] Train-loss=208.209400\n",
+      "INFO:root:Epoch[2] Time cost=6.619\n",
+      "INFO:root:Epoch[3] Train-loss=206.146854\n",
+      "INFO:root:Epoch[3] Time cost=6.648\n",
+      "INFO:root:Epoch[4] Train-loss=204.530598\n",
+      "INFO:root:Epoch[4] Time cost=7.000\n",
+      "INFO:root:Epoch[5] Train-loss=202.799992\n",
+      "INFO:root:Epoch[5] Time cost=6.778\n",
+      "INFO:root:Epoch[6] Train-loss=200.333474\n",
+      "INFO:root:Epoch[6] Time cost=7.187\n",
+      "INFO:root:Epoch[7] Train-loss=197.506393\n",
+      "INFO:root:Epoch[7] Time cost=6.712\n",
+      "INFO:root:Epoch[8] Train-loss=195.969775\n",
+      "INFO:root:Epoch[8] Time cost=6.896\n",
+      "INFO:root:Epoch[9] Train-loss=195.418288\n",
+      "INFO:root:Epoch[9] Time cost=6.887\n",
+      "INFO:root:Epoch[10] Train-loss=194.739763\n",
+      "INFO:root:Epoch[10] Time cost=6.745\n",
+      "INFO:root:Epoch[11] Train-loss=194.380536\n",
+      "INFO:root:Epoch[11] Time cost=6.706\n",
+      "INFO:root:Epoch[12] Train-loss=193.955462\n",
+      "INFO:root:Epoch[12] Time cost=6.592\n",
+      "INFO:root:Epoch[13] Train-loss=193.493671\n",
+      "INFO:root:Epoch[13] Time cost=6.775\n",
+      "INFO:root:Epoch[14] Train-loss=192.958739\n",
+      "INFO:root:Epoch[14] Time cost=6.600\n",
+      "INFO:root:Epoch[15] Train-loss=191.928542\n",
+      "INFO:root:Epoch[15] Time cost=6.586\n",
+      "INFO:root:Epoch[16] Train-loss=189.797939\n",
+      "INFO:root:Epoch[16] Time cost=6.700\n",
+      "INFO:root:Epoch[17] Train-loss=186.672446\n",
+      "INFO:root:Epoch[17] Time cost=6.869\n",
+      "INFO:root:Epoch[18] Train-loss=184.616599\n",
+      "INFO:root:Epoch[18] Time cost=7.144\n",
+      "INFO:root:Epoch[19] Train-loss=183.305978\n",
+      "INFO:root:Epoch[19] Time cost=6.997\n",
+      "INFO:root:Epoch[20] Train-loss=181.944634\n",
+      "INFO:root:Epoch[20] Time cost=6.481\n",
+      "INFO:root:Epoch[21] Train-loss=181.005329\n",
+      "INFO:root:Epoch[21] Time cost=6.754\n",
+      "INFO:root:Epoch[22] Train-loss=178.363118\n",
+      "INFO:root:Epoch[22] Time cost=7.000\n",
+      "INFO:root:Epoch[23] Train-loss=176.363421\n",
+      "INFO:root:Epoch[23] Time cost=6.923\n",
+      "INFO:root:Epoch[24] Train-loss=174.573954\n",
+      "INFO:root:Epoch[24] Time cost=6.510\n",
+      "INFO:root:Epoch[25] Train-loss=173.245940\n",
+      "INFO:root:Epoch[25] Time cost=6.926\n",
+      "INFO:root:Epoch[26] Train-loss=172.082522\n",
+      "INFO:root:Epoch[26] Time cost=6.733\n",
+      "INFO:root:Epoch[27] Train-loss=171.123084\n",
+      "INFO:root:Epoch[27] Time cost=6.616\n",
+      "INFO:root:Epoch[28] Train-loss=170.239300\n",
+      "INFO:root:Epoch[28] Time cost=7.004\n",
+      "INFO:root:Epoch[29] Train-loss=169.538416\n",
+      "INFO:root:Epoch[29] Time cost=6.341\n",
+      "INFO:root:Epoch[30] Train-loss=168.952901\n",
+      "INFO:root:Epoch[30] Time cost=6.736\n",
+      "INFO:root:Epoch[31] Train-loss=168.169076\n",
+      "INFO:root:Epoch[31] Time cost=6.616\n",
+      "INFO:root:Epoch[32] Train-loss=167.208973\n",
+      "INFO:root:Epoch[32] Time cost=6.446\n",
+      "INFO:root:Epoch[33] Train-loss=165.732213\n",
+      "INFO:root:Epoch[33] Time cost=6.405\n",
+      "INFO:root:Epoch[34] Train-loss=163.606801\n",
+      "INFO:root:Epoch[34] Time cost=6.139\n",
+      "INFO:root:Epoch[35] Train-loss=161.985880\n",
+      "INFO:root:Epoch[35] Time cost=6.678\n",
+      "INFO:root:Epoch[36] Train-loss=160.763072\n",
+      "INFO:root:Epoch[36] Time cost=8.749\n",
+      "INFO:root:Epoch[37] Train-loss=160.025193\n",
+      "INFO:root:Epoch[37] Time cost=6.519\n",
+      "INFO:root:Epoch[38] Train-loss=159.319723\n",
+      "INFO:root:Epoch[38] Time cost=7.584\n",
+      "INFO:root:Epoch[39] Train-loss=158.670701\n",
+      "INFO:root:Epoch[39] Time cost=6.874\n",
+      "INFO:root:Epoch[40] Train-loss=158.225733\n",
+      "INFO:root:Epoch[40] Time cost=6.402\n",
+      "INFO:root:Epoch[41] Train-loss=157.741337\n",
+      "INFO:root:Epoch[41] Time cost=8.617\n",
+      "INFO:root:Epoch[42] Train-loss=157.301411\n",
+      "INFO:root:Epoch[42] Time cost=6.515\n",
+      "INFO:root:Epoch[43] Train-loss=156.765170\n",
+      "INFO:root:Epoch[43] Time cost=6.447\n",
+      "INFO:root:Epoch[44] Train-loss=156.389668\n",
+      "INFO:root:Epoch[44] Time cost=6.130\n",
+      "INFO:root:Epoch[45] Train-loss=155.815434\n",
+      "INFO:root:Epoch[45] Time cost=6.155\n",
+      "INFO:root:Epoch[46] Train-loss=155.432254\n",
+      "INFO:root:Epoch[46] Time cost=6.158\n",
+      "INFO:root:Epoch[47] Train-loss=155.114027\n",
+      "INFO:root:Epoch[47] Time cost=6.749\n",
+      "INFO:root:Epoch[48] Train-loss=154.612441\n",
+      "INFO:root:Epoch[48] Time cost=6.255\n",
+      "INFO:root:Epoch[49] Train-loss=154.137659\n",
+      "INFO:root:Epoch[49] Time cost=7.813\n",
+      "INFO:root:Epoch[50] Train-loss=153.634072\n",
+      "INFO:root:Epoch[50] Time cost=7.408\n",
+      "INFO:root:Epoch[51] Train-loss=153.417397\n",
+      "INFO:root:Epoch[51] Time cost=7.747\n",
+      "INFO:root:Epoch[52] Train-loss=152.851887\n",
+      "INFO:root:Epoch[52] Time cost=8.587\n",
+      "INFO:root:Epoch[53] Train-loss=152.575068\n",
+      "INFO:root:Epoch[53] Time cost=7.554\n",
+      "INFO:root:Epoch[54] Train-loss=152.084419\n",
+      "INFO:root:Epoch[54] Time cost=6.628\n",
+      "INFO:root:Epoch[55] Train-loss=151.724836\n",
+      "INFO:root:Epoch[55] Time cost=6.535\n",
+      "INFO:root:Epoch[56] Train-loss=151.302525\n",
+      "INFO:root:Epoch[56] Time cost=7.148\n",
+      "INFO:root:Epoch[57] Train-loss=150.960916\n",
+      "INFO:root:Epoch[57] Time cost=7.195\n",
+      "INFO:root:Epoch[58] Train-loss=150.603895\n",
+      "INFO:root:Epoch[58] Time cost=6.649\n",
+      "INFO:root:Epoch[59] Train-loss=150.237795\n",
+      "INFO:root:Epoch[59] Time cost=6.222\n",
+      "INFO:root:Epoch[60] Train-loss=149.936080\n",
+      "INFO:root:Epoch[60] Time cost=8.450\n",
+      "INFO:root:Epoch[61] Train-loss=149.514617\n",
+      "INFO:root:Epoch[61] Time cost=6.113\n",
+      "INFO:root:Epoch[62] Train-loss=149.229345\n",
+      "INFO:root:Epoch[62] Time cost=6.088\n",
+      "INFO:root:Epoch[63] Train-loss=148.893769\n",
+      "INFO:root:Epoch[63] Time cost=6.558\n",
+      "INFO:root:Epoch[64] Train-loss=148.526837\n",
+      "INFO:root:Epoch[64] Time cost=7.590\n",
+      "INFO:root:Epoch[65] Train-loss=148.249951\n",
+      "INFO:root:Epoch[65] Time cost=6.180\n",
+      "INFO:root:Epoch[66] Train-loss=147.940414\n",
+      "INFO:root:Epoch[66] Time cost=6.242\n",
+      "INFO:root:Epoch[67] Train-loss=147.621304\n",
+      "INFO:root:Epoch[67] Time cost=8.501\n",
+      "INFO:root:Epoch[68] Train-loss=147.294314\n",
+      "INFO:root:Epoch[68] Time cost=7.645\n",
+      "INFO:root:Epoch[69] Train-loss=147.074479\n",
+      "INFO:root:Epoch[69] Time cost=7.092\n",
+      "INFO:root:Epoch[70] Train-loss=146.796387\n",
+      "INFO:root:Epoch[70] Time cost=6.914\n",
+      "INFO:root:Epoch[71] Train-loss=146.508842\n",
+      "INFO:root:Epoch[71] Time cost=6.606\n",
+      "INFO:root:Epoch[72] Train-loss=146.230444\n",
+      "INFO:root:Epoch[72] Time cost=7.755\n",
+      "INFO:root:Epoch[73] Train-loss=145.970296\n",
+      "INFO:root:Epoch[73] Time cost=6.409\n",
+      "INFO:root:Epoch[74] Train-loss=145.711610\n",
+      "INFO:root:Epoch[74] Time cost=6.334\n",
+      "INFO:root:Epoch[75] Train-loss=145.460053\n",
+      "INFO:root:Epoch[75] Time cost=7.269\n",
+      "INFO:root:Epoch[76] Train-loss=145.156451\n",
+      "INFO:root:Epoch[76] Time cost=6.744\n",
+      "INFO:root:Epoch[77] Train-loss=144.957674\n",
+      "INFO:root:Epoch[77] Time cost=7.100\n",
+      "INFO:root:Epoch[78] Train-loss=144.729749\n",
+      "INFO:root:Epoch[78] Time cost=6.242\n",
+      "INFO:root:Epoch[79] Train-loss=144.481728\n",
+      "INFO:root:Epoch[79] Time cost=6.865\n",
+      "INFO:root:Epoch[80] Train-loss=144.236061\n",
+      "INFO:root:Epoch[80] Time cost=6.632\n",
+      "INFO:root:Epoch[81] Train-loss=144.030473\n",
+      "INFO:root:Epoch[81] Time cost=6.764\n",
+      "INFO:root:Epoch[82] Train-loss=143.776374\n",
+      "INFO:root:Epoch[82] Time cost=6.564\n",
+      "INFO:root:Epoch[83] Train-loss=143.538847\n",
+      "INFO:root:Epoch[83] Time cost=6.181\n",
+      "INFO:root:Epoch[84] Train-loss=143.326444\n",
+      "INFO:root:Epoch[84] Time cost=6.220\n",
+      "INFO:root:Epoch[85] Train-loss=143.078987\n",
+      "INFO:root:Epoch[85] Time cost=6.823\n",
+      "INFO:root:Epoch[86] Train-loss=142.877117\n",
+      "INFO:root:Epoch[86] Time cost=7.755\n",
+      "INFO:root:Epoch[87] Train-loss=142.667316\n",
+      "INFO:root:Epoch[87] Time cost=6.068\n",
+      "INFO:root:Epoch[88] Train-loss=142.461755\n",
+      "INFO:root:Epoch[88] Time cost=6.111\n",
+      "INFO:root:Epoch[89] Train-loss=142.270438\n",
+      "INFO:root:Epoch[89] Time cost=6.221\n",
+      "INFO:root:Epoch[90] Train-loss=142.047086\n",
+      "INFO:root:Epoch[90] Time cost=8.061\n",
+      "INFO:root:Epoch[91] Train-loss=141.855774\n",
+      "INFO:root:Epoch[91] Time cost=6.433\n",
+      "INFO:root:Epoch[92] Train-loss=141.688955\n",
+      "INFO:root:Epoch[92] Time cost=7.153\n",
+      "INFO:root:Epoch[93] Train-loss=141.442910\n",
+      "INFO:root:Epoch[93] Time cost=7.113\n",
+      "INFO:root:Epoch[94] Train-loss=141.279274\n",
+      "INFO:root:Epoch[94] Time cost=7.152\n",
+      "INFO:root:Epoch[95] Train-loss=141.086522\n",
+      "INFO:root:Epoch[95] Time cost=6.472\n",
+      "INFO:root:Epoch[96] Train-loss=140.901925\n",
+      "INFO:root:Epoch[96] Time cost=6.767\n",
+      "INFO:root:Epoch[97] Train-loss=140.722496\n",
+      "INFO:root:Epoch[97] Time cost=7.044\n",
+      "INFO:root:Epoch[98] Train-loss=140.579295\n",
+      "INFO:root:Epoch[98] Time cost=7.040\n",
+      "INFO:root:Epoch[99] Train-loss=140.386067\n",
+      "INFO:root:Epoch[99] Time cost=6.669\n"
+     ]
+    }
+   ],
+   "source": [
+    "# training the model, save training loss as a list.\n",
+    "training_loss=list()\n",
+    "\n",
+    "# initilize the parameters for training using Normal.\n",
+    "init = mx.init.Normal(0.01)\n",
+    "model.fit(nd_iter,  # train data\n",
+    "              initializer=init,\n",
+    "              #if eval_data is supplied, test loss will also be reported\n",
+    "              #eval_data = nd_iter_test,\n",
+    "              optimizer='sgd',  # use SGD to train\n",
+    "              optimizer_params={'learning_rate':1e-3,'wd':1e-2},  \n",
+    "              epoch_end_callback  = None if model_prefix==None else mx.callback.do_checkpoint(model_prefix, 1),   #save parameters for each epoch if model_prefix is supplied\n",
+    "              batch_end_callback = log_to_list(N/batch_size,training_loss), \n",
+    "              num_epoch=100,\n",
+    "              eval_metric = 'Loss')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAAEWCAYAAACnlKo3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XmcZGV97/HPt6v3bfZ9BmaEEdlRRwSjUSO5EGIC90oM\nMYkYt/iSxJhrouISlStJNIkmvhKNxusVIooaF9C4gUaMMYADso6AA8MwM8zSs/Z0Ty+1/O4f53RT\n9PSpGmqmunq6v+/Xq15z6jmnqp7ndM3zq2c5z1FEYGZmVoumRmfAzMyOXw4iZmZWMwcRMzOrmYOI\nmZnVzEHEzMxq5iBiZmY1cxCxY0rSP0t677E+diaQ1CHpG5IOSPpyAz7/XZI+fbTHSnqJpK3HNneZ\n+XhM0gVT8VlWm+ZGZ8CmD0mPAa+PiFtqfY+IeFM9jp0hLgOWAAsiojDVHx4Rf1mPY58OSauBTUBL\nI86B [...]
+      "text/plain": [
+       "<matplotlib.figure.Figure at 0x145e16898>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "ELBO = [-training_loss[i] for i in range(len(training_loss))]\n",
+    "plt.plot(ELBO)\n",
+    "plt.ylabel('ELBO');plt.xlabel('epoch');plt.title(\"training curve for mini batches\")\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "As expected, the ELBO is monotonically increasing over epoch, and we reproduced the resutls given in the paper [Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114/). Now we can extract/load the parameters and then feed the network forward to calculate $y$ which is the reconstructed image, and we can also calculate the ELBO for the test set. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 80,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "arg_params = model.get_params()[0]\n",
+    "\n",
+    "#if saved the parameters, can load them at e.g. 100th epoch\n",
+    "#sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 100)\n",
+    "#assert sym.tojson() == output.tojson()\n",
+    "\n",
+    "e = y.bind(mx.cpu(), {'data': nd_iter_test.data[0][1],\n",
+    "                     'encoder_h_weight': arg_params['encoder_h_weight'],\n",
+    "                     'encoder_h_bias': arg_params['encoder_h_bias'],\n",
+    "                     'mu_weight': arg_params['mu_weight'],\n",
+    "                     'mu_bias': arg_params['mu_bias'],\n",
+    "                     'logvar_weight':arg_params['logvar_weight'],\n",
+    "                     'logvar_bias':arg_params['logvar_bias'],\n",
+    "                     'decoder_z_weight':arg_params['decoder_z_weight'],\n",
+    "                     'decoder_z_bias':arg_params['decoder_z_bias'],\n",
+    "                     'decoder_x_weight':arg_params['decoder_x_weight'],\n",
+    "                     'decoder_x_bias':arg_params['decoder_x_bias'],                \n",
+    "                     'loss_label':label})\n",
+    "\n",
+    "x_fit = e.forward()\n",
+    "x_construction = x_fit[0].asnumpy()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 78,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAsMAAADSCAYAAACvmc1VAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XuUXGWZ7/HfkyvkSkKHJARICDAkxMkBaUm4CZrAwQku\nRAdnogzxNjhneRhdOjOiZ80ZdIYj5yxvzNHlDB4QFEcFuSsjIANyRxIMQhJCQi6E0Em6SUI693R4\nzx+1w5Sp56F7d1d1d9X+ftbKSvev3q7au/dT1W9X7+fdllISAAAAUEQD+noDAAAAgL7CZBgAAACF\nxWQYAAAAhcVkGAAAAIXFZBgAAACFxWQYAAAAhcVkuB8ws/9nZl/u6+0AesLMpphZMrNBwe1LzOzc\nXt4soFPULuoRdVs9TIYlmdn2sn9vmtmuss8/WuvHTyl9KqX0v2r9OKgfZrbGzOb29XZUU0ppRkrp\n4b7e [...]
+      "text/plain": [
+       "<matplotlib.figure.Figure at 0x141020f98>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "# learning images on the test set\n",
+    "f, ((ax1, ax2, ax3, ax4)) = plt.subplots(1,4,  sharex='col', sharey='row',figsize=(12,3))\n",
+    "ax1.imshow(np.reshape(image_test[0,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
+    "ax1.set_title('True image')\n",
+    "ax2.imshow(np.reshape(x_construction[0,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
+    "ax2.set_title('Learned image')\n",
+    "ax3.imshow(np.reshape(x_construction[999,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
+    "ax3.set_title('Learned image')\n",
+    "ax4.imshow(np.reshape(x_construction[9999,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
+    "ax4.set_title('Learned image')\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 37,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "[('loss', 139.73684648437501)]"
+      ]
+     },
+     "execution_count": 37,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "#calculate the ELBO which is minus the loss for test set\n",
+    "metric = mx.metric.Loss()\n",
+    "model.score(nd_iter_test, metric)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## 4. All together: MXNet-based class VAE"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "from VAE import VAE"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "One can directly call the class `VAE` to do the training. The outputs are the learned model and training loss.\n",
+    "```VAE(n_latent=5,num_hidden_ecoder=400,num_hidden_decoder=400,x_train=None,x_valid=None,batch_size=100,learning_rate=0.001,weight_decay=0.01,num_epoch=100,optimizer='sgd',model_prefix=None, initializer = mx.init.Normal(0.01),likelihood=Bernoulli)```"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:root:Epoch[0] Train-loss=377.146422\n",
+      "INFO:root:Epoch[0] Time cost=5.989\n",
+      "INFO:root:Epoch[1] Train-loss=211.998043\n",
+      "INFO:root:Epoch[1] Time cost=6.303\n",
+      "INFO:root:Epoch[2] Train-loss=207.103096\n",
+      "INFO:root:Epoch[2] Time cost=7.368\n",
+      "INFO:root:Epoch[3] Train-loss=204.958183\n",
+      "INFO:root:Epoch[3] Time cost=7.530\n",
+      "INFO:root:Epoch[4] Train-loss=203.342700\n",
+      "INFO:root:Epoch[4] Time cost=8.887\n",
+      "INFO:root:Epoch[5] Train-loss=201.649251\n",
+      "INFO:root:Epoch[5] Time cost=9.147\n",
+      "INFO:root:Epoch[6] Train-loss=199.782661\n",
+      "INFO:root:Epoch[6] Time cost=8.924\n",
+      "INFO:root:Epoch[7] Train-loss=198.044015\n",
+      "INFO:root:Epoch[7] Time cost=8.920\n",
+      "INFO:root:Epoch[8] Train-loss=195.732077\n",
+      "INFO:root:Epoch[8] Time cost=8.857\n",
+      "INFO:root:Epoch[9] Train-loss=194.070547\n",
+      "INFO:root:Epoch[9] Time cost=9.216\n",
+      "INFO:root:Epoch[10] Train-loss=193.186871\n",
+      "INFO:root:Epoch[10] Time cost=8.966\n",
+      "INFO:root:Epoch[11] Train-loss=192.700208\n",
+      "INFO:root:Epoch[11] Time cost=8.843\n",
+      "INFO:root:Epoch[12] Train-loss=192.191504\n",
+      "INFO:root:Epoch[12] Time cost=8.152\n",
+      "INFO:root:Epoch[13] Train-loss=191.842837\n",
+      "INFO:root:Epoch[13] Time cost=6.180\n",
+      "INFO:root:Epoch[14] Train-loss=191.310450\n",
+      "INFO:root:Epoch[14] Time cost=6.067\n",
+      "INFO:root:Epoch[15] Train-loss=190.520681\n",
+      "INFO:root:Epoch[15] Time cost=6.058\n",
+      "INFO:root:Epoch[16] Train-loss=189.784146\n",
+      "INFO:root:Epoch[16] Time cost=6.046\n",
+      "INFO:root:Epoch[17] Train-loss=188.515020\n",
+      "INFO:root:Epoch[17] Time cost=6.062\n",
+      "INFO:root:Epoch[18] Train-loss=187.530712\n",
+      "INFO:root:Epoch[18] Time cost=6.088\n",
+      "INFO:root:Epoch[19] Train-loss=186.194826\n",
+      "INFO:root:Epoch[19] Time cost=6.491\n",
+      "INFO:root:Epoch[20] Train-loss=185.492288\n",
+      "INFO:root:Epoch[20] Time cost=6.182\n",
+      "INFO:root:Epoch[21] Train-loss=184.922654\n",
+      "INFO:root:Epoch[21] Time cost=6.058\n",
+      "INFO:root:Epoch[22] Train-loss=184.677911\n",
+      "INFO:root:Epoch[22] Time cost=6.042\n",
+      "INFO:root:Epoch[23] Train-loss=183.921396\n",
+      "INFO:root:Epoch[23] Time cost=5.994\n",
+      "INFO:root:Epoch[24] Train-loss=183.600690\n",
+      "INFO:root:Epoch[24] Time cost=6.038\n",
+      "INFO:root:Epoch[25] Train-loss=183.388476\n",
+      "INFO:root:Epoch[25] Time cost=6.025\n",
+      "INFO:root:Epoch[26] Train-loss=182.972208\n",
+      "INFO:root:Epoch[26] Time cost=6.014\n",
+      "INFO:root:Epoch[27] Train-loss=182.561678\n",
+      "INFO:root:Epoch[27] Time cost=6.064\n",
+      "INFO:root:Epoch[28] Train-loss=182.475261\n",
+      "INFO:root:Epoch[28] Time cost=5.983\n",
+      "INFO:root:Epoch[29] Train-loss=182.308808\n",
+      "INFO:root:Epoch[29] Time cost=6.371\n",
+      "INFO:root:Epoch[30] Train-loss=182.135900\n",
+      "INFO:root:Epoch[30] Time cost=6.038\n",
+      "INFO:root:Epoch[31] Train-loss=181.978367\n",
+      "INFO:root:Epoch[31] Time cost=6.924\n",
+      "INFO:root:Epoch[32] Train-loss=181.677153\n",
+      "INFO:root:Epoch[32] Time cost=8.205\n",
+      "INFO:root:Epoch[33] Train-loss=181.677775\n",
+      "INFO:root:Epoch[33] Time cost=6.017\n",
+      "INFO:root:Epoch[34] Train-loss=181.257998\n",
+      "INFO:root:Epoch[34] Time cost=6.056\n",
+      "INFO:root:Epoch[35] Train-loss=181.125288\n",
+      "INFO:root:Epoch[35] Time cost=6.020\n",
+      "INFO:root:Epoch[36] Train-loss=181.018858\n",
+      "INFO:root:Epoch[36] Time cost=6.035\n",
+      "INFO:root:Epoch[37] Train-loss=180.785110\n",
+      "INFO:root:Epoch[37] Time cost=6.049\n",
+      "INFO:root:Epoch[38] Train-loss=180.452598\n",
+      "INFO:root:Epoch[38] Time cost=6.083\n",
+      "INFO:root:Epoch[39] Train-loss=180.362733\n",
+      "INFO:root:Epoch[39] Time cost=6.198\n",
+      "INFO:root:Epoch[40] Train-loss=180.060788\n",
+      "INFO:root:Epoch[40] Time cost=6.049\n",
+      "INFO:root:Epoch[41] Train-loss=180.022728\n",
+      "INFO:root:Epoch[41] Time cost=6.135\n",
+      "INFO:root:Epoch[42] Train-loss=179.648499\n",
+      "INFO:root:Epoch[42] Time cost=6.055\n",
+      "INFO:root:Epoch[43] Train-loss=179.507952\n",
+      "INFO:root:Epoch[43] Time cost=6.108\n",
+      "INFO:root:Epoch[44] Train-loss=179.303132\n",
+      "INFO:root:Epoch[44] Time cost=6.020\n",
+      "INFO:root:Epoch[45] Train-loss=178.945211\n",
+      "INFO:root:Epoch[45] Time cost=6.004\n",
+      "INFO:root:Epoch[46] Train-loss=178.808598\n",
+      "INFO:root:Epoch[46] Time cost=6.016\n",
+      "INFO:root:Epoch[47] Train-loss=178.550906\n",
+      "INFO:root:Epoch[47] Time cost=6.050\n",
+      "INFO:root:Epoch[48] Train-loss=178.403674\n",
+      "INFO:root:Epoch[48] Time cost=6.115\n",
+      "INFO:root:Epoch[49] Train-loss=178.237544\n",
+      "INFO:root:Epoch[49] Time cost=6.004\n",
+      "INFO:root:Epoch[50] Train-loss=178.033747\n",
+      "INFO:root:Epoch[50] Time cost=6.051\n",
+      "INFO:root:Epoch[51] Train-loss=177.802884\n",
+      "INFO:root:Epoch[51] Time cost=6.028\n",
+      "INFO:root:Epoch[52] Train-loss=177.533980\n",
+      "INFO:root:Epoch[52] Time cost=6.052\n",
+      "INFO:root:Epoch[53] Train-loss=177.490143\n",
+      "INFO:root:Epoch[53] Time cost=6.019\n",
+      "INFO:root:Epoch[54] Train-loss=177.136637\n",
+      "INFO:root:Epoch[54] Time cost=6.014\n",
+      "INFO:root:Epoch[55] Train-loss=177.062524\n",
+      "INFO:root:Epoch[55] Time cost=6.024\n",
+      "INFO:root:Epoch[56] Train-loss=176.869033\n",
+      "INFO:root:Epoch[56] Time cost=6.065\n",
+      "INFO:root:Epoch[57] Train-loss=176.704606\n",
+      "INFO:root:Epoch[57] Time cost=6.037\n",
+      "INFO:root:Epoch[58] Train-loss=176.470091\n",
+      "INFO:root:Epoch[58] Time cost=6.012\n",
+      "INFO:root:Epoch[59] Train-loss=176.261440\n",
+      "INFO:root:Epoch[59] Time cost=6.215\n",
+      "INFO:root:Epoch[60] Train-loss=176.133904\n",
+      "INFO:root:Epoch[60] Time cost=6.042\n",
+      "INFO:root:Epoch[61] Train-loss=175.941920\n",
+      "INFO:root:Epoch[61] Time cost=6.000\n",
+      "INFO:root:Epoch[62] Train-loss=175.731296\n",
+      "INFO:root:Epoch[62] Time cost=6.025\n",
+      "INFO:root:Epoch[63] Train-loss=175.613303\n",
+      "INFO:root:Epoch[63] Time cost=6.002\n",
+      "INFO:root:Epoch[64] Train-loss=175.438844\n",
+      "INFO:root:Epoch[64] Time cost=5.982\n",
+      "INFO:root:Epoch[65] Train-loss=175.254716\n",
+      "INFO:root:Epoch[65] Time cost=6.016\n",
+      "INFO:root:Epoch[66] Train-loss=175.090210\n",
+      "INFO:root:Epoch[66] Time cost=6.008\n",
+      "INFO:root:Epoch[67] Train-loss=174.895443\n",
+      "INFO:root:Epoch[67] Time cost=6.008\n",
+      "INFO:root:Epoch[68] Train-loss=174.701321\n",
+      "INFO:root:Epoch[68] Time cost=6.418\n",
+      "INFO:root:Epoch[69] Train-loss=174.553292\n",
+      "INFO:root:Epoch[69] Time cost=6.072\n",
+      "INFO:root:Epoch[70] Train-loss=174.349379\n",
+      "INFO:root:Epoch[70] Time cost=6.048\n",
+      "INFO:root:Epoch[71] Train-loss=174.174641\n",
+      "INFO:root:Epoch[71] Time cost=6.036\n",
+      "INFO:root:Epoch[72] Train-loss=173.966333\n",
+      "INFO:root:Epoch[72] Time cost=6.017\n",
+      "INFO:root:Epoch[73] Train-loss=173.798454\n",
+      "INFO:root:Epoch[73] Time cost=6.018\n",
+      "INFO:root:Epoch[74] Train-loss=173.635657\n",
+      "INFO:root:Epoch[74] Time cost=5.985\n",
+      "INFO:root:Epoch[75] Train-loss=173.423795\n",
+      "INFO:root:Epoch[75] Time cost=6.016\n",
+      "INFO:root:Epoch[76] Train-loss=173.273981\n",
+      "INFO:root:Epoch[76] Time cost=6.018\n",
+      "INFO:root:Epoch[77] Train-loss=173.073401\n",
+      "INFO:root:Epoch[77] Time cost=5.996\n",
+      "INFO:root:Epoch[78] Train-loss=172.888044\n",
+      "INFO:root:Epoch[78] Time cost=6.035\n",
+      "INFO:root:Epoch[79] Train-loss=172.694943\n",
+      "INFO:root:Epoch[79] Time cost=8.492\n",
+      "INFO:root:Epoch[80] Train-loss=172.504260\n",
+      "INFO:root:Epoch[80] Time cost=7.380\n",
+      "INFO:root:Epoch[81] Train-loss=172.323245\n",
+      "INFO:root:Epoch[81] Time cost=6.063\n",
+      "INFO:root:Epoch[82] Train-loss=172.131274\n",
+      "INFO:root:Epoch[82] Time cost=6.209\n",
+      "INFO:root:Epoch[83] Train-loss=171.932986\n",
+      "INFO:root:Epoch[83] Time cost=6.060\n",
+      "INFO:root:Epoch[84] Train-loss=171.755262\n",
+      "INFO:root:Epoch[84] Time cost=6.068\n",
+      "INFO:root:Epoch[85] Train-loss=171.556803\n",
+      "INFO:root:Epoch[85] Time cost=6.004\n",
+      "INFO:root:Epoch[86] Train-loss=171.384773\n",
+      "INFO:root:Epoch[86] Time cost=6.059\n",
+      "INFO:root:Epoch[87] Train-loss=171.185034\n",
+      "INFO:root:Epoch[87] Time cost=6.001\n",
+      "INFO:root:Epoch[88] Train-loss=170.995980\n",
+      "INFO:root:Epoch[88] Time cost=6.143\n",
+      "INFO:root:Epoch[89] Train-loss=170.818701\n",
+      "INFO:root:Epoch[89] Time cost=6.690\n",
+      "INFO:root:Epoch[90] Train-loss=170.629929\n",
+      "INFO:root:Epoch[90] Time cost=6.869\n",
+      "INFO:root:Epoch[91] Train-loss=170.450824\n",
+      "INFO:root:Epoch[91] Time cost=7.156\n",
+      "INFO:root:Epoch[92] Train-loss=170.261806\n",
+      "INFO:root:Epoch[92] Time cost=6.972\n",
+      "INFO:root:Epoch[93] Train-loss=170.070318\n",
+      "INFO:root:Epoch[93] Time cost=6.595\n",
+      "INFO:root:Epoch[94] Train-loss=169.906993\n",
+      "INFO:root:Epoch[94] Time cost=6.561\n",
+      "INFO:root:Epoch[95] Train-loss=169.734455\n",
+      "INFO:root:Epoch[95] Time cost=6.744\n",
+      "INFO:root:Epoch[96] Train-loss=169.564318\n",
+      "INFO:root:Epoch[96] Time cost=6.601\n",
+      "INFO:root:Epoch[97] Train-loss=169.373926\n",
+      "INFO:root:Epoch[97] Time cost=6.725\n",
+      "INFO:root:Epoch[98] Train-loss=169.215408\n",
+      "INFO:root:Epoch[98] Time cost=6.391\n",
+      "INFO:root:Epoch[99] Train-loss=169.039854\n",
+      "INFO:root:Epoch[99] Time cost=6.677\n",
+      "INFO:root:Epoch[100] Train-loss=168.869222\n",
+      "INFO:root:Epoch[100] Time cost=6.370\n",
+      "INFO:root:Epoch[101] Train-loss=168.703175\n",
+      "INFO:root:Epoch[101] Time cost=6.607\n",
+      "INFO:root:Epoch[102] Train-loss=168.523054\n",
+      "INFO:root:Epoch[102] Time cost=6.368\n",
+      "INFO:root:Epoch[103] Train-loss=168.365964\n",
+      "INFO:root:Epoch[103] Time cost=10.267\n",
+      "INFO:root:Epoch[104] Train-loss=168.181174\n",
+      "INFO:root:Epoch[104] Time cost=11.132\n",
+      "INFO:root:Epoch[105] Train-loss=168.021498\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:root:Epoch[105] Time cost=10.187\n",
+      "INFO:root:Epoch[106] Train-loss=167.858251\n",
+      "INFO:root:Epoch[106] Time cost=10.676\n",
+      "INFO:root:Epoch[107] Train-loss=167.690670\n",
+      "INFO:root:Epoch[107] Time cost=10.973\n",
+      "INFO:root:Epoch[108] Train-loss=167.535069\n",
+      "INFO:root:Epoch[108] Time cost=10.108\n",
+      "INFO:root:Epoch[109] Train-loss=167.373971\n",
+      "INFO:root:Epoch[109] Time cost=11.013\n",
+      "INFO:root:Epoch[110] Train-loss=167.207507\n",
+      "INFO:root:Epoch[110] Time cost=11.427\n",
+      "INFO:root:Epoch[111] Train-loss=167.043077\n",
+      "INFO:root:Epoch[111] Time cost=10.349\n",
+      "INFO:root:Epoch[112] Train-loss=166.884060\n",
+      "INFO:root:Epoch[112] Time cost=13.129\n",
+      "INFO:root:Epoch[113] Train-loss=166.746976\n",
+      "INFO:root:Epoch[113] Time cost=11.255\n",
+      "INFO:root:Epoch[114] Train-loss=166.572499\n",
+      "INFO:root:Epoch[114] Time cost=10.037\n",
+      "INFO:root:Epoch[115] Train-loss=166.445170\n",
+      "INFO:root:Epoch[115] Time cost=10.406\n",
+      "INFO:root:Epoch[116] Train-loss=166.284912\n",
+      "INFO:root:Epoch[116] Time cost=10.170\n",
+      "INFO:root:Epoch[117] Train-loss=166.171475\n",
+      "INFO:root:Epoch[117] Time cost=10.034\n",
+      "INFO:root:Epoch[118] Train-loss=166.015457\n",
+      "INFO:root:Epoch[118] Time cost=10.047\n",
+      "INFO:root:Epoch[119] Train-loss=165.882208\n",
+      "INFO:root:Epoch[119] Time cost=10.008\n",
+      "INFO:root:Epoch[120] Train-loss=165.753836\n",
+      "INFO:root:Epoch[120] Time cost=10.056\n",
+      "INFO:root:Epoch[121] Train-loss=165.626045\n",
+      "INFO:root:Epoch[121] Time cost=10.704\n",
+      "INFO:root:Epoch[122] Train-loss=165.492859\n",
+      "INFO:root:Epoch[122] Time cost=10.609\n",
+      "INFO:root:Epoch[123] Train-loss=165.361132\n",
+      "INFO:root:Epoch[123] Time cost=10.027\n",
+      "INFO:root:Epoch[124] Train-loss=165.256487\n",
+      "INFO:root:Epoch[124] Time cost=11.225\n",
+      "INFO:root:Epoch[125] Train-loss=165.119995\n",
+      "INFO:root:Epoch[125] Time cost=11.266\n",
+      "INFO:root:Epoch[126] Train-loss=165.012773\n",
+      "INFO:root:Epoch[126] Time cost=10.547\n",
+      "INFO:root:Epoch[127] Train-loss=164.898748\n",
+      "INFO:root:Epoch[127] Time cost=10.339\n",
+      "INFO:root:Epoch[128] Train-loss=164.775702\n",
+      "INFO:root:Epoch[128] Time cost=10.875\n",
+      "INFO:root:Epoch[129] Train-loss=164.692449\n",
+      "INFO:root:Epoch[129] Time cost=8.412\n",
+      "INFO:root:Epoch[130] Train-loss=164.564323\n",
+      "INFO:root:Epoch[130] Time cost=7.239\n",
+      "INFO:root:Epoch[131] Train-loss=164.468273\n",
+      "INFO:root:Epoch[131] Time cost=10.096\n",
+      "INFO:root:Epoch[132] Train-loss=164.328320\n",
+      "INFO:root:Epoch[132] Time cost=9.680\n",
+      "INFO:root:Epoch[133] Train-loss=164.256156\n",
+      "INFO:root:Epoch[133] Time cost=10.707\n",
+      "INFO:root:Epoch[134] Train-loss=164.151625\n",
+      "INFO:root:Epoch[134] Time cost=13.835\n",
+      "INFO:root:Epoch[135] Train-loss=164.046402\n",
+      "INFO:root:Epoch[135] Time cost=10.049\n",
+      "INFO:root:Epoch[136] Train-loss=163.960676\n",
+      "INFO:root:Epoch[136] Time cost=9.625\n",
+      "INFO:root:Epoch[137] Train-loss=163.873193\n",
+      "INFO:root:Epoch[137] Time cost=9.845\n",
+      "INFO:root:Epoch[138] Train-loss=163.783837\n",
+      "INFO:root:Epoch[138] Time cost=9.618\n",
+      "INFO:root:Epoch[139] Train-loss=163.658903\n",
+      "INFO:root:Epoch[139] Time cost=10.411\n",
+      "INFO:root:Epoch[140] Train-loss=163.588920\n",
+      "INFO:root:Epoch[140] Time cost=9.633\n",
+      "INFO:root:Epoch[141] Train-loss=163.493254\n",
+      "INFO:root:Epoch[141] Time cost=10.668\n",
+      "INFO:root:Epoch[142] Train-loss=163.401188\n",
+      "INFO:root:Epoch[142] Time cost=10.644\n",
+      "INFO:root:Epoch[143] Train-loss=163.334470\n",
+      "INFO:root:Epoch[143] Time cost=9.665\n",
+      "INFO:root:Epoch[144] Train-loss=163.235133\n",
+      "INFO:root:Epoch[144] Time cost=9.612\n",
+      "INFO:root:Epoch[145] Train-loss=163.168029\n",
+      "INFO:root:Epoch[145] Time cost=9.578\n",
+      "INFO:root:Epoch[146] Train-loss=163.092392\n",
+      "INFO:root:Epoch[146] Time cost=10.215\n",
+      "INFO:root:Epoch[147] Train-loss=163.014362\n",
+      "INFO:root:Epoch[147] Time cost=12.296\n",
+      "INFO:root:Epoch[148] Train-loss=162.891574\n",
+      "INFO:root:Epoch[148] Time cost=9.578\n",
+      "INFO:root:Epoch[149] Train-loss=162.831664\n",
+      "INFO:root:Epoch[149] Time cost=9.536\n",
+      "INFO:root:Epoch[150] Train-loss=162.768784\n",
+      "INFO:root:Epoch[150] Time cost=9.607\n",
+      "INFO:root:Epoch[151] Train-loss=162.695416\n",
+      "INFO:root:Epoch[151] Time cost=9.681\n",
+      "INFO:root:Epoch[152] Train-loss=162.620814\n",
+      "INFO:root:Epoch[152] Time cost=9.464\n",
+      "INFO:root:Epoch[153] Train-loss=162.527031\n",
+      "INFO:root:Epoch[153] Time cost=9.518\n",
+      "INFO:root:Epoch[154] Train-loss=162.466575\n",
+      "INFO:root:Epoch[154] Time cost=9.562\n",
+      "INFO:root:Epoch[155] Train-loss=162.409388\n",
+      "INFO:root:Epoch[155] Time cost=9.483\n",
+      "INFO:root:Epoch[156] Train-loss=162.308957\n",
+      "INFO:root:Epoch[156] Time cost=9.545\n",
+      "INFO:root:Epoch[157] Train-loss=162.211725\n",
+      "INFO:root:Epoch[157] Time cost=9.542\n",
+      "INFO:root:Epoch[158] Train-loss=162.141098\n",
+      "INFO:root:Epoch[158] Time cost=9.768\n",
+      "INFO:root:Epoch[159] Train-loss=162.124311\n",
+      "INFO:root:Epoch[159] Time cost=7.155\n",
+      "INFO:root:Epoch[160] Train-loss=162.013039\n",
+      "INFO:root:Epoch[160] Time cost=6.147\n",
+      "INFO:root:Epoch[161] Train-loss=161.954485\n",
+      "INFO:root:Epoch[161] Time cost=9.121\n",
+      "INFO:root:Epoch[162] Train-loss=161.913859\n",
+      "INFO:root:Epoch[162] Time cost=9.936\n",
+      "INFO:root:Epoch[163] Train-loss=161.830799\n",
+      "INFO:root:Epoch[163] Time cost=8.612\n",
+      "INFO:root:Epoch[164] Train-loss=161.768672\n",
+      "INFO:root:Epoch[164] Time cost=9.722\n",
+      "INFO:root:Epoch[165] Train-loss=161.689120\n",
+      "INFO:root:Epoch[165] Time cost=9.478\n",
+      "INFO:root:Epoch[166] Train-loss=161.598279\n",
+      "INFO:root:Epoch[166] Time cost=9.466\n",
+      "INFO:root:Epoch[167] Train-loss=161.551172\n",
+      "INFO:root:Epoch[167] Time cost=9.419\n",
+      "INFO:root:Epoch[168] Train-loss=161.488880\n",
+      "INFO:root:Epoch[168] Time cost=9.457\n",
+      "INFO:root:Epoch[169] Train-loss=161.410458\n",
+      "INFO:root:Epoch[169] Time cost=9.504\n",
+      "INFO:root:Epoch[170] Train-loss=161.340681\n",
+      "INFO:root:Epoch[170] Time cost=9.866\n",
+      "INFO:root:Epoch[171] Train-loss=161.281700\n",
+      "INFO:root:Epoch[171] Time cost=9.526\n",
+      "INFO:root:Epoch[172] Train-loss=161.215523\n",
+      "INFO:root:Epoch[172] Time cost=9.511\n",
+      "INFO:root:Epoch[173] Train-loss=161.152452\n",
+      "INFO:root:Epoch[173] Time cost=9.498\n",
+      "INFO:root:Epoch[174] Train-loss=161.058544\n",
+      "INFO:root:Epoch[174] Time cost=9.561\n",
+      "INFO:root:Epoch[175] Train-loss=161.036475\n",
+      "INFO:root:Epoch[175] Time cost=9.463\n",
+      "INFO:root:Epoch[176] Train-loss=161.009996\n",
+      "INFO:root:Epoch[176] Time cost=9.629\n",
+      "INFO:root:Epoch[177] Train-loss=160.853546\n",
+      "INFO:root:Epoch[177] Time cost=9.518\n",
+      "INFO:root:Epoch[178] Train-loss=160.860520\n",
+      "INFO:root:Epoch[178] Time cost=9.395\n",
+      "INFO:root:Epoch[179] Train-loss=160.810621\n",
+      "INFO:root:Epoch[179] Time cost=9.452\n",
+      "INFO:root:Epoch[180] Train-loss=160.683071\n",
+      "INFO:root:Epoch[180] Time cost=9.411\n",
+      "INFO:root:Epoch[181] Train-loss=160.674101\n",
+      "INFO:root:Epoch[181] Time cost=8.784\n",
+      "INFO:root:Epoch[182] Train-loss=160.554823\n",
+      "INFO:root:Epoch[182] Time cost=7.265\n",
+      "INFO:root:Epoch[183] Train-loss=160.536528\n",
+      "INFO:root:Epoch[183] Time cost=6.108\n",
+      "INFO:root:Epoch[184] Train-loss=160.525913\n",
+      "INFO:root:Epoch[184] Time cost=6.349\n",
+      "INFO:root:Epoch[185] Train-loss=160.399412\n",
+      "INFO:root:Epoch[185] Time cost=7.364\n",
+      "INFO:root:Epoch[186] Train-loss=160.380027\n",
+      "INFO:root:Epoch[186] Time cost=7.651\n",
+      "INFO:root:Epoch[187] Train-loss=160.272921\n",
+      "INFO:root:Epoch[187] Time cost=7.309\n",
+      "INFO:root:Epoch[188] Train-loss=160.243907\n",
+      "INFO:root:Epoch[188] Time cost=7.162\n",
+      "INFO:root:Epoch[189] Train-loss=160.194351\n",
+      "INFO:root:Epoch[189] Time cost=8.941\n",
+      "INFO:root:Epoch[190] Train-loss=160.130400\n",
+      "INFO:root:Epoch[190] Time cost=10.242\n",
+      "INFO:root:Epoch[191] Train-loss=160.073841\n",
+      "INFO:root:Epoch[191] Time cost=10.528\n",
+      "INFO:root:Epoch[192] Train-loss=160.021623\n",
+      "INFO:root:Epoch[192] Time cost=9.482\n",
+      "INFO:root:Epoch[193] Train-loss=159.938673\n",
+      "INFO:root:Epoch[193] Time cost=9.465\n",
+      "INFO:root:Epoch[194] Train-loss=159.885823\n",
+      "INFO:root:Epoch[194] Time cost=9.523\n",
+      "INFO:root:Epoch[195] Train-loss=159.886516\n",
+      "INFO:root:Epoch[195] Time cost=9.599\n",
+      "INFO:root:Epoch[196] Train-loss=159.797400\n",
+      "INFO:root:Epoch[196] Time cost=8.675\n",
+      "INFO:root:Epoch[197] Train-loss=159.705562\n",
+      "INFO:root:Epoch[197] Time cost=9.551\n",
+      "INFO:root:Epoch[198] Train-loss=159.738354\n",
+      "INFO:root:Epoch[198] Time cost=9.919\n",
+      "INFO:root:Epoch[199] Train-loss=159.619932\n",
+      "INFO:root:Epoch[199] Time cost=10.121\n"
+     ]
+    }
+   ],
+   "source": [
+    "# can initilize weights and biases with the learned parameters \n",
+    "#init = mx.initializer.Load(params)\n",
+    "\n",
+    "# call the VAE , output model contains the learned model and training loss\n",
+    "out = VAE(n_latent=2,x_train=image,x_valid=None,num_epoch=200) "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "# encode test images to obtain mu and logvar which are used for sampling\n",
+    "[mu,logvar] = VAE.encoder(out,image_test)\n",
+    "#sample in the latent space\n",
+    "z = VAE.sampler(mu,logvar)\n",
+    "# decode from the latent space to obtain reconstructed images\n",
+    "x_construction = VAE.decoder(out,z)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAsMAAADSCAYAAACvmc1VAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XuU3GWd5/HPl9xJQsilE5JAEgIBQgQCtshNyQpBRVE5\nq86gO6COg7PHdeTozIjumRWdYWX3eJ11jw4uCOIFdcHb6OhGRA0YkA6JQIgxkgSTkPu1cyfJs3/U\nL0yR+n5J/7qrurvqeb/OyUn3p55UP7/+fav6SXU935+llAQAAADk6Li+ngAAAADQV1gMAwAAIFss\nhgEAAJAtFsMAAADIFothAAAAZIvFMAAAALLFYrgfMLP/Y2Yf6+t5AD1hZtPMLJnZwOD2JWY2p5en\nBRwTtYtmRN3WD4thSWa2q+rPYTPbW/X5Oxv99VNK700p/fdGfx00DzNbZWZX9vU86imlNCul9Mu+\nngca [...]
+      "text/plain": [
+       "<matplotlib.figure.Figure at 0x11e9ff7f0>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "f, ((ax1, ax2, ax3, ax4)) = plt.subplots(1,4,  sharex='col', sharey='row',figsize=(12,3))\n",
+    "ax1.imshow(np.reshape(image_test[0,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
+    "ax1.set_title('True image')\n",
+    "ax2.imshow(np.reshape(x_construction[0,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
+    "ax2.set_title('Learned image')\n",
+    "ax3.imshow(np.reshape(x_construction[999,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
+    "ax3.set_title('Learned image')\n",
+    "ax4.imshow(np.reshape(x_construction[9999,:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
+    "ax4.set_title('Learned image')\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 78,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAEICAYAAABLdt/UAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXuUZXV157/7vmjurQbjacQRqFuK4izGqIlE42saLZNg\ngxLHLJUU0s7MsuTmMcSBPLBWBieZMq9ZQSfIYI3iOH1vjMwYHzE4JATGrPggNAhMRHTQrioUH3QB\nAl1IN1V7/rj3FKdOncfvnHve9/tZa6/uuvc8fufcc75nn/3bv/0TVQUhhJDqUMu7AYQQQpKFwk4I\nIRWDwk4IIRWDwk4IIRWDwk4IIRWDwk4IIRWDwk5yRUSWReR1ebeDkCpBYSelQURURJ6b0LbOFpHv\nJLEtQooGhZ0QQioGhZ0UBhF5qYh8WUQeFpHvichVItIafff3o8XuFJHHROSto8/PE5E7Rut8SURe\n6Nje [...]
+      "text/plain": [
+       "<matplotlib.figure.Figure at 0x12fab6eb8>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAsMAAAC3CAYAAAD3oFO8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XmQXeV55/HfA4jFyAIJLQgEEmKRWAIiFhgHs4Ultkmw\nQ7m8ToaUE5OZSjLOjCsVT6pmxjVVmXFSWWpqqiYTMvaAPTF2cELhcpwhBisDGErQCAm0IIRAWAK0\ngQFhK2x6549uOfee99fdr+7St889309Vl3Qenb73vec55/Sr2+9zn0gpCQAAAGiiwwY9AAAAAGBQ\nmAwDAACgsZgMAwAAoLGYDAMAAKCxmAwDAACgsZgMAwAAoLGYDAMAAKCxmAwXilF/GBEvjX39YUTE\noMeFQxcRV0XEqoh4NSK2DXo86FxE/G5ErI+IfRHxbET87qDHhM5ExL+NiGci4rWIeCEi/iwijhj0\nuNCZ [...]
+      "text/plain": [
+       "<matplotlib.figure.Figure at 0x11f9cd438>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "z1 = z[:,0]\n",
+    "z2 = z[:,1]\n",
+    "\n",
+    "fig = plt.figure()\n",
+    "ax = fig.add_subplot(111)\n",
+    "ax.plot(z1,z2,'ko')\n",
+    "plt.title(\"latent space\")\n",
+    "\n",
+    "#np.where((z1>3) & (z2<2) & (z2>0))\n",
+    "#select the points from the latent space\n",
+    "a_vec = [2,5,7,789,25,9993]\n",
+    "for i in range(len(a_vec)):\n",
+    "    ax.plot(z1[a_vec[i]],z2[a_vec[i]],'ro')  \n",
+    "    ax.annotate('z%d' %i, xy=(z1[a_vec[i]],z2[a_vec[i]]), \n",
+    "                xytext=(z1[a_vec[i]],z2[a_vec[i]]),color = 'r',fontsize=15)\n",
+    "\n",
+    "\n",
+    "f, ((ax0, ax1, ax2, ax3, ax4,ax5)) = plt.subplots(1,6,  sharex='col', sharey='row',figsize=(12,2.5))\n",
+    "for i in range(len(a_vec)):\n",
+    "    eval('ax%d' %(i)).imshow(np.reshape(x_construction[a_vec[i],:],(28,28)), interpolation='nearest', cmap=cm.Greys)\n",
+    "    eval('ax%d' %(i)).set_title('z%d'%i)\n",
+    "\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Above is a plot of points in the 2D latent space and their corresponding decoded images, it can be seen that points that are close in the latent space get mapped to the same digit from the decoder, and we can see how it evolves from left to right."
+   ]
+  }
+ ],
+ "metadata": {
+  "anaconda-cloud": {},
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.1"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].