vishaalkapoor commented on a change in pull request #13144: [MXNET-1203] Tutorial infogan URL: https://github.com/apache/incubator-mxnet/pull/13144#discussion_r231632468
########## File path: docs/tutorials/gluon/info_gan.md ########## @@ -0,0 +1,438 @@ + +# Image similarity search with InfoGAN + +This notebook shows how to implement an InfoGAN based on Gluon. InfoGAN is an extension of GANs, where the generator input is split in 2 parts: random noise and a latent code c (see [InfoGAN Paper](https://arxiv.org/pdf/1606.03657.pdf)). +The codes are made meaningful by maximizing the mutual information between code and generator output. InfoGAN learns a disentangled representation in a completely unsupervised manner. It can be used for many applications such as image similarity search. This notebook uses the DCGAN example from the [Straight Dope Book](https://gluon.mxnet.io/chapter14_generative-adversarial-networks/dcgan.html) and extends it to create an InfoGAN. + + +```python +from __future__ import print_function +from datetime import datetime +import sys +import os +import logging +import time +import tarfile + +from matplotlib import pyplot as plt +import mxnet as mx +from mxnet import gluon +from mxnet import ndarray as nd +from mxnet.gluon import nn, utils +from mxnet import autograd +from mxboard import SummaryWriter +import numpy as np + +``` + +The latent code vector c can contain several variables, which can be categorical and/or continuous. We set `n_continuous` to 2 and `n_categories` to 10. + + +```python +batch_size = 64 +z_dim = 100 +n_continuous = 2 +n_categories = 10 +ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu() +``` + +Some functions to load and normalize images. + + +```python +lfw_url = 'http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz' +data_path = 'lfw_dataset' +if not os.path.exists(data_path): + os.makedirs(data_path) + data_file = utils.download(lfw_url) + with tarfile.open(data_file) as tar: + tar.extractall(path=data_path) + +``` + + +```python +def transform(data, width=64, height=64): + data = mx.image.imresize(data, width, height) + data = nd.transpose(data, (2,0,1)) + data = data.astype(np.float32)/127.5 - 1 + if data.shape[0] == 1: + data = nd.tile(data, (3, 1, 1)) + return data.reshape((1,) + data.shape) +``` + + +```python +def get_files(data_dir): + images = [] + filenames = [] + for path, _, fnames in os.walk(data_dir): + for fname in fnames: + if not fname.endswith('.jpg'): + continue + img = os.path.join(path, fname) + img_arr = mx.image.imread(img) + img_arr = transform(img_arr) + images.append(img_arr) + filenames.append(path + "/" + fname) Review comment: nit '/' ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
