You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/13 04:36:13 UTC

[GitHub] asitstands commented on a change in pull request #11268: A binary RBM example

asitstands commented on a change in pull request #11268: A binary RBM example
URL: https://github.com/apache/incubator-mxnet/pull/11268#discussion_r209490146
 
 

 ##########
 File path: example/restricted-boltzmann-machine/binary_rbm_gluon.py
 ##########
 @@ -0,0 +1,139 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import random as pyrnd
+import argparse
+import numpy as np
+import mxnet as mx
+from matplotlib import pyplot as plt
+from binary_rbm import BinaryRBMBlock
+from binary_rbm import estimate_log_likelihood
+
+
+### Helper function
+
+def get_non_auxiliary_params(rbm):
+    return rbm.collect_params('^(?!.*_aux_.*).*$')
+
+### Command line arguments
+
+parser = argparse.ArgumentParser(description='Restricted Boltzmann machine learning MNIST')
+parser.add_argument('--num-hidden', type=int, default=500, help='number of hidden units')
+parser.add_argument('--k', type=int, default=30, help='number of Gibbs sampling steps used in the PCD algorithm')
+parser.add_argument('--batch-size', type=int, default=80, help='batch size')
+parser.add_argument('--num-epoch', type=int, default=130, help='number of epochs')
+parser.add_argument('--learning-rate', type=float, default=0.1, help='learning rate for stochastic gradient descent') # The optimizer rescales this with `1 / batch_size`
+parser.add_argument('--momentum', type=float, default=0.3, help='momentum for the stochastic gradient descent')
+parser.add_argument('--ais-batch-size', type=int, default=100, help='batch size for AIS to estimate the log-likelihood')
+parser.add_argument('--ais-num-batch', type=int, default=10, help='number of batches for AIS to estimate the log-likelihood')
+parser.add_argument('--ais-intermediate-steps', type=int, default=10, help='number of intermediate distributions for AIS to estimate the log-likelihood')
+parser.add_argument('--ais-burn-in-steps', type=int, default=10, help='number of burn in steps for each intermediate distributions of AIS to estimate the log-likelihood')
+parser.add_argument('--cuda', action='store_true', dest='cuda', help='train on GPU with CUDA')
+parser.add_argument('--no-cuda', action='store_false', dest='cuda', help='train on CPU')
+parser.add_argument('--device-id', type=int, default=0, help='GPU device id')
+parser.add_argument('--data-loader-num-worker', type=int, default=4, help='number of multithreading workers for the data loader')
+parser.set_defaults(cuda=True)
+
+args = parser.parse_args()
+print(args)
+
+### Global environment
+
+mx.random.seed(pyrnd.getrandbits(32))
+ctx = mx.gpu(args.device_id) if args.cuda else mx.cpu()
+
+
+### Prepare data
+
+def data_transform(data, label):
+    return data.astype(np.float32) / 255, label.astype(np.float32)
+
+mnist_train_dataset = mx.gluon.data.vision.MNIST(train=True, transform=data_transform)
+mnist_test_dataset = mx.gluon.data.vision.MNIST(train=False, transform=data_transform)
+img_height = mnist_train_dataset[0][0].shape[0]
+img_width = mnist_train_dataset[0][0].shape[1]
+num_visible = img_width * img_height
+
+# This generates arrays with shape (batch_size, height = 28, width = 28, num_channel = 1)
+train_data = mx.gluon.data.DataLoader(mnist_train_dataset, args.batch_size, shuffle=True, num_workers=args.data_loader_num_worker)
+test_data = mx.gluon.data.DataLoader(mnist_test_dataset, args.batch_size, shuffle=True, num_workers=args.data_loader_num_worker)
+
+### Train
+
+rbm = BinaryRBMBlock(num_hidden=args.num_hidden, k=args.k, for_training=True, prefix='rbm_')
+rbm.initialize(mx.init.Normal(sigma=.01), ctx=ctx)
+rbm.hybridize()
+trainer = mx.gluon.Trainer(
+    get_non_auxiliary_params(rbm),
+    'sgd', {'learning_rate': args.learning_rate, 'momentum': args.momentum})
+for epoch in range(args.num_epoch):
+    # Update parameters
+    for batch, _ in train_data:
+        batch = batch.as_in_context(ctx).flatten()
+        with mx.autograd.record():
+            out = rbm(batch)
+        out[0].backward()
+        trainer.step(batch.shape[0])
+    mx.nd.waitall() # To restrict memory usage
+
+    # Monitor the performace of the model
+    params = get_non_auxiliary_params(rbm)
+    param_visible_layer_bias = params['rbm_visible_layer_bias'].data(ctx=ctx)
+    param_hidden_layer_bias = params['rbm_hidden_layer_bias'].data(ctx=ctx)
+    param_interaction_weight = params['rbm_interaction_weight'].data(ctx=ctx)
+    test_log_likelihood, _ = estimate_log_likelihood(
+            param_visible_layer_bias, param_hidden_layer_bias, param_interaction_weight,
+            args.ais_batch_size, args.ais_num_batch, args.ais_intermediate_steps, args.ais_burn_in_steps, test_data, ctx)
+    train_log_likelihood, _ = estimate_log_likelihood(
+            param_visible_layer_bias, param_hidden_layer_bias, param_interaction_weight,
+            args.ais_batch_size, args.ais_num_batch, args.ais_intermediate_steps, args.ais_burn_in_steps, train_data, ctx)
+    print("Epoch %d completed with test log-likelihood %f and train log-likelihood %f" % (epoch, test_log_likelihood, train_log_likelihood))
+
+
+### Show some samples.
+
+# Each sample is obtained by 3000 steps of Gibbs sampling starting from a real sample.
+# Starting from the real data is just for convenience of implmentation.
+# There must be no correlation between the initial states and the resulting samples.
 
 Review comment:
   Why irrelevant? I chose starting from real images insted of random images, but starting from a random images is also a common practice. Mentioning an alternative choice would be helpful here. No correlation between the initial state and the the steady state is also a general property of a Markov chain.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services