asitstands commented on a change in pull request #11268: A binary RBM example URL: https://github.com/apache/incubator-mxnet/pull/11268#discussion_r209496649
########## File path: example/restricted-boltzmann-machine/binary_rbm_gluon.py ########## @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import random as pyrnd +import argparse +import numpy as np +import mxnet as mx +from matplotlib import pyplot as plt +from binary_rbm import BinaryRBMBlock +from binary_rbm import estimate_log_likelihood + + +### Helper function + +def get_non_auxiliary_params(rbm): + return rbm.collect_params('^(?!.*_aux_.*).*$') + +### Command line arguments + +parser = argparse.ArgumentParser(description='Restricted Boltzmann machine learning MNIST') +parser.add_argument('--num-hidden', type=int, default=500, help='number of hidden units') +parser.add_argument('--k', type=int, default=30, help='number of Gibbs sampling steps used in the PCD algorithm') +parser.add_argument('--batch-size', type=int, default=80, help='batch size') +parser.add_argument('--num-epoch', type=int, default=130, help='number of epochs') +parser.add_argument('--learning-rate', type=float, default=0.1, help='learning rate for stochastic gradient descent') # The optimizer rescales this with `1 / batch_size` +parser.add_argument('--momentum', type=float, default=0.3, help='momentum for the stochastic gradient descent') +parser.add_argument('--ais-batch-size', type=int, default=100, help='batch size for AIS to estimate the log-likelihood') +parser.add_argument('--ais-num-batch', type=int, default=10, help='number of batches for AIS to estimate the log-likelihood') +parser.add_argument('--ais-intermediate-steps', type=int, default=10, help='number of intermediate distributions for AIS to estimate the log-likelihood') +parser.add_argument('--ais-burn-in-steps', type=int, default=10, help='number of burn in steps for each intermediate distributions of AIS to estimate the log-likelihood') +parser.add_argument('--cuda', action='store_true', dest='cuda', help='train on GPU with CUDA') +parser.add_argument('--no-cuda', action='store_false', dest='cuda', help='train on CPU') +parser.add_argument('--device-id', type=int, default=0, help='GPU device id') +parser.add_argument('--data-loader-num-worker', type=int, default=4, help='number of multithreading workers for the data loader') +parser.set_defaults(cuda=True) + +args = parser.parse_args() +print(args) + +### Global environment + +mx.random.seed(pyrnd.getrandbits(32)) +ctx = mx.gpu(args.device_id) if args.cuda else mx.cpu() + + +### Prepare data + +def data_transform(data, label): + return data.astype(np.float32) / 255, label.astype(np.float32) + +mnist_train_dataset = mx.gluon.data.vision.MNIST(train=True, transform=data_transform) +mnist_test_dataset = mx.gluon.data.vision.MNIST(train=False, transform=data_transform) +img_height = mnist_train_dataset[0][0].shape[0] +img_width = mnist_train_dataset[0][0].shape[1] +num_visible = img_width * img_height + +# This generates arrays with shape (batch_size, height = 28, width = 28, num_channel = 1) +train_data = mx.gluon.data.DataLoader(mnist_train_dataset, args.batch_size, shuffle=True, num_workers=args.data_loader_num_worker) +test_data = mx.gluon.data.DataLoader(mnist_test_dataset, args.batch_size, shuffle=True, num_workers=args.data_loader_num_worker) + +### Train + +rbm = BinaryRBMBlock(num_hidden=args.num_hidden, k=args.k, for_training=True, prefix='rbm_') +rbm.initialize(mx.init.Normal(sigma=.01), ctx=ctx) +rbm.hybridize() +trainer = mx.gluon.Trainer( + get_non_auxiliary_params(rbm), + 'sgd', {'learning_rate': args.learning_rate, 'momentum': args.momentum}) +for epoch in range(args.num_epoch): + # Update parameters + for batch, _ in train_data: + batch = batch.as_in_context(ctx).flatten() + with mx.autograd.record(): + out = rbm(batch) + out[0].backward() + trainer.step(batch.shape[0]) + mx.nd.waitall() # To restrict memory usage + + # Monitor the performace of the model + params = get_non_auxiliary_params(rbm) + param_visible_layer_bias = params['rbm_visible_layer_bias'].data(ctx=ctx) + param_hidden_layer_bias = params['rbm_hidden_layer_bias'].data(ctx=ctx) + param_interaction_weight = params['rbm_interaction_weight'].data(ctx=ctx) + test_log_likelihood, _ = estimate_log_likelihood( + param_visible_layer_bias, param_hidden_layer_bias, param_interaction_weight, + args.ais_batch_size, args.ais_num_batch, args.ais_intermediate_steps, args.ais_burn_in_steps, test_data, ctx) + train_log_likelihood, _ = estimate_log_likelihood( + param_visible_layer_bias, param_hidden_layer_bias, param_interaction_weight, + args.ais_batch_size, args.ais_num_batch, args.ais_intermediate_steps, args.ais_burn_in_steps, train_data, ctx) + print("Epoch %d completed with test log-likelihood %f and train log-likelihood %f" % (epoch, test_log_likelihood, train_log_likelihood)) + + +### Show some samples. + +# Each sample is obtained by 3000 steps of Gibbs sampling starting from a real sample. +# Starting from the real data is just for convenience of implmentation. +# There must be no correlation between the initial states and the resulting samples. Review comment: I'm not sure that the result in the paper you mentioned has how much generality for various generative models (Anyway the paper itself looks interesting). I think the trained RBM hardly define a reducible Gibbs chain. If it does, training an RBM using Gibbs sampling could be very difficult problem. As we increase the running time of the Gibbs sampling in the training, the quality of the resulting model gets better. This is a common observation, but I think it is also unexpected if the chain is reducible. Thus the Gibbs chain may be irreducible or has very limited number of isolated states. Then the initial state and the steady state of the chain must have no direct relation. If they have some relation, that means the running time of the chain is too short. The relaxation to the steady state is exponential except some special cases. To reach the steady state, the chain does not visit each possible image in the uniform probability. Also it does not need to visit a large fraction of the all possible images. The relaxation is exponential but the constant factor could matter in reality. So we start the chain from the real images. Starting from real images is just a trick to reduce the long relaxation time of the random initial state. Anyway it is a matter of the constant factor. It cannot beat the exponential nature of the relaxation. If you get images far from digits in the samples obtained from a *long* Gibbs chain starting from random images, they must finally appear in the samples starting from the real images. So the problem is just how long we can run chain. From the same argument, I think comparing the initial real images and the resulting images is also not appropriate. It could be misleading. If you want to see the quality of the trained RBM, comparing some random samples from the test dataset and the resulting images of the model would be an appropriate way. By the way, there are many papers with good RBM samples starting from random states. For example, the AIS paper you linked also shows examples of the RBM samples starting from randomly initialized MNIST images. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
