sandeep-krishnamurthy closed pull request #12790: Extending the DCGAN example 
implemented by gluon API to provide a more straight-forward evaluation on the 
generated image
URL: https://github.com/apache/incubator-mxnet/pull/12790
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/gluon/DCGAN/README.md b/example/gluon/DCGAN/README.md
new file mode 100644
index 00000000000..5aacd78a3ed
--- /dev/null
+++ b/example/gluon/DCGAN/README.md
@@ -0,0 +1,52 @@
+# DCGAN in MXNet
+
+[Deep Convolutional Generative Adversarial 
Networks(DCGAN)](https://arxiv.org/abs/1511.06434) implementation with Apache 
MXNet GLUON.
+This implementation uses 
[inception_score](https://github.com/openai/improved-gan) to evaluate the model.
+
+You can use this reference implementation on the MNIST and CIFAR-10 datasets.
+
+
+#### Generated image output examples from the CIFAR-10 dataset
+![Generated image output examples from the CIFAR-10 
dataset](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/gluon/DCGAN/fake_img_iter_13900.png)
+
+#### Generated image output examples from the MNIST dataset
+![Generated image output examples from the MNIST 
dataset](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/gluon/DCGAN/fake_img_iter_21700.png)
+
+#### inception_score in cpu and gpu (the real image`s score is around 3.3)
+CPU & GPU
+
+![inception score with 
CPU](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/gluon/DCGAN/inception_score_cifar10_cpu.png)
+![inception score with 
GPU](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/gluon/DCGAN/inception_score_cifar10.png)
+
+## Quick start
+Use the following code to see the configurations you can set:
+```bash
+python dcgan.py -h
+```
+    
+
+    optional arguments:
+      -h, --help            show this help message and exit
+      --dataset DATASET     dataset to use. options are cifar10 and mnist.
+      --batch-size BATCH_SIZE  input batch size, default is 64
+      --nz NZ               size of the latent z vector, default is 100
+      --ngf NGF             the channel of each generator filter layer, 
default is 64.
+      --ndf NDF             the channel of each descriminator filter layer, 
default is 64.
+      --nepoch NEPOCH       number of epochs to train for, default is 25.
+      --niter NITER         save generated images and inception_score per 
niter iters, default is 100.
+      --lr LR               learning rate, default=0.0002
+      --beta1 BETA1         beta1 for adam. default=0.5
+      --cuda                enables cuda
+      --netG NETG           path to netG (to continue training)
+      --netD NETD           path to netD (to continue training)
+      --outf OUTF           folder to output images and model checkpoints
+      --check-point CHECK_POINT
+                            save results at each epoch or not
+      --inception_score INCEPTION_SCORE
+                            To record the inception_score, default is True.
+
+
+Use the following Python script to train a DCGAN model with default 
configurations using the CIFAR-10 dataset and record metrics with 
`inception_score`:
+```bash
+python dcgan.py
+```
diff --git a/example/gluon/DCGAN/__init__.py b/example/gluon/DCGAN/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/example/gluon/DCGAN/dcgan.py b/example/gluon/DCGAN/dcgan.py
new file mode 100644
index 00000000000..970c35d54df
--- /dev/null
+++ b/example/gluon/DCGAN/dcgan.py
@@ -0,0 +1,340 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import matplotlib as mpl
+mpl.use('Agg')
+from matplotlib import pyplot as plt
+
+import argparse
+import mxnet as mx
+from mxnet import gluon
+from mxnet.gluon import nn
+from mxnet import autograd
+import numpy as np
+import logging
+from datetime import datetime
+import os
+import time
+
+from inception_score import get_inception_score
+
+
+def fill_buf(buf, i, img, shape):
+    """
+    Reposition the images generated by the generator so that it can be saved 
as picture matrix.
+    :param buf: the images metric
+    :param i: index of each image
+    :param img: images generated by generator once
+    :param shape: each image`s shape
+    :return: Adjust images for output
+    """
+    n = buf.shape[0]//shape[1]
+    m = buf.shape[1]//shape[0]
+
+    sx = (i%m)*shape[0]
+    sy = (i//m)*shape[1]
+    buf[sy:sy+shape[1], sx:sx+shape[0], :] = img
+    return None
+
+
+def visual(title, X, name):
+    """
+    Image visualization and preservation
+    :param title: title
+    :param X: images to visualized
+    :param name: saved picture`s name
+    :return:
+    """
+    assert len(X.shape) == 4
+    X = X.transpose((0, 2, 3, 1))
+    X = np.clip((X - np.min(X))*(255.0/(np.max(X) - np.min(X))), 0, 
255).astype(np.uint8)
+    n = np.ceil(np.sqrt(X.shape[0]))
+    buff = np.zeros((int(n*X.shape[1]), int(n*X.shape[2]), int(X.shape[3])), 
dtype=np.uint8)
+    for i, img in enumerate(X):
+        fill_buf(buff, i, img, X.shape[1:3])
+    buff = buff[:, :, ::-1]
+    plt.imshow(buff)
+    plt.title(title)
+    plt.savefig(name)
+
+
+parser = argparse.ArgumentParser()
+parser = argparse.ArgumentParser(description='Train a DCgan model for image 
generation '
+                                             'and then use inception_score to 
metric the result.')
+parser.add_argument('--dataset', type=str, default='cifar10', help='dataset to 
use. options are cifar10 and mnist.')
+parser.add_argument('--batch-size', type=int, default=64, help='input batch 
size, default is 64')
+parser.add_argument('--nz', type=int, default=100, help='size of the latent z 
vector, default is 100')
+parser.add_argument('--ngf', type=int, default=64, help='the channel of each 
generator filter layer, default is 64.')
+parser.add_argument('--ndf', type=int, default=64, help='the channel of each 
descriminator filter layer, default is 64.')
+parser.add_argument('--nepoch', type=int, default=25, help='number of epochs 
to train for, default is 25.')
+parser.add_argument('--niter', type=int, default=10, help='save generated 
images and inception_score per niter iters, default is 100.')
+parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, 
default=0.0002')
+parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. 
default=0.5')
+parser.add_argument('--cuda', action='store_true', help='enables cuda')
+parser.add_argument('--netG', default='', help="path to netG (to continue 
training)")
+parser.add_argument('--netD', default='', help="path to netD (to continue 
training)")
+parser.add_argument('--outf', default='./results', help='folder to output 
images and model checkpoints')
+parser.add_argument('--check-point', default=True, help="save results at each 
epoch or not")
+parser.add_argument('--inception_score', type=bool, default=True, help='To 
record the inception_score, default is True.')
+
+opt = parser.parse_args()
+print(opt)
+
+logging.basicConfig(level=logging.DEBUG)
+
+nz = int(opt.nz)
+ngf = int(opt.ngf)
+ndf = int(opt.ndf)
+niter = opt.niter
+nc = 3
+if opt.cuda:
+    ctx = mx.gpu(0)
+else:
+    ctx = mx.cpu()
+batch_size = opt.batch_size
+check_point = bool(opt.check_point)
+outf = opt.outf
+dataset = opt.dataset
+
+if not os.path.exists(outf):
+    os.makedirs(outf)
+
+
+def transformer(data, label):
+    # resize to 64x64
+    data = mx.image.imresize(data, 64, 64)
+    # transpose from (64, 64, 3) to (3, 64, 64)
+    data = mx.nd.transpose(data, (2, 0, 1))
+    # normalize to [-1, 1]
+    data = data.astype(np.float32)/128 - 1
+    # if image is greyscale, repeat 3 times to get RGB image.
+    if data.shape[0] == 1:
+        data = mx.nd.tile(data, (3, 1, 1))
+    return data, label
+
+
+# get dataset with the batch_size num each time
+def get_dataset(dataset):
+    # mnist
+    if dataset == "mnist":
+        train_data = gluon.data.DataLoader(
+            gluon.data.vision.MNIST('./data', train=True, 
transform=transformer),
+            batch_size, shuffle=True, last_batch='discard')
+
+        val_data = gluon.data.DataLoader(
+            gluon.data.vision.MNIST('./data', train=False, 
transform=transformer),
+            batch_size, shuffle=False)
+    # cifar10
+    elif dataset == "cifar10":
+        train_data = gluon.data.DataLoader(
+            gluon.data.vision.CIFAR10('./data', train=True, 
transform=transformer),
+            batch_size, shuffle=True, last_batch='discard')
+
+        val_data = gluon.data.DataLoader(
+            gluon.data.vision.CIFAR10('./data', train=False, 
transform=transformer),
+            batch_size, shuffle=False)
+
+    return train_data, val_data
+
+
+def get_netG():
+    # build the generator
+    netG = nn.Sequential()
+    with netG.name_scope():
+        # input is Z, going into a convolution
+        netG.add(nn.Conv2DTranspose(ngf * 8, 4, 1, 0, use_bias=False))
+        netG.add(nn.BatchNorm())
+        netG.add(nn.Activation('relu'))
+        # state size. (ngf*8) x 4 x 4
+        netG.add(nn.Conv2DTranspose(ngf * 4, 4, 2, 1, use_bias=False))
+        netG.add(nn.BatchNorm())
+        netG.add(nn.Activation('relu'))
+        # state size. (ngf*4) x 8 x 8
+        netG.add(nn.Conv2DTranspose(ngf * 2, 4, 2, 1, use_bias=False))
+        netG.add(nn.BatchNorm())
+        netG.add(nn.Activation('relu'))
+        # state size. (ngf*2) x 16 x 16
+        netG.add(nn.Conv2DTranspose(ngf, 4, 2, 1, use_bias=False))
+        netG.add(nn.BatchNorm())
+        netG.add(nn.Activation('relu'))
+        # state size. (ngf) x 32 x 32
+        netG.add(nn.Conv2DTranspose(nc, 4, 2, 1, use_bias=False))
+        netG.add(nn.Activation('tanh'))
+        # state size. (nc) x 64 x 64
+
+    return netG
+
+
+def get_netD():
+    # build the discriminator
+    netD = nn.Sequential()
+    with netD.name_scope():
+        # input is (nc) x 64 x 64
+        netD.add(nn.Conv2D(ndf, 4, 2, 1, use_bias=False))
+        netD.add(nn.LeakyReLU(0.2))
+        # state size. (ndf) x 32 x 32
+        netD.add(nn.Conv2D(ndf * 2, 4, 2, 1, use_bias=False))
+        netD.add(nn.BatchNorm())
+        netD.add(nn.LeakyReLU(0.2))
+        # state size. (ndf*2) x 16 x 16
+        netD.add(nn.Conv2D(ndf * 4, 4, 2, 1, use_bias=False))
+        netD.add(nn.BatchNorm())
+        netD.add(nn.LeakyReLU(0.2))
+        # state size. (ndf*4) x 8 x 8
+        netD.add(nn.Conv2D(ndf * 8, 4, 2, 1, use_bias=False))
+        netD.add(nn.BatchNorm())
+        netD.add(nn.LeakyReLU(0.2))
+        # state size. (ndf*8) x 4 x 4
+        netD.add(nn.Conv2D(2, 4, 1, 0, use_bias=False))
+        # state size. 2 x 1 x 1
+
+    return netD
+
+
+def get_configurations(netG, netD):
+    # loss
+    loss = gluon.loss.SoftmaxCrossEntropyLoss()
+
+    # initialize the generator and the discriminator
+    netG.initialize(mx.init.Normal(0.02), ctx=ctx)
+    netD.initialize(mx.init.Normal(0.02), ctx=ctx)
+
+    # trainer for the generator and the discriminator
+    trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': 
opt.lr, 'beta1': opt.beta1})
+    trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': 
opt.lr, 'beta1': opt.beta1})
+
+    return loss, trainerG, trainerD
+
+
+def ins_save(inception_score):
+    # draw the inception_score curve
+    length = len(inception_score)
+    x = np.arange(0, length)
+    plt.figure(figsize=(8.0, 6.0))
+    plt.plot(x, inception_score)
+    plt.xlabel("iter/100")
+    plt.ylabel("inception_score")
+    plt.savefig("inception_score.png")
+
+
+# main function
+def main():
+    print("|------- new changes!!!!!!!!!")
+    # to get the dataset and net configuration
+    train_data, val_data = get_dataset(dataset)
+    netG = get_netG()
+    netD = get_netD()
+    loss, trainerG, trainerD = get_configurations(netG, netD)
+
+    # set labels
+    real_label = mx.nd.ones((opt.batch_size,), ctx=ctx)
+    fake_label = mx.nd.zeros((opt.batch_size,), ctx=ctx)
+
+    metric = mx.metric.Accuracy()
+    print('Training... ')
+    stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
+
+    iter = 0
+
+    # to metric the network
+    loss_d = []
+    loss_g = []
+    inception_score = []
+
+    for epoch in range(opt.nepoch):
+        tic = time.time()
+        btic = time.time()
+        for data, _ in train_data:
+            ############################
+            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
+            ###########################
+            # train with real_t
+            data = data.as_in_context(ctx)
+            noise = mx.nd.random.normal(0, 1, shape=(opt.batch_size, nz, 1, 
1), ctx=ctx)
+
+            with autograd.record():
+                output = netD(data)
+                # reshape output from (opt.batch_size, 2, 1, 1) to 
(opt.batch_size, 2)
+                output = output.reshape((opt.batch_size, 2))
+                errD_real = loss(output, real_label)
+
+            metric.update([real_label, ], [output, ])
+
+            with autograd.record():
+                fake = netG(noise)
+                output = netD(fake.detach())
+                output = output.reshape((opt.batch_size, 2))
+                errD_fake = loss(output, fake_label)
+                errD = errD_real + errD_fake
+
+            errD.backward()
+            metric.update([fake_label,], [output,])
+
+            trainerD.step(opt.batch_size)
+
+            ############################
+            # (2) Update G network: maximize log(D(G(z)))
+            ###########################
+            with autograd.record():
+                output = netD(fake)
+                output = output.reshape((-1, 2))
+                errG = loss(output, real_label)
+
+            errG.backward()
+
+            trainerG.step(opt.batch_size)
+
+            name, acc = metric.get()
+            logging.info('discriminator loss = %f, generator loss = %f, binary 
training acc = %f at iter %d epoch %d'
+                         % (mx.nd.mean(errD).asscalar(), 
mx.nd.mean(errG).asscalar(), acc, iter, epoch))
+            if iter % niter == 0:
+                visual('gout', fake.asnumpy(), name=os.path.join(outf, 
'fake_img_iter_%d.png' % iter))
+                visual('data', data.asnumpy(), name=os.path.join(outf, 
'real_img_iter_%d.png' % iter))
+                # record the metric data
+                loss_d.append(errD)
+                loss_g.append(errG)
+                if opt.inception_score:
+                    score, _ = get_inception_score(fake)
+                    inception_score.append(score)
+
+            iter = iter + 1
+            btic = time.time()
+
+        name, acc = metric.get()
+        metric.reset()
+        logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, 
name, acc))
+        logging.info('time: %f' % (time.time() - tic))
+
+        # save check_point
+        if check_point:
+            netG.save_parameters(os.path.join(outf,'generator_epoch_%d.params' 
%epoch))
+            
netD.save_parameters(os.path.join(outf,'discriminator_epoch_%d.params' % epoch))
+
+    # save parameter
+    netG.save_parameters(os.path.join(outf, 'generator.params'))
+    netD.save_parameters(os.path.join(outf, 'discriminator.params'))
+
+    # visualization the inception_score as a picture
+    if opt.inception_score:
+        ins_save(inception_score)
+
+
+if __name__ == '__main__':
+    if opt.inception_score:
+        print("Use inception_score to metric this DCgan model, the reusult is 
save as a picture named \"inception_score.png\"!")
+    main()
+
diff --git a/example/gluon/DCGAN/inception_score.py 
b/example/gluon/DCGAN/inception_score.py
new file mode 100644
index 00000000000..e23513f5055
--- /dev/null
+++ b/example/gluon/DCGAN/inception_score.py
@@ -0,0 +1,110 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from mxnet.gluon.model_zoo import vision as models
+import mxnet as mx
+from mxnet import nd
+import numpy as np
+import math
+import sys
+
+import cv2
+
+
+inception_model = None
+
+
+def get_inception_score(images, splits=10):
+    """
+    Inception_score function.
+        The images will be divided into 'splits' parts, and calculate each 
inception_score separately,
+        then return the mean and std of inception_scores of these parts.
+    :param images: Images(num x c x w x h) that needs to calculate 
inception_score.
+    :param splits:
+    :return: mean and std of inception_score
+    """
+    assert (images.shape[1] == 3)
+
+    # load inception model
+    if inception_model is None:
+        _init_inception()
+
+    # resize images to adapt inception model(inceptionV3)
+    if images.shape[2] != 299:
+        images = resize(images, 299, 299)
+
+    preds = []
+    bs = 4
+    n_batches = int(math.ceil(float(images.shape[0])/float(bs)))
+
+    # to get the predictions/picture of inception model
+    for i in range(n_batches):
+        sys.stdout.write(".")
+        sys.stdout.flush()
+        inps = images[(i * bs):min((i + 1) * bs, len(images))]
+        # inps size. bs x 3 x 299 x 299
+        pred = nd.softmax(inception_model(inps))
+        # pred size. bs x 1000
+        preds.append(pred.asnumpy())
+
+    # list to array
+    preds = np.concatenate(preds, 0)
+    scores = []
+
+    # to calculate the inception_score each split.
+    for i in range(splits):
+        # extract per split image pred
+        part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] 
// splits), :]
+        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 
0)))
+        kl = np.mean(np.sum(kl, 1))
+        scores.append(np.exp(kl))
+
+    return np.mean(scores), np.std(scores)
+
+
+def _init_inception():
+    global inception_model
+    inception_model = models.inception_v3(pretrained=True)
+    print("success import inception model, and the model is inception_v3!")
+
+
+def resize(images, w, h):
+    nums = images.shape[0]
+    res = nd.random.uniform(0, 255, (nums, 3, w, h))
+    for i in range(nums):
+        img = images[i, :, :, :]
+        img = mx.nd.transpose(img, (1, 2, 0))
+        # Replace 'mx.image.imresize()' with 'cv2.resize()' because : Operator 
_cvimresize is not implemented for GPU.
+        # img = mx.image.imresize(img, w, h)
+        img = cv2.resize(img.asnumpy(), (299, 299))
+        img = nd.array(img)
+        img = mx.nd.transpose(img, (2, 0, 1))
+        res[i, :, :, :] = img
+
+    return res
+
+
+if __name__ == '__main__':
+    if inception_model is None:
+        _init_inception()
+    # dummy data
+    images = nd.random.uniform(0, 255, (64, 3, 64, 64))
+    print(images.shape[0])
+    # resize(images,299,299)
+
+    score = get_inception_score(images)
+    print(score)
diff --git a/example/gluon/dcgan.py b/example/gluon/dcgan.py
deleted file mode 100644
index 8ac9c522cf5..00000000000
--- a/example/gluon/dcgan.py
+++ /dev/null
@@ -1,236 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import matplotlib as mpl
-mpl.use('Agg')
-from matplotlib import pyplot as plt
-
-import argparse
-import mxnet as mx
-from mxnet import gluon
-from mxnet.gluon import nn
-from mxnet import autograd
-import numpy as np
-import logging
-from datetime import datetime
-import os
-import time
-
-def fill_buf(buf, i, img, shape):
-    n = buf.shape[0]//shape[1]
-    m = buf.shape[1]//shape[0]
-
-    sx = (i%m)*shape[0]
-    sy = (i//m)*shape[1]
-    buf[sy:sy+shape[1], sx:sx+shape[0], :] = img
-    return None
-
-def visual(title, X, name):
-    assert len(X.shape) == 4
-    X = X.transpose((0, 2, 3, 1))
-    X = np.clip((X - np.min(X))*(255.0/(np.max(X) - np.min(X))), 0, 
255).astype(np.uint8)
-    n = np.ceil(np.sqrt(X.shape[0]))
-    buff = np.zeros((int(n*X.shape[1]), int(n*X.shape[2]), int(X.shape[3])), 
dtype=np.uint8)
-    for i, img in enumerate(X):
-        fill_buf(buff, i, img, X.shape[1:3])
-    buff = buff[:,:,::-1]
-    plt.imshow(buff)
-    plt.title(title)
-    plt.savefig(name)
-
-
-parser = argparse.ArgumentParser()
-parser.add_argument('--dataset', type=str, default='cifar10', help='dataset to 
use. options are cifar10 and imagenet.')
-parser.add_argument('--batch-size', type=int, default=64, help='input batch 
size')
-parser.add_argument('--nz', type=int, default=100, help='size of the latent z 
vector')
-parser.add_argument('--ngf', type=int, default=64)
-parser.add_argument('--ndf', type=int, default=64)
-parser.add_argument('--nepoch', type=int, default=25, help='number of epochs 
to train for')
-parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, 
default=0.0002')
-parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. 
default=0.5')
-parser.add_argument('--cuda', action='store_true', help='enables cuda')
-parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to 
use')
-parser.add_argument('--netG', default='', help="path to netG (to continue 
training)")
-parser.add_argument('--netD', default='', help="path to netD (to continue 
training)")
-parser.add_argument('--outf', default='./results', help='folder to output 
images and model checkpoints')
-parser.add_argument('--check-point', default=True, help="save results at each 
epoch or not")
-
-opt = parser.parse_args()
-print(opt)
-
-logging.basicConfig(level=logging.DEBUG)
-ngpu = int(opt.ngpu)
-nz = int(opt.nz)
-ngf = int(opt.ngf)
-ndf = int(opt.ndf)
-nc = 3
-if opt.cuda:
-    ctx = mx.gpu(0)
-else:
-    ctx = mx.cpu()
-check_point = bool(opt.check_point)
-outf = opt.outf
-
-if not os.path.exists(outf):
-    os.makedirs(outf)
-
-
-def transformer(data, label):
-    # resize to 64x64
-    data = mx.image.imresize(data, 64, 64)
-    # transpose from (64, 64, 3) to (3, 64, 64)
-    data = mx.nd.transpose(data, (2,0,1))
-    # normalize to [-1, 1]
-    data = data.astype(np.float32)/128 - 1
-    # if image is greyscale, repeat 3 times to get RGB image.
-    if data.shape[0] == 1:
-        data = mx.nd.tile(data, (3, 1, 1))
-    return data, label
-
-train_data = gluon.data.DataLoader(
-    gluon.data.vision.MNIST('./data', train=True, transform=transformer),
-    batch_size=opt.batch_size, shuffle=True, last_batch='discard')
-
-val_data = gluon.data.DataLoader(
-    gluon.data.vision.MNIST('./data', train=False, transform=transformer),
-    batch_size=opt.batch_size, shuffle=False)
-
-
-# build the generator
-netG = nn.Sequential()
-with netG.name_scope():
-    # input is Z, going into a convolution
-    netG.add(nn.Conv2DTranspose(ngf * 8, 4, 1, 0, use_bias=False))
-    netG.add(nn.BatchNorm())
-    netG.add(nn.Activation('relu'))
-    # state size. (ngf*8) x 4 x 4
-    netG.add(nn.Conv2DTranspose(ngf * 4, 4, 2, 1, use_bias=False))
-    netG.add(nn.BatchNorm())
-    netG.add(nn.Activation('relu'))
-    # state size. (ngf*8) x 8 x 8
-    netG.add(nn.Conv2DTranspose(ngf * 2, 4, 2, 1, use_bias=False))
-    netG.add(nn.BatchNorm())
-    netG.add(nn.Activation('relu'))
-    # state size. (ngf*8) x 16 x 16
-    netG.add(nn.Conv2DTranspose(ngf, 4, 2, 1, use_bias=False))
-    netG.add(nn.BatchNorm())
-    netG.add(nn.Activation('relu'))
-    # state size. (ngf*8) x 32 x 32
-    netG.add(nn.Conv2DTranspose(nc, 4, 2, 1, use_bias=False))
-    netG.add(nn.Activation('tanh'))
-    # state size. (nc) x 64 x 64
-
-# build the discriminator
-netD = nn.Sequential()
-with netD.name_scope():
-    # input is (nc) x 64 x 64
-    netD.add(nn.Conv2D(ndf, 4, 2, 1, use_bias=False))
-    netD.add(nn.LeakyReLU(0.2))
-    # state size. (ndf) x 32 x 32
-    netD.add(nn.Conv2D(ndf * 2, 4, 2, 1, use_bias=False))
-    netD.add(nn.BatchNorm())
-    netD.add(nn.LeakyReLU(0.2))
-    # state size. (ndf) x 16 x 16
-    netD.add(nn.Conv2D(ndf * 4, 4, 2, 1, use_bias=False))
-    netD.add(nn.BatchNorm())
-    netD.add(nn.LeakyReLU(0.2))
-    # state size. (ndf) x 8 x 8
-    netD.add(nn.Conv2D(ndf * 8, 4, 2, 1, use_bias=False))
-    netD.add(nn.BatchNorm())
-    netD.add(nn.LeakyReLU(0.2))
-    # state size. (ndf) x 4 x 4
-    netD.add(nn.Conv2D(2, 4, 1, 0, use_bias=False))
-
-# loss
-loss = gluon.loss.SoftmaxCrossEntropyLoss()
-
-# initialize the generator and the discriminator
-netG.initialize(mx.init.Normal(0.02), ctx=ctx)
-netD.initialize(mx.init.Normal(0.02), ctx=ctx)
-
-# trainer for the generator and the discriminator
-trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': 
opt.lr, 'beta1': opt.beta1})
-trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': 
opt.lr, 'beta1': opt.beta1})
-
-# ============printing==============
-real_label = mx.nd.ones((opt.batch_size,), ctx=ctx)
-fake_label = mx.nd.zeros((opt.batch_size,), ctx=ctx)
-
-metric = mx.metric.Accuracy()
-print('Training... ')
-stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')
-
-iter = 0
-for epoch in range(opt.nepoch):
-    tic = time.time()
-    btic = time.time()
-    for data, _ in train_data:
-        ############################
-        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
-        ###########################
-        # train with real_t
-        data = data.as_in_context(ctx)
-        noise = mx.nd.random.normal(0, 1, shape=(opt.batch_size, nz, 1, 1), 
ctx=ctx)
-
-        with autograd.record():
-            output = netD(data)
-            output = output.reshape((opt.batch_size, 2))
-            errD_real = loss(output, real_label)
-            metric.update([real_label,], [output,])
-
-            fake = netG(noise)
-            output = netD(fake.detach())
-            output = output.reshape((opt.batch_size, 2))
-            errD_fake = loss(output, fake_label)
-            errD = errD_real + errD_fake
-            errD.backward()
-            metric.update([fake_label,], [output,])
-
-        trainerD.step(opt.batch_size)
-
-        ############################
-        # (2) Update G network: maximize log(D(G(z)))
-        ###########################
-        with autograd.record():
-            output = netD(fake)
-            output = output.reshape((-1, 2))
-            errG = loss(output, real_label)
-            errG.backward()
-
-        trainerG.step(opt.batch_size)
-
-        name, acc = metric.get()
-        # logging.info('speed: {} samples/s'.format(opt.batch_size / 
(time.time() - btic)))
-        logging.info('discriminator loss = %f, generator loss = %f, binary 
training acc = %f at iter %d epoch %d' %(mx.nd.mean(errD).asscalar(), 
mx.nd.mean(errG).asscalar(), acc, iter, epoch))
-        if iter % 1 == 0:
-            visual('gout', fake.asnumpy(), 
name=os.path.join(outf,'fake_img_iter_%d.png' %iter))
-            visual('data', data.asnumpy(), 
name=os.path.join(outf,'real_img_iter_%d.png' %iter))
-
-        iter = iter + 1
-        btic = time.time()
-
-    name, acc = metric.get()
-    metric.reset()
-    logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, 
acc))
-    logging.info('time: %f' % (time.time() - tic))
-
-    if check_point:
-        netG.save_parameters(os.path.join(outf,'generator_epoch_%d.params' 
%epoch))
-        netD.save_parameters(os.path.join(outf,'discriminator_epoch_%d.params' 
% epoch))
-
-netG.save_parameters(os.path.join(outf, 'generator.params'))
-netD.save_parameters(os.path.join(outf, 'discriminator.params'))
diff --git a/example/gluon/sn_gan/data.py b/example/gluon/sn_gan/data.py
index 333125dbe9f..7ed4c38a3b3 100644
--- a/example/gluon/sn_gan/data.py
+++ b/example/gluon/sn_gan/data.py
@@ -17,7 +17,7 @@
 
 # This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
 # https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
-# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py
+# 
https://github.com/apache/incubator-mxnet/blob/master/example/gluon/DCGAN/dcgan.py
 
 import numpy as np
 
diff --git a/example/gluon/sn_gan/model.py b/example/gluon/sn_gan/model.py
index 38f87ebddc8..b714c758788 100644
--- a/example/gluon/sn_gan/model.py
+++ b/example/gluon/sn_gan/model.py
@@ -17,7 +17,7 @@
 
 # This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
 # https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
-# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py
+# 
https://github.com/apache/incubator-mxnet/blob/master/example/gluon/DCGAN/dcgan.py
 
 import mxnet as mx
 from mxnet import nd
diff --git a/example/gluon/sn_gan/train.py b/example/gluon/sn_gan/train.py
index 1cba1f57d0a..f4b9884810c 100644
--- a/example/gluon/sn_gan/train.py
+++ b/example/gluon/sn_gan/train.py
@@ -17,7 +17,7 @@
 
 # This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
 # https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
-# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py
+# 
https://github.com/apache/incubator-mxnet/blob/master/example/gluon/DCGAN/dcgan.py
 
 
 import os
diff --git a/example/gluon/sn_gan/utils.py b/example/gluon/sn_gan/utils.py
index d3f1b8626a1..06c02300bc3 100644
--- a/example/gluon/sn_gan/utils.py
+++ b/example/gluon/sn_gan/utils.py
@@ -17,7 +17,7 @@
 
 # This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
 # https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
-# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py
+# 
https://github.com/apache/incubator-mxnet/blob/master/example/gluon/DCGAN/dcgan.py
 
 import math
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to