You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/08/23 01:17:40 UTC

[incubator-mxnet] branch master updated: A binary RBM example (#11268)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 67ba3c5  A binary RBM example (#11268)
67ba3c5 is described below

commit 67ba3c508b1cad6ee222ffd378210a77db3195fa
Author: Deokjae Lee <36...@users.noreply.github.com>
AuthorDate: Thu Aug 23 10:17:31 2018 +0900

    A binary RBM example (#11268)
    
    * A binary RBM example
    
    * Retrigger CI
    
    * Rename the parameter `interaction` as `interaction_weight`
    
    * Improved Bernoulli sampling
    
    * Cosmetic changes
    
    * Implement likelihood estimation using AIS
    
    * Add momemtun option
    
    * Replace underbars in the command line options with hyphens
    
    * Adjust default values of the hyperparameters and add command line options to set device
    
    * Update README
    
    * Minor updates
    
    * Setting num_workers for the dataloader
    
    * Remove unnecessary `enumerate` call
    
    * Fix a bug on `--cuda` option
    
    * Show the initial real images also
    
    * Minor change in README
    
    * Trigger CI
    
    * Trigger CI
---
 example/README.md                                  |   1 +
 example/restricted-boltzmann-machine/README.md     |  13 ++
 example/restricted-boltzmann-machine/binary_rbm.py | 253 +++++++++++++++++++++
 .../binary_rbm_gluon.py                            | 142 ++++++++++++
 .../binary_rbm_module.py                           | 171 ++++++++++++++
 example/restricted-boltzmann-machine/samples.png   | Bin 0 -> 191570 bytes
 6 files changed, 580 insertions(+)

diff --git a/example/README.md b/example/README.md
index ff071df..6b9a086 100644
--- a/example/README.md
+++ b/example/README.md
@@ -117,6 +117,7 @@ If your tutorial depends on specific packages, simply add them to this provision
     * [DDPG](reinforcement-learning/ddpg) - example of training DDPG for CartPole
     * [DQN](reinforcement-learning/dqn) - examples of training DQN and Double DQN to play Atari Games
     * [Parallel Advantage-Actor Critic](reinforcement-learning/parallel_actor_critic)
+* [Restricted Boltzmann Machine](restricted-boltzmann-machine) - an example of the binary restricted Boltzmann machine learning MNIST
 * [RNN Time Major](rnn-time-major) - RNN implementation with Time-major layout
 * [Recurrent Neural Net](rnn) - creating recurrent neural networks models using high level `mxnet.rnn` interface
 * [Sparse](sparse) - a variety of sparse examples
diff --git a/example/restricted-boltzmann-machine/README.md b/example/restricted-boltzmann-machine/README.md
new file mode 100644
index 0000000..129120b
--- /dev/null
+++ b/example/restricted-boltzmann-machine/README.md
@@ -0,0 +1,13 @@
+# Restricted Boltzmann machine (RBM)
+
+An example of the binary RBM [1] learning the MNIST data. The RBM is implemented as a custom operator, and a gluon block is also provided. `binary_rbm.py` contains the implementation of the RBM. `binary_rbm_module.py` and `binary_rbm_gluon.py` train the MNIST data using the module interface and the gluon interface respectively. The MNIST data is downloaded automatically.
+
+The progress of the learning is monitored by estimating the log-likelihood using the annealed importance sampling [2,3]. The learning with the default hyperparameters takes about 25 minutes on GTX 1080Ti and the resulting log-likelihood is around -70 for both testing and training datasets.
+
+Here are some samples generated by the RBM with the default hyperparameters. The samples (right) are obtained by 3000 steps of Gibbs sampling starting from randomly chosen real images (left).
+
+<p style="text-align:center"><img src="samples.png"/></p>
+
+[1] G E Hinton &amp; R R Salakhutdinov, Reducing the Dimensionality of Data with Neural Networks Science **313**, 5786 (2006)<br/>
+[2] R M Neal, Annealed importance sampling. Stat Comput **11** 2 (2001)<br/>
+[3] R Salakhutdinov &amp; I Murray, On the quantitative analysis of deep belief networks. In Proc. ICML '08 **25** (2008)
\ No newline at end of file
diff --git a/example/restricted-boltzmann-machine/binary_rbm.py b/example/restricted-boltzmann-machine/binary_rbm.py
new file mode 100644
index 0000000..115e9d1
--- /dev/null
+++ b/example/restricted-boltzmann-machine/binary_rbm.py
@@ -0,0 +1,253 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import ast
+import numpy as np
+import mxnet as mx
+
+class BinaryRBM(mx.operator.CustomOp):
+
+    def __init__(self, k):
+        self.k = k # Persistent contrastive divergence k
+
+    def forward(self, is_train, req, in_data, out_data, aux):
+        visible_layer_data = in_data[0] # (num_batch, num_visible)
+        visible_layer_bias = in_data[1] # (num_visible,)
+        hidden_layer_bias = in_data[2]  # (num_hidden,)
+        interaction_weight= in_data[3]        # (num_visible, num_hidden)
+
+        if is_train:
+            _, hidden_layer_prob_1 = self.sample_hidden_layer(visible_layer_data, hidden_layer_bias, interaction_weight)
+            hidden_layer_sample = aux[1] # The initial state of the Gibbs sampling for persistent CD
+        else:
+            hidden_layer_sample, hidden_layer_prob_1 = self.sample_hidden_layer(visible_layer_data, hidden_layer_bias, interaction_weight)
+
+        # k-step Gibbs sampling
+        for _ in range(self.k):
+            visible_layer_sample, visible_layer_prob_1 = self.sample_visible_layer(hidden_layer_sample, visible_layer_bias, interaction_weight)
+            hidden_layer_sample, _ = self.sample_hidden_layer(visible_layer_sample, hidden_layer_bias, interaction_weight)
+
+        if is_train:
+            # Used in backward and next forward
+            aux[0][:] = visible_layer_sample
+            aux[1][:] = hidden_layer_sample
+
+        self.assign(out_data[0], req[0], visible_layer_prob_1)
+        self.assign(out_data[1], req[1], hidden_layer_prob_1)
+
+    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
+        visible_layer_data = in_data[0]    # (num_batch, num_visible)
+        visible_layer_sample = aux[0]      # (num_batch, num_visible)
+        hidden_layer_prob_1 = out_data[1]  # (num_batch, num_hidden)
+        hidden_layer_sample = aux[1]       # (num_batch, num_hidden)
+
+        grad_visible_layer_bias = (visible_layer_sample - visible_layer_data).mean(axis=0)
+        grad_hidden_layer_bias = (hidden_layer_sample - hidden_layer_prob_1).mean(axis=0)
+        grad_interaction_weight= (mx.nd.linalg.gemm2(visible_layer_sample.expand_dims(2), hidden_layer_sample.expand_dims(1)) -
+                            mx.nd.linalg.gemm2(visible_layer_data.expand_dims(2), hidden_layer_prob_1.expand_dims(1))
+                           ).mean(axis=0)
+
+        # We don't need the gradient on the visible layer input
+        self.assign(in_grad[1], req[1], grad_visible_layer_bias)
+        self.assign(in_grad[2], req[2], grad_hidden_layer_bias)
+        self.assign(in_grad[3], req[3], grad_interaction_weight)
+
+    def sample_hidden_layer(self, visible_layer_batch, hidden_layer_bias, interaction_weight):
+        return self.sample_layer(visible_layer_batch, hidden_layer_bias, interaction_weight, False)
+
+    def sample_visible_layer(self, hidden_layer_batch, visible_layer_bias, interaction_weight):
+        return self.sample_layer(hidden_layer_batch, visible_layer_bias, interaction_weight, True)
+
+    def sample_layer(self, other_layer_sample, layer_bias, interaction_weight, interaction_transpose):
+        prob_1 = mx.nd.linalg.gemm(
+            other_layer_sample,
+            interaction_weight,
+            layer_bias.tile(reps=(other_layer_sample.shape[0], 1)),
+            transpose_b=interaction_transpose) # (num_batch, num_units_in_layer)
+        prob_1.sigmoid(out=prob_1)
+        return mx.nd.random.uniform(shape=prob_1.shape) < prob_1, prob_1
+
+@mx.operator.register('BinaryRBM')
+class BinaryRBMProp(mx.operator.CustomOpProp):
+
+    # Auxiliary states are requested only if `for_training` is true.
+    def __init__(self, num_hidden, k, for_training):
+        super(BinaryRBMProp, self).__init__(False)
+        self.num_hidden = int(num_hidden)
+        self.k = int(k)
+        self.for_training = ast.literal_eval(for_training)
+
+    def list_arguments(self):
+        # 0: (batch size, the number of visible units)
+        # 1: (the number of visible units,)
+        # 2: (the number of hidden units,)
+        # 3: (the number of visible units, the number of hidden units)
+        return ['data', 'visible_layer_bias', 'hidden_layer_bias', 'interaction_weight']
+
+    def list_outputs(self):
+        # 0: The probabilities that each visible unit is 1 after `k` steps of Gibbs sampling starting from the given `data`.
+        #    (batch size, the number of visible units)
+        # 1: The probabilities that each hidden unit is 1 conditional on the given `data`.
+        #    (batch size, the number of hidden units)
+        return ['visible_layer_prob_1', 'hidden_layer_prob_1']
+
+    def list_auxiliary_states(self):
+        # Used only if `self.for_trainig is true.
+        # 0: Store the visible layer samples obtained in the forward pass, used in the backward pass.
+        #    (batch size, the number of visible units)
+        # 1: Store the hidden layer samples obtained in the forward pass, used in the backward and next forward pass.
+        #    (batch size, the number of hidden units)
+        return ['aux_visible_layer_sample', 'aux_hidden_layer_sample'] if self.for_training else []
+
+    def infer_shape(self, in_shapes):
+        visible_layer_data_shape = in_shapes[0] # The input data
+        visible_layer_bias_shape = (visible_layer_data_shape[1],)
+        hidden_layer_bias_shape = (self.num_hidden,)
+        interaction_shape = (visible_layer_data_shape[1], self.num_hidden)
+        visible_layer_sample_shape = visible_layer_data_shape
+        visible_layer_prob_1_shape = visible_layer_sample_shape
+        hidden_layer_sample_shape = (visible_layer_data_shape[0], self.num_hidden)
+        hidden_layer_prob_1_shape = hidden_layer_sample_shape
+        return [visible_layer_data_shape, visible_layer_bias_shape, hidden_layer_bias_shape, interaction_shape], \
+               [visible_layer_prob_1_shape, hidden_layer_prob_1_shape], \
+               [visible_layer_sample_shape, hidden_layer_sample_shape] if self.for_training else []
+
+    def infer_type(self, in_type):
+        return [in_type[0], in_type[0], in_type[0], in_type[0]], \
+               [in_type[0], in_type[0]], \
+               [in_type[0], in_type[0]] if self.for_training else []
+
+    def create_operator(self, ctx, in_shapes, in_dtypes):
+        return BinaryRBM(self.k)
+
+# For gluon API
+class BinaryRBMBlock(mx.gluon.HybridBlock):
+
+    def __init__(self, num_hidden, k, for_training, **kwargs):
+        super(BinaryRBMBlock, self).__init__(**kwargs)
+        with self.name_scope():
+            self.num_hidden = num_hidden
+            self.k = k
+            self.for_training = for_training
+            self.visible_layer_bias = self.params.get('visible_layer_bias', shape=(0,), allow_deferred_init=True)
+            self.hidden_layer_bias = self.params.get('hidden_layer_bias', shape=(0,), allow_deferred_init=True)
+            self.interaction_weight= self.params.get('interaction_weight', shape=(0, 0), allow_deferred_init=True)
+            if for_training:
+                self.aux_visible_layer_sample = self.params.get('aux_visible_layer_sample', shape=(0, 0), allow_deferred_init=True)
+                self.aux_hidden_layer_sample = self.params.get('aux_hidden_layer_sample', shape=(0, 0), allow_deferred_init=True)
+
+    def hybrid_forward(self, F, data, visible_layer_bias, hidden_layer_bias, interaction_weight, aux_visible_layer_sample=None, aux_hidden_layer_sample=None):
+        # As long as `for_training` is kept constant, this conditional statement does not prevent hybridization.
+        if self.for_training:
+            return F.Custom(
+                data,
+                visible_layer_bias,
+                hidden_layer_bias,
+                interaction_weight,
+                aux_visible_layer_sample,
+                aux_hidden_layer_sample,
+                num_hidden=self.num_hidden,
+                k=self.k,
+                for_training=self.for_training,
+                op_type='BinaryRBM')
+        else:
+            return F.Custom(
+                data,
+                visible_layer_bias,
+                hidden_layer_bias,
+                interaction_weight,
+                num_hidden=self.num_hidden,
+                k=self.k,
+                for_training=self.for_training,
+                op_type='BinaryRBM')
+
+def estimate_log_likelihood(visible_layer_bias, hidden_layer_bias, interaction_weight, ais_batch_size, ais_num_batch, ais_intermediate_steps, ais_burn_in_steps, data, ctx):
+    # The base-rate RBM with no hidden layer. The visible layer bias is set to the same with the given RBM.
+    # This is not the only possible choice but simple and works well.
+    base_rate_visible_layer_bias = visible_layer_bias
+    base_rate_visible_prob_1 = base_rate_visible_layer_bias.sigmoid()
+    log_base_rate_z = base_rate_visible_layer_bias.exp().log1p().sum()
+
+    def log_intermediate_unnormalized_prob(visible_layer_sample, beta):
+        p = mx.nd.dot(
+                visible_layer_sample, 
+                (1 - beta) * base_rate_visible_layer_bias + beta * visible_layer_bias)
+        if beta != 0:
+            p += mx.nd.linalg.gemm(
+                    visible_layer_sample,
+                    interaction_weight,
+                    hidden_layer_bias.tile(reps=(visible_layer_sample.shape[0], 1)),
+                    transpose_b=False,
+                    alpha=beta,
+                    beta=beta).exp().log1p().sum(axis=1)
+        return p
+
+    def sample_base_rbm():
+        rands = mx.nd.random.uniform(shape=(ais_batch_size, base_rate_visible_prob_1.shape[0]), ctx=ctx)
+        return rands < base_rate_visible_prob_1.tile(reps=(ais_batch_size, 1))
+
+    def sample_intermediate_visible_layer(visible_layer_sample, beta):
+        for _ in range(ais_burn_in_steps):
+            hidden_prob_1 = mx.nd.linalg.gemm(
+                visible_layer_sample,
+                interaction_weight,
+                hidden_layer_bias.tile(reps=(visible_layer_sample.shape[0], 1)),
+                transpose_b=False,
+                alpha=beta,
+                beta=beta)
+            hidden_prob_1.sigmoid(out=hidden_prob_1)
+            hidden_layer_sample = mx.nd.random.uniform(shape=hidden_prob_1.shape, ctx=ctx) < hidden_prob_1
+            visible_prob_1 = mx.nd.linalg.gemm(
+                hidden_layer_sample,
+                interaction_weight,
+                visible_layer_bias.tile(reps=(hidden_layer_sample.shape[0], 1)),
+                transpose_b=True,
+                alpha=beta,
+                beta=beta) + (1 - beta) * base_rate_visible_layer_bias
+            visible_prob_1.sigmoid(out=visible_prob_1)
+            visible_layer_sample = mx.nd.random.uniform(shape=visible_prob_1.shape, ctx=ctx) < visible_prob_1
+        return visible_layer_sample
+
+    def array_from_batch(batch):
+        if isinstance(batch, mx.io.DataBatch):
+            return batch.data[0].as_in_context(ctx).flatten()
+        else: # batch is an instance of list in the case of gluon DataLoader
+            return batch[0].as_in_context(ctx).flatten()
+
+    importance_weight_sum = 0
+    num_ais_samples = ais_num_batch * ais_batch_size
+    for _ in range(ais_num_batch):
+        log_importance_weight = 0
+        visible_layer_sample = sample_base_rbm()
+        for n in range(1, ais_intermediate_steps + 1):
+            beta = 1. * n / ais_intermediate_steps
+            log_importance_weight += \
+                log_intermediate_unnormalized_prob(visible_layer_sample, beta) - \
+                log_intermediate_unnormalized_prob(visible_layer_sample, (n - 1.) / ais_intermediate_steps)
+            visible_layer_sample = sample_intermediate_visible_layer(visible_layer_sample, beta)
+        importance_weight_sum += log_importance_weight.exp().sum()
+    log_z = (importance_weight_sum / num_ais_samples).log() + log_base_rate_z
+
+    log_likelihood = 0
+    num_data = 0
+    for batch in data:
+        batch_array = array_from_batch(batch)
+        log_likelihood += log_intermediate_unnormalized_prob(batch_array, 1) - log_z
+        num_data += batch_array.shape[0]
+    log_likelihood = log_likelihood.sum() / num_data
+
+    return log_likelihood.asscalar(), log_z.asscalar()
diff --git a/example/restricted-boltzmann-machine/binary_rbm_gluon.py b/example/restricted-boltzmann-machine/binary_rbm_gluon.py
new file mode 100644
index 0000000..cdce2e6
--- /dev/null
+++ b/example/restricted-boltzmann-machine/binary_rbm_gluon.py
@@ -0,0 +1,142 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import random as pyrnd
+import argparse
+import numpy as np
+import mxnet as mx
+from matplotlib import pyplot as plt
+from binary_rbm import BinaryRBMBlock
+from binary_rbm import estimate_log_likelihood
+
+
+### Helper function
+
+def get_non_auxiliary_params(rbm):
+    return rbm.collect_params('^(?!.*_aux_.*).*$')
+
+### Command line arguments
+
+parser = argparse.ArgumentParser(description='Restricted Boltzmann machine learning MNIST')
+parser.add_argument('--num-hidden', type=int, default=500, help='number of hidden units')
+parser.add_argument('--k', type=int, default=30, help='number of Gibbs sampling steps used in the PCD algorithm')
+parser.add_argument('--batch-size', type=int, default=80, help='batch size')
+parser.add_argument('--num-epoch', type=int, default=130, help='number of epochs')
+parser.add_argument('--learning-rate', type=float, default=0.1, help='learning rate for stochastic gradient descent') # The optimizer rescales this with `1 / batch_size`
+parser.add_argument('--momentum', type=float, default=0.3, help='momentum for the stochastic gradient descent')
+parser.add_argument('--ais-batch-size', type=int, default=100, help='batch size for AIS to estimate the log-likelihood')
+parser.add_argument('--ais-num-batch', type=int, default=10, help='number of batches for AIS to estimate the log-likelihood')
+parser.add_argument('--ais-intermediate-steps', type=int, default=10, help='number of intermediate distributions for AIS to estimate the log-likelihood')
+parser.add_argument('--ais-burn-in-steps', type=int, default=10, help='number of burn in steps for each intermediate distributions of AIS to estimate the log-likelihood')
+parser.add_argument('--cuda', action='store_true', dest='cuda', help='train on GPU with CUDA')
+parser.add_argument('--no-cuda', action='store_false', dest='cuda', help='train on CPU')
+parser.add_argument('--device-id', type=int, default=0, help='GPU device id')
+parser.add_argument('--data-loader-num-worker', type=int, default=4, help='number of multithreading workers for the data loader')
+parser.set_defaults(cuda=True)
+
+args = parser.parse_args()
+print(args)
+
+### Global environment
+
+mx.random.seed(pyrnd.getrandbits(32))
+ctx = mx.gpu(args.device_id) if args.cuda else mx.cpu()
+
+
+### Prepare data
+
+def data_transform(data, label):
+    return data.astype(np.float32) / 255, label.astype(np.float32)
+
+mnist_train_dataset = mx.gluon.data.vision.MNIST(train=True, transform=data_transform)
+mnist_test_dataset = mx.gluon.data.vision.MNIST(train=False, transform=data_transform)
+img_height = mnist_train_dataset[0][0].shape[0]
+img_width = mnist_train_dataset[0][0].shape[1]
+num_visible = img_width * img_height
+
+# This generates arrays with shape (batch_size, height = 28, width = 28, num_channel = 1)
+train_data = mx.gluon.data.DataLoader(mnist_train_dataset, args.batch_size, shuffle=True, num_workers=args.data_loader_num_worker)
+test_data = mx.gluon.data.DataLoader(mnist_test_dataset, args.batch_size, shuffle=True, num_workers=args.data_loader_num_worker)
+
+### Train
+
+rbm = BinaryRBMBlock(num_hidden=args.num_hidden, k=args.k, for_training=True, prefix='rbm_')
+rbm.initialize(mx.init.Normal(sigma=.01), ctx=ctx)
+rbm.hybridize()
+trainer = mx.gluon.Trainer(
+    get_non_auxiliary_params(rbm),
+    'sgd', {'learning_rate': args.learning_rate, 'momentum': args.momentum})
+for epoch in range(args.num_epoch):
+    # Update parameters
+    for batch, _ in train_data:
+        batch = batch.as_in_context(ctx).flatten()
+        with mx.autograd.record():
+            out = rbm(batch)
+        out[0].backward()
+        trainer.step(batch.shape[0])
+    mx.nd.waitall() # To restrict memory usage
+
+    # Monitor the performace of the model
+    params = get_non_auxiliary_params(rbm)
+    param_visible_layer_bias = params['rbm_visible_layer_bias'].data(ctx=ctx)
+    param_hidden_layer_bias = params['rbm_hidden_layer_bias'].data(ctx=ctx)
+    param_interaction_weight = params['rbm_interaction_weight'].data(ctx=ctx)
+    test_log_likelihood, _ = estimate_log_likelihood(
+            param_visible_layer_bias, param_hidden_layer_bias, param_interaction_weight,
+            args.ais_batch_size, args.ais_num_batch, args.ais_intermediate_steps, args.ais_burn_in_steps, test_data, ctx)
+    train_log_likelihood, _ = estimate_log_likelihood(
+            param_visible_layer_bias, param_hidden_layer_bias, param_interaction_weight,
+            args.ais_batch_size, args.ais_num_batch, args.ais_intermediate_steps, args.ais_burn_in_steps, train_data, ctx)
+    print("Epoch %d completed with test log-likelihood %f and train log-likelihood %f" % (epoch, test_log_likelihood, train_log_likelihood))
+
+
+### Show some samples.
+
+# Each sample is obtained by 3000 steps of Gibbs sampling starting from a real sample.
+# Starting from the real data is just for convenience of implmentation.
+# There must be no correlation between the initial states and the resulting samples.
+# You can start from random states and run the Gibbs chain for sufficiently long time.
+
+print("Preparing showcase")
+
+showcase_gibbs_sampling_steps = 3000
+showcase_num_samples_w = 15
+showcase_num_samples_h = 15
+showcase_num_samples = showcase_num_samples_w * showcase_num_samples_h
+showcase_img_shape = (showcase_num_samples_h * img_height, 2 * showcase_num_samples_w * img_width)
+showcase_img_column_shape = (showcase_num_samples_h * img_height, img_width)
+
+showcase_rbm = BinaryRBMBlock(
+    num_hidden=args.num_hidden,
+    k=showcase_gibbs_sampling_steps,
+    for_training=False,
+    params=get_non_auxiliary_params(rbm))
+showcase_iter = iter(mx.gluon.data.DataLoader(mnist_train_dataset, showcase_num_samples_h, shuffle=True))
+showcase_img = np.zeros(showcase_img_shape)
+for i in range(showcase_num_samples_w):
+    data_batch = next(showcase_iter)[0].as_in_context(ctx).flatten()
+    sample_batch = showcase_rbm(data_batch)
+    # Each pixel is the probability that the unit is 1.
+    showcase_img[:, i * img_width : (i + 1) * img_width] = data_batch.reshape(showcase_img_column_shape).asnumpy()
+    showcase_img[:, (showcase_num_samples_w + i) * img_width : (showcase_num_samples_w + i + 1) * img_width
+                ] = sample_batch[0].reshape(showcase_img_column_shape).asnumpy()
+s = plt.imshow(showcase_img, cmap='gray')
+plt.axis('off')
+plt.axvline(showcase_num_samples_w * img_width, color='y')
+plt.show(s)
+
+print("Done")
\ No newline at end of file
diff --git a/example/restricted-boltzmann-machine/binary_rbm_module.py b/example/restricted-boltzmann-machine/binary_rbm_module.py
new file mode 100644
index 0000000..e1a3653b
--- /dev/null
+++ b/example/restricted-boltzmann-machine/binary_rbm_module.py
@@ -0,0 +1,171 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import random as pyrnd
+import argparse
+import numpy as np
+import mxnet as mx
+from matplotlib import pyplot as plt
+import binary_rbm
+
+
+### Command line arguments
+
+parser = argparse.ArgumentParser(description='Restricted Boltzmann machine learning MNIST')
+parser.add_argument('--num-hidden', type=int, default=500, help='number of hidden units')
+parser.add_argument('--k', type=int, default=30, help='number of Gibbs sampling steps used in the PCD algorithm')
+parser.add_argument('--batch-size', type=int, default=80, help='batch size')
+parser.add_argument('--num-epoch', type=int, default=130, help='number of epochs')
+parser.add_argument('--learning-rate', type=float, default=0.1, help='learning rate for stochastic gradient descent') # The optimizer rescales this with `1 / batch_size`
+parser.add_argument('--momentum', type=float, default=0.3, help='momentum for the stochastic gradient descent')
+parser.add_argument('--ais-batch-size', type=int, default=100, help='batch size for AIS to estimate the log-likelihood')
+parser.add_argument('--ais-num-batch', type=int, default=10, help='number of batches for AIS to estimate the log-likelihood')
+parser.add_argument('--ais-intermediate-steps', type=int, default=10, help='number of intermediate distributions for AIS to estimate the log-likelihood')
+parser.add_argument('--ais-burn-in-steps', type=int, default=10, help='number of burn in steps for each intermediate distributions of AIS to estimate the log-likelihood')
+parser.add_argument('--cuda', action='store_true', dest='cuda', help='train on GPU with CUDA')
+parser.add_argument('--no-cuda', action='store_false', dest='cuda', help='train on CPU')
+parser.add_argument('--device-id', type=int, default=0, help='GPU device id')
+parser.set_defaults(cuda=True)
+
+args = parser.parse_args()
+print(args)
+
+### Global environment
+
+mx.random.seed(pyrnd.getrandbits(32))
+ctx = mx.gpu(args.device_id) if args.cuda else mx.cpu()
+
+### Prepare data
+
+mnist = mx.test_utils.get_mnist() # Each pixel has a value in [0, 1].
+mnist_train_data = mnist['train_data']
+mnist_test_data = mnist['test_data']
+img_height = mnist_train_data.shape[2]
+img_width = mnist_train_data.shape[3]
+num_visible = img_width * img_height
+
+# The iterators generate arrays with shape (batch_size, num_channel = 1, height = 28, width = 28)
+train_iter = mx.io.NDArrayIter(
+    data={'data': mnist_train_data},
+    batch_size=args.batch_size,
+    shuffle=True)
+test_iter = mx.io.NDArrayIter(
+    data={'data': mnist_test_data},
+    batch_size=args.batch_size,
+    shuffle=True)
+
+
+### Define symbols
+
+data = mx.sym.Variable('data') # (batch_size, num_channel = 1, height, width)
+flattened_data = mx.sym.flatten(data=data) # (batch_size, num_channel * height * width)
+visible_layer_bias = mx.sym.Variable('visible_layer_bias', init=mx.init.Normal(sigma=.01))
+hidden_layer_bias = mx.sym.Variable('hidden_layer_bias', init=mx.init.Normal(sigma=.01))
+interaction_weight = mx.sym.Variable('interaction_weight', init=mx.init.Normal(sigma=.01))
+aux_hidden_layer_sample = mx.sym.Variable('aux_hidden_layer_sample', init=mx.init.Normal(sigma=.01))
+aux_hidden_layer_prob_1 = mx.sym.Variable('aux_hidden_layer_prob_1', init=mx.init.Constant(0))
+
+
+### Train
+
+rbm = mx.sym.Custom(
+    flattened_data,
+    visible_layer_bias,
+    hidden_layer_bias,
+    interaction_weight,
+    aux_hidden_layer_sample,
+    aux_hidden_layer_prob_1,
+    num_hidden=args.num_hidden,
+    k=args.k,
+    for_training=True,
+    op_type='BinaryRBM',
+    name='rbm')
+model = mx.mod.Module(symbol=rbm, context=ctx, data_names=['data'], label_names=None)
+model.bind(data_shapes=train_iter.provide_data)
+model.init_params()
+model.init_optimizer(optimizer='sgd', optimizer_params={'learning_rate': args.learning_rate, 'momentum': args.momentum})
+
+for epoch in range(args.num_epoch):
+    # Update parameters
+    train_iter.reset()
+    for batch in train_iter:
+        model.forward(batch)
+        model.backward()
+        model.update()
+    mx.nd.waitall()
+
+    # Monitor the performace of the model
+    params = model.get_params()[0]
+    param_visible_layer_bias = params['visible_layer_bias'].as_in_context(ctx)
+    param_hidden_layer_bias = params['hidden_layer_bias'].as_in_context(ctx)
+    param_interaction_weight = params['interaction_weight'].as_in_context(ctx)
+    test_iter.reset()
+    test_log_likelihood, _ = binary_rbm.estimate_log_likelihood(
+        param_visible_layer_bias, param_hidden_layer_bias, param_interaction_weight,
+        args.ais_batch_size, args.ais_num_batch, args.ais_intermediate_steps, args.ais_burn_in_steps, test_iter, ctx)
+    train_iter.reset()
+    train_log_likelihood, _ = binary_rbm.estimate_log_likelihood(
+        param_visible_layer_bias, param_hidden_layer_bias, param_interaction_weight,
+        args.ais_batch_size, args.ais_num_batch, args.ais_intermediate_steps, args.ais_burn_in_steps, train_iter, ctx)
+    print("Epoch %d completed with test log-likelihood %f and train log-likelihood %f" % (epoch, test_log_likelihood, train_log_likelihood))
+
+### Show some samples.
+
+# Each sample is obtained by 3000 steps of Gibbs sampling starting from a real sample.
+# Starting from the real data is just for convenience of implmentation.
+# There must be no correlation between the initial states and the resulting samples.
+# You can start from random states and run the Gibbs chain for sufficiently long time.
+
+print("Preparing showcase")
+
+showcase_gibbs_sampling_steps = 3000
+showcase_num_samples_w = 15
+showcase_num_samples_h = 15
+showcase_num_samples = showcase_num_samples_w * showcase_num_samples_h
+showcase_img_shape = (showcase_num_samples_h * img_height, 2 * showcase_num_samples_w * img_width)
+showcase_img_column_shape = (showcase_num_samples_h * img_height, img_width)
+
+params = model.get_params()[0] # We don't need aux states here
+showcase_rbm = mx.sym.Custom(
+    flattened_data,
+    visible_layer_bias,
+    hidden_layer_bias,
+    interaction_weight,
+    num_hidden=args.num_hidden,
+    k=showcase_gibbs_sampling_steps,
+    for_training=False,
+    op_type='BinaryRBM',
+    name='showcase_rbm')
+showcase_iter = mx.io.NDArrayIter(
+    data={'data': mnist['train_data']},
+    batch_size=showcase_num_samples_h,
+    shuffle=True)
+showcase_model = mx.mod.Module(symbol=showcase_rbm, context=ctx, data_names=['data'], label_names=None)
+showcase_model.bind(data_shapes=showcase_iter.provide_data, for_training=False)
+showcase_model.set_params(params, aux_params=None)
+showcase_img = np.zeros(showcase_img_shape)
+for sample_batch, i, data_batch in showcase_model.iter_predict(eval_data=showcase_iter, num_batch=showcase_num_samples_w):
+    # Each pixel is the probability that the unit is 1.
+    showcase_img[:, i * img_width : (i + 1) * img_width] = data_batch.data[0].reshape(showcase_img_column_shape).asnumpy()
+    showcase_img[:, (showcase_num_samples_w + i) * img_width : (showcase_num_samples_w + i + 1) * img_width
+                ] = sample_batch[0].reshape(showcase_img_column_shape).asnumpy()
+s = plt.imshow(showcase_img, cmap='gray')
+plt.axis('off')
+plt.axvline(showcase_num_samples_w * img_width, color='y')
+plt.show(s)
+
+print("Done")
diff --git a/example/restricted-boltzmann-machine/samples.png b/example/restricted-boltzmann-machine/samples.png
new file mode 100644
index 0000000..b266f8e
Binary files /dev/null and b/example/restricted-boltzmann-machine/samples.png differ