You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/26 05:10:43 UTC

[GitHub] KellenSunderland commented on a change in pull request #12933: Update autoencoder example

KellenSunderland commented on a change in pull request #12933: Update autoencoder example
URL: https://github.com/apache/incubator-mxnet/pull/12933#discussion_r236123126
 
 

 ##########
 File path: example/autoencoder/convolutional_autoencoder.ipynb
 ##########
 @@ -0,0 +1,587 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Convolutional Autoencoder"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "![](https://cdn-images-1.medium.com/max/800/1*LSYNW5m3TN7xRX61BZhoZA.png)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "In this example we will demonstrate how you can create a convolutional autoencoder in Gluon"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import random\n",
+    "\n",
+    "import matplotlib.pyplot as plt\n",
+    "import mxnet as mx\n",
+    "from mxnet import autograd, nd, gluon"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Data\n",
+    "\n",
+    "We will use the FashionMNIST dataset which is of a similar format than MNIST but is richer and has more variance"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "batch_size = 512\n",
+    "ctx = mx.gpu() if len(mx.test_utils.list_gpus()) > 0 else mx.cpu()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "transform = lambda x,y: (x.transpose((2,0,1)).astype('float32')/255., y)\n",
+    "\n",
+    "train_dataset = gluon.data.vision.FashionMNIST(train=True)\n",
+    "test_dataset = gluon.data.vision.FashionMNIST(train=False)\n",
+    "\n",
+    "train_dataset_t = train_dataset.transform(transform)\n",
+    "test_dataset_t = test_dataset.transform(transform)\n",
+    "\n",
+    "train_data = gluon.data.DataLoader(train_dataset_t, batch_size=batch_size, last_batch='rollover', shuffle=True, num_workers=5)\n",
+    "test_data = gluon.data.DataLoader(test_dataset_t, batch_size=batch_size, last_batch='rollover', shuffle=True, num_workers=5)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "\n",
+      "text/plain": [
+       "<Figure size 1440x720 with 10 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "plt.figure(figsize=(20,10))\n",
+    "for i in range(10):\n",
+    "    ax = plt.subplot(1, 10, i+1)\n",
+    "    ax.imshow(train_dataset[i][0].squeeze().asnumpy(), cmap='gray')\n",
+    "    ax.axis('off')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Network"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "net = gluon.nn.HybridSequential(prefix='autoencoder_')\n",
+    "with net.name_scope():\n",
+    "    # Encoder 1x28x28 -> 32x1x1\n",
+    "    encoder = gluon.nn.HybridSequential(prefix='encoder_')\n",
+    "    with encoder.name_scope():\n",
+    "        encoder.add(\n",
+    "            gluon.nn.Conv2D(channels=4, kernel_size=3, padding=1, strides=(2,2), activation='relu'),\n",
+    "            gluon.nn.BatchNorm(),\n",
+    "            gluon.nn.Conv2D(channels=8, kernel_size=3, padding=1, strides=(2,2), activation='relu'),\n",
+    "            gluon.nn.BatchNorm(),\n",
+    "            gluon.nn.Conv2D(channels=16, kernel_size=3, padding=1, strides=(2,2), activation='relu'),\n",
+    "            gluon.nn.BatchNorm(),\n",
+    "            gluon.nn.Conv2D(channels=32, kernel_size=3, padding=0, strides=(2,2),activation='relu'),\n",
+    "            gluon.nn.BatchNorm()\n",
+    "        )\n",
+    "    decoder = gluon.nn.HybridSequential(prefix='decoder_')\n",
+    "    # Decoder 32x1x1 -> 1x28x28\n",
+    "    with decoder.name_scope():\n",
+    "        decoder.add(\n",
+    "            gluon.nn.Conv2D(channels=32, kernel_size=3, padding=2, activation='relu'),\n",
+    "            gluon.nn.HybridLambda(lambda F, x: F.UpSampling(x, scale=2, sample_type='nearest')),\n",
+    "            gluon.nn.BatchNorm(),\n",
+    "            gluon.nn.Conv2D(channels=16, kernel_size=3, padding=1, activation='relu'),\n",
+    "            gluon.nn.HybridLambda(lambda F, x: F.UpSampling(x, scale=2, sample_type='nearest')),\n",
+    "            gluon.nn.BatchNorm(),\n",
+    "            gluon.nn.Conv2D(channels=8, kernel_size=3, padding=2, activation='relu'),\n",
+    "            gluon.nn.HybridLambda(lambda F, x: F.UpSampling(x, scale=2, sample_type='nearest')),\n",
+    "            gluon.nn.BatchNorm(),\n",
+    "            gluon.nn.Conv2D(channels=4, kernel_size=3, padding=1, activation='relu'),\n",
+    "            gluon.nn.Conv2D(channels=1, kernel_size=3, padding=1, activation='sigmoid')\n",
+    "        )\n",
+    "    net.add(\n",
+    "        encoder,\n",
+    "        decoder\n",
+    "    )"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "net.initialize(ctx=ctx)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "--------------------------------------------------------------------------------\n",
+      "        Layer (type)                                Output Shape         Param #\n",
+      "================================================================================\n",
+      "               Input                              (1, 1, 28, 28)               0\n",
+      "        Activation-1  <Symbol autoencoder_encoder_conv0_relu_fwd>               0\n",
+      "        Activation-2                              (1, 4, 14, 14)               0\n",
+      "            Conv2D-3                              (1, 4, 14, 14)              40\n",
+      "         BatchNorm-4                              (1, 4, 14, 14)              16\n",
+      "        Activation-5  <Symbol autoencoder_encoder_conv1_relu_fwd>               0\n",
+      "        Activation-6                                (1, 8, 7, 7)               0\n",
+      "            Conv2D-7                                (1, 8, 7, 7)             296\n",
+      "         BatchNorm-8                                (1, 8, 7, 7)              32\n",
+      "        Activation-9  <Symbol autoencoder_encoder_conv2_relu_fwd>               0\n",
+      "       Activation-10                               (1, 16, 4, 4)               0\n",
+      "           Conv2D-11                               (1, 16, 4, 4)            1168\n",
+      "        BatchNorm-12                               (1, 16, 4, 4)              64\n",
+      "       Activation-13  <Symbol autoencoder_encoder_conv3_relu_fwd>               0\n",
+      "       Activation-14                               (1, 32, 1, 1)               0\n",
+      "           Conv2D-15                               (1, 32, 1, 1)            4640\n",
+      "        BatchNorm-16                               (1, 32, 1, 1)             128\n",
+      "       Activation-17  <Symbol autoencoder_decoder_conv0_relu_fwd>               0\n",
+      "       Activation-18                               (1, 32, 3, 3)               0\n",
+      "           Conv2D-19                               (1, 32, 3, 3)            9248\n",
+      "     HybridLambda-20                               (1, 32, 6, 6)               0\n",
+      "        BatchNorm-21                               (1, 32, 6, 6)             128\n",
+      "       Activation-22  <Symbol autoencoder_decoder_conv1_relu_fwd>               0\n",
+      "       Activation-23                               (1, 16, 6, 6)               0\n",
+      "           Conv2D-24                               (1, 16, 6, 6)            4624\n",
+      "     HybridLambda-25                             (1, 16, 12, 12)               0\n",
+      "        BatchNorm-26                             (1, 16, 12, 12)              64\n",
+      "       Activation-27  <Symbol autoencoder_decoder_conv2_relu_fwd>               0\n",
+      "       Activation-28                              (1, 8, 14, 14)               0\n",
+      "           Conv2D-29                              (1, 8, 14, 14)            1160\n",
+      "     HybridLambda-30                              (1, 8, 28, 28)               0\n",
+      "        BatchNorm-31                              (1, 8, 28, 28)              32\n",
+      "       Activation-32  <Symbol autoencoder_decoder_conv3_relu_fwd>               0\n",
+      "       Activation-33                              (1, 4, 28, 28)               0\n",
+      "           Conv2D-34                              (1, 4, 28, 28)             292\n",
+      "       Activation-35  <Symbol autoencoder_decoder_conv4_sigmoid_fwd>               0\n",
+      "       Activation-36                              (1, 1, 28, 28)               0\n",
+      "           Conv2D-37                              (1, 1, 28, 28)              37\n",
+      "================================================================================\n",
+      "Parameters in forward computation graph, duplicate included\n",
+      "   Total params: 21969\n",
+      "   Trainable params: 21737\n",
+      "   Non-trainable params: 232\n",
+      "Shared params in forward computation graph: 0\n",
+      "Unique parameters in model: 21969\n",
+      "--------------------------------------------------------------------------------\n"
+     ]
+    }
+   ],
+   "source": [
+    "net.summary(test_dataset_t[0][0].expand_dims(axis=0).as_in_context(ctx))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We can see that the original image goes from 28x28 = 784 pixels to a vector of length 32. That is a ~25x information compression rate.\n",
+    "Then the decoder brings back this compressed information to the original shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "l2_loss = gluon.loss.L2Loss()\n",
+    "l1_loss = gluon.loss.L1Loss()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': 0.001, 'wd':0.001})\n",
+    "net.hybridize(static_shape=True, static_alloc=True)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Training loop"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch [0], Loss 0.2246280246310764\n",
+      "Epoch [1], Loss 0.14493223337026742\n",
+      "Epoch [2], Loss 0.13147933666522688\n",
+      "Epoch [3], Loss 0.12138325943906084\n",
+      "Epoch [4], Loss 0.11291297684367906\n",
+      "Epoch [5], Loss 0.10611823453741559\n",
+      "Epoch [6], Loss 0.09942417470817892\n",
+      "Epoch [7], Loss 0.09408332955124032\n",
+      "Epoch [8], Loss 0.08883619716024807\n",
+      "Epoch [9], Loss 0.08491455795418502\n",
+      "Epoch [10], Loss 0.0809355994402352\n",
+      "Epoch [11], Loss 0.07784551636785524\n",
+      "Epoch [12], Loss 0.07570812029716296\n",
+      "Epoch [13], Loss 0.07417513366438384\n",
+      "Epoch [14], Loss 0.07218785571236895\n",
+      "Epoch [15], Loss 0.07093704352944584\n",
+      "Epoch [16], Loss 0.0700181406787318\n",
+      "Epoch [17], Loss 0.0689836893326197\n",
+      "Epoch [18], Loss 0.06782063459738708\n",
+      "Epoch [19], Loss 0.06713279088338216\n"
+     ]
+    }
+   ],
+   "source": [
+    "epochs = 20\n",
+    "for e in range(epochs):\n",
+    "    curr_loss = 0.\n",
+    "    for i, (data, _) in enumerate(train_data):\n",
+    "        data = data.as_in_context(ctx)\n",
+    "        with autograd.record():\n",
+    "            output = net(data)\n",
+    "            # Compute the L2 and L1 losses between the original and the generated image\n",
+    "            l2 = l2_loss(output.flatten(), data.flatten())\n",
+    "            l1 = l1_loss(output.flatten(), data.flatten())\n",
+    "            l =  l2 + l1 \n",
+    "        l.backward()\n",
+    "        trainer.step(data.shape[0])\n",
+    "        \n",
+    "        curr_loss += l.mean()\n",
+    "\n",
+    "    print(\"Epoch [{}], Loss {}\".format(e, curr_loss.asscalar()/(i+1)))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Testing reconstruction"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We plot 10 images and their reconstruction by the autoencoder. The results are pretty good for a ~25x compression rate!"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "
 
 Review comment:
   Nit: Would being more specific and saying 'encoder network' be helpful to readers?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services