You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2020/08/11 04:52:11 UTC
[incubator-mxnet] branch master updated: [Gluon] Add VAE demo
(#18758)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new e101d68 [Gluon] Add VAE demo (#18758)
e101d68 is described below
commit e101d68d8b338aaca4541ba055c83124c8289919
Author: Xi Wang <xi...@gmail.com>
AuthorDate: Tue Aug 11 12:50:26 2020 +0800
[Gluon] Add VAE demo (#18758)
* add VAE demo
* minor changes
* change format to md
* minor changes
* add liscence
* Update VAE.md
* update vae demo
* remove unnecessary files
---
example/probability/VAE/VAE.md | 259 +++++++++++++++++++++++++++++++++++
example/probability/VAE/VAE_11_0.png | Bin 0 -> 9062 bytes
example/probability/VAE/VAE_14_0.png | Bin 0 -> 15863 bytes
3 files changed, 259 insertions(+)
diff --git a/example/probability/VAE/VAE.md b/example/probability/VAE/VAE.md
new file mode 100644
index 0000000..a334d70
--- /dev/null
+++ b/example/probability/VAE/VAE.md
@@ -0,0 +1,259 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements. See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership. The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License. You may obtain a copy of the License at -->
+
+<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied. See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+
+# VAE with Gluon.probability
+
+In this example, we will demonstrate how you can implement a Variational Auto-encoder(VAE) with Gluon.probability and MXNet's latest NumPy API.
+
+
+```{.python .input}
+import numpy as np
+import mxnet as mx
+from mxnet import autograd, gluon, np, npx
+from mxnet.gluon import nn
+import mxnet.gluon.probability as mgp
+import matplotlib.pyplot as plt
+
+# Switch numpy-compatible semantics on.
+npx.set_np()
+
+# Set context for model context, here we choose to use GPU.
+model_ctx = mx.gpu(0)
+```
+
+## Dataset
+
+We will use MNIST here for simplicity purpose.
+
+
+```{.python .input}
+def load_data(batch_size):
+ mnist_train = gluon.data.vision.MNIST(train=True)
+ mnist_test = gluon.data.vision.MNIST(train=False)
+ num_worker = 4
+ transformer = gluon.data.vision.transforms.ToTensor()
+ return (gluon.data.DataLoader(mnist_train.transform_first(transformer),
+ batch_size, shuffle=True,
+ num_workers=num_worker),
+ gluon.data.DataLoader(mnist_test.transform_first(transformer),
+ batch_size, shuffle=False,
+ num_workers=num_worker))
+
+```
+
+## Model definition
+
+
+```{.python .input}
+class VAE(gluon.HybridBlock):
+ def __init__(self, n_hidden=256, n_latent=2, n_layers=1, n_output=784, act_type='relu', **kwargs):
+ r"""
+ n_hidden : number of hidden units in each layer
+ n_latent : dimension of the latent space
+ n_layers : number of layers in the encoder and decoder network
+ n_output : dimension of the observed data
+ """
+ self.soft_zero = 1e-10
+ self.n_latent = n_latent
+ self.output = None
+ self.mu = None
+ super(VAE, self).__init__(**kwargs)
+ self.encoder = nn.HybridSequential()
+ for _ in range(n_layers):
+ self.encoder.add(nn.Dense(n_hidden, activation=act_type))
+ self.encoder.add(nn.Dense(n_latent*2, activation=None))
+ self.decoder = nn.HybridSequential()
+ for _ in range(n_layers):
+ self.decoder.add(nn.Dense(n_hidden, activation=act_type))
+ self.decoder.add(nn.Dense(n_output, activation='sigmoid'))
+
+ def encode(self, x):
+ r"""
+ Given a batch of x,
+ return the encoder's output
+ """
+ # [loc_1, ..., loc_n, log(scale_1), ..., log(scale_n)]
+ h = self.encoder(x)
+
+ # Extract loc and log_scale from the encoder output.
+ loc_scale = np.split(h, 2, 1)
+ loc = loc_scale[0]
+ log_scale = loc_scale[1]
+
+ # Convert log_scale back to scale.
+ scale = np.exp(log_scale)
+
+ # Return a Normal object.
+ return mgp.Normal(loc, scale)
+
+ def decode(self, z):
+ r"""
+ Given a batch of samples from z,
+ return the decoder's output
+ """
+ return self.decoder(z)
+
+ def forward(self, x):
+ r"""
+ Given a batch of data x,
+ return the negative of Evidence Lower-bound,
+ i.e. an objective to minimize.
+ """
+ # prior p(z)
+ pz = mgp.Normal(0, 1)
+
+ # posterior q(z|x)
+ qz_x = self.encode(x)
+
+ # Sampling operation qz_x.sample() is automatically reparameterized.
+ z = qz_x.sample()
+
+ # Reconstruction result
+ y = self.decode(z)
+
+ # Gluon.probability can help you calculate the analytical kl-divergence
+ # between two distribution objects.
+ KL = mgp.kl_divergence(qz_x, pz).sum(1)
+
+ # We assume p(x|z) ~ Bernoulli, therefore we compute the reconstruction
+ # loss with binary cross entropy.
+ logloss = np.sum(x * np.log(y + self.soft_zero) + (1 - x)
+ * np.log(1 - y + self.soft_zero), axis=1)
+ loss = -logloss + KL
+ return loss
+```
+
+## Training
+
+
+```{.python .input}
+def train(net, n_epoch, print_period, train_iter, test_iter):
+ net.initialize(mx.init.Xavier(), ctx=model_ctx)
+ net.hybridize()
+ trainer = gluon.Trainer(net.collect_params(), 'adam',
+ {'learning_rate': .001})
+ training_loss = []
+ validation_loss = []
+ for epoch in range(n_epoch):
+ epoch_loss = 0
+ epoch_val_loss = 0
+
+ n_batch_train = 0
+ for batch in train_iter:
+ n_batch_train += 1
+ data = batch[0].as_in_context(model_ctx).reshape(-1, 28 * 28)
+ with autograd.record():
+ loss = net(data)
+ loss.backward()
+ trainer.step(data.shape[0])
+ epoch_loss += np.mean(loss)
+
+ n_batch_val = 0
+ for batch in test_iter:
+ n_batch_val += 1
+ data = batch[0].as_in_context(model_ctx).reshape(-1, 28 * 28)
+ loss = net(data)
+ epoch_val_loss += np.mean(loss)
+
+ epoch_loss /= n_batch_train
+ epoch_val_loss /= n_batch_val
+
+ training_loss.append(epoch_loss)
+ validation_loss.append(epoch_val_loss)
+
+ if epoch % max(print_period, 1) == 0:
+ print('Epoch{}, Training loss {:.2f}, Validation loss {:.2f}'.format(
+ epoch, float(epoch_loss), float(epoch_val_loss)))
+```
+
+
+```{.python .input}
+n_hidden = 128
+n_latent = 40
+n_layers = 3
+n_output = 784
+batch_size = 128
+model_prefix = 'vae_gluon_{}d{}l{}h.params'.format(
+ n_latent, n_layers, n_hidden)
+net = VAE(n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers,
+ n_output=n_output)
+net.hybridize()
+n_epoch = 50
+print_period = n_epoch // 10
+train_set, test_set = load_data(batch_size)
+train(net, n_epoch, print_period, train_set, test_set)
+```
+
+
+## Reconstruction visualiztion
+
+To verify the effictiveness of our model, we first take a look at how well our model can reconstruct the data.
+
+
+```{.python .input}
+# Grab a batch from the test set
+qz_x = None
+for batch in test_set:
+ data = batch[0].as_in_context(model_ctx).reshape(-1, 28 * 28)
+ qz_x = net.encode(data)
+ break
+```
+
+
+```{.python .input}
+num_samples = 4
+fig, axes = plt.subplots(nrows=num_samples, ncols=2, figsize=(4, 6), subplot_kw={'xticks': [], 'yticks': []})
+axes[0, 0].set_title('Original image')
+axes[0, 1].set_title('reconstruction')
+for i in range(num_samples):
+ axes[i, 0].imshow(data[i].squeeze().reshape(28, 28).asnumpy(), cmap='gray')
+ axes[i, 1].imshow(net.decode(qz_x.sample())[i].reshape(28, 28).asnumpy(), cmap='gray')
+```
+
+
+![png](./VAE_11_0.png)
+
+
+## Sample generation
+
+One of the most important difference between Variational Auto-encoder and Auto-encoder is VAE's capabilities of generating new samples.
+
+To achieve that, one simply needs to feed a random sample from $p(z) \sim \mathcal{N}(0,1)$ to the decoder network.
+
+
+```{.python .input}
+def plot_samples(samples, h=5, w=10):
+ fig, axes = plt.subplots(nrows=h,
+ ncols=w,
+ figsize=(int(1.4 * w), int(1.4 * h)),
+ subplot_kw={'xticks': [], 'yticks': []})
+ for i, ax in enumerate(axes.flatten()):
+ ax.imshow(samples[i], cmap='gray')
+```
+
+
+```{.python .input}
+n_samples = 20
+noise = np.random.randn(n_samples, n_latent).as_in_context(model_ctx)
+dec_output = net.decode(noise).reshape(-1, 28, 28).asnumpy()
+plot_samples(dec_output, 4, 5)
+```
+
+
+![png](./VAE_14_0.png)
+
diff --git a/example/probability/VAE/VAE_11_0.png b/example/probability/VAE/VAE_11_0.png
new file mode 100644
index 0000000..455f09e
Binary files /dev/null and b/example/probability/VAE/VAE_11_0.png differ
diff --git a/example/probability/VAE/VAE_14_0.png b/example/probability/VAE/VAE_14_0.png
new file mode 100644
index 0000000..d26173b
Binary files /dev/null and b/example/probability/VAE/VAE_14_0.png differ