Repository: incubator-singa
Updated Branches:
  refs/heads/master 8aac80e42 -> f8cd7e384


SINGA-344 Add a GAN example


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/b1610d75
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/b1610d75
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/b1610d75

Branch: refs/heads/master
Commit: b1610d7576cd58cbc0c989af540c6c64c501585c
Parents: 2224d5f
Author: huangwentao <[email protected]>
Authored: Fri Aug 24 10:16:37 2018 +0800
Committer: huangwentao <[email protected]>
Committed: Fri Aug 24 10:16:37 2018 +0800

----------------------------------------------------------------------
 examples/gan/download_mnist.py |  28 +++++
 examples/gan/lsgan.py          | 213 ++++++++++++++++++++++++++++++++++++
 examples/gan/utils.py          |  67 ++++++++++++
 examples/gan/vanilla.py        | 207 +++++++++++++++++++++++++++++++++++
 4 files changed, 515 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b1610d75/examples/gan/download_mnist.py
----------------------------------------------------------------------
diff --git a/examples/gan/download_mnist.py b/examples/gan/download_mnist.py
new file mode 100644
index 0000000..b042a7c
--- /dev/null
+++ b/examples/gan/download_mnist.py
@@ -0,0 +1,28 @@
+#!/usr/bin/env python
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#     http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# 
+
+import argparse
+from utils import download_data
+
+if __name__ == '__main__':
+       parser = argparse.ArgumentParser(description='download the 
pre-processed MNIST dataset')
+       parser.add_argument('gzfile', type=str, help='the dataset path')
+       parser.add_argument('url', type=str, help='dataset url')        
+       args = parser.parse_args()
+       download_data(args.gzfile, args.url)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b1610d75/examples/gan/lsgan.py
----------------------------------------------------------------------
diff --git a/examples/gan/lsgan.py b/examples/gan/lsgan.py
new file mode 100644
index 0000000..dc6582c
--- /dev/null
+++ b/examples/gan/lsgan.py
@@ -0,0 +1,213 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import device
+from singa import initializer
+from singa import layer
+from singa import loss
+from singa import net as ffnet
+from singa import optimizer
+from singa import tensor
+
+import argparse
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+
+from utils import load_data
+from utils import print_log
+
+class LSGAN():
+       def  __init__(self, dev, rows=28, cols=28, channels=1, noise_size=100, 
hidden_size=128, batch=128, 
+               interval=1000, learning_rate=0.001, epochs=1000000, d_steps=3, 
g_steps=1, 
+               dataset_filepath='mnist.pkl.gz', file_dir='lsgan_images/'):
+               self.dev = dev
+               self.rows = rows
+               self.cols = cols
+               self.channels = channels
+               self.feature_size = self.rows * self.cols * self.channels
+               self.noise_size = noise_size
+               self.hidden_size = hidden_size
+               self.batch = batch
+               self.batch_size = self.batch//2
+               self.interval = interval
+               self.learning_rate = learning_rate
+               self.epochs = epochs
+               self.d_steps = d_steps
+               self.g_steps = g_steps
+               self.dataset_filepath = dataset_filepath
+               self.file_dir = file_dir
+
+               self.g_w0_specs = {'init': 'xavier',}
+               self.g_b0_specs = {'init': 'constant', 'value': 0,}
+               self.g_w1_specs = {'init': 'xavier',}
+               self.g_b1_specs = {'init': 'constant', 'value': 0,}
+               self.gen_net = ffnet.FeedForwardNet(loss.SquaredError(),)
+               self.gen_net_fc_0 = layer.Dense(name='g_fc_0', 
num_output=self.hidden_size, use_bias=True, 
+                       W_specs=self.g_w0_specs, b_specs=self.g_b0_specs, 
input_sample_shape=(self.noise_size,))
+               self.gen_net_relu_0 = layer.Activation(name='g_relu_0', 
mode='relu',input_sample_shape=(self.hidden_size,))
+               self.gen_net_fc_1 = layer.Dense(name='g_fc_1', 
num_output=self.feature_size, use_bias=True, 
+                       W_specs=self.g_w1_specs, b_specs=self.g_b1_specs, 
input_sample_shape=(self.hidden_size,))
+               self.gen_net_sigmoid_1 = layer.Activation(name='g_relu_1', 
mode='sigmoid', input_sample_shape=(self.feature_size,))
+               self.gen_net.add(self.gen_net_fc_0)
+               self.gen_net.add(self.gen_net_relu_0)
+               self.gen_net.add(self.gen_net_fc_1)
+               self.gen_net.add(self.gen_net_sigmoid_1)
+               for (p, specs) in zip(self.gen_net.param_values(), 
self.gen_net.param_specs()):
+                       filler = specs.filler
+                       if filler.type == 'gaussian':
+                               p.gaussian(filler.mean, filler.std)
+                       elif filler.type == 'xavier':
+                               initializer.xavier(p)
+                       else: 
+                               p.set_value(0)
+                       print(specs.name, filler.type, p.l1())  
+               self.gen_net.to_device(self.dev)                
+
+               self.d_w0_specs = {'init': 'xavier',}
+               self.d_b0_specs = {'init': 'constant', 'value': 0,}
+               self.d_w1_specs = {'init': 'xavier',}
+               self.d_b1_specs = {'init': 'constant', 'value': 0,}             
        
+               self.dis_net = ffnet.FeedForwardNet(loss.SquaredError(),)
+               self.dis_net_fc_0 = layer.Dense(name='d_fc_0', 
num_output=self.hidden_size, use_bias=True, 
+                       W_specs=self.d_w0_specs, b_specs=self.d_b0_specs, 
input_sample_shape=(self.feature_size,))
+               self.dis_net_relu_0 = layer.Activation(name='d_relu_0', 
mode='relu',input_sample_shape=(self.hidden_size,))
+               self.dis_net_fc_1 = layer.Dense(name='d_fc_1', num_output=1,  
use_bias=True, 
+                       W_specs=self.d_w1_specs, b_specs=self.d_b1_specs, 
input_sample_shape=(self.hidden_size,))
+               self.dis_net.add(self.dis_net_fc_0)
+               self.dis_net.add(self.dis_net_relu_0)
+               self.dis_net.add(self.dis_net_fc_1)                     
+               for (p, specs) in zip(self.dis_net.param_values(), 
self.dis_net.param_specs()):
+                       filler = specs.filler
+                       if filler.type == 'gaussian':
+                               p.gaussian(filler.mean, filler.std)
+                       elif filler.type == 'xavier':
+                               initializer.xavier(p)
+                       else: 
+                               p.set_value(0)
+                       print(specs.name, filler.type, p.l1())
+               self.dis_net.to_device(self.dev)
+
+               self.combined_net = ffnet.FeedForwardNet(loss.SquaredError(), )
+               for l in self.gen_net.layers:
+                       self.combined_net.add(l)
+               for l in self.dis_net.layers:
+                       self.combined_net.add(l)
+               self.combined_net.to_device(self.dev)
+
+       def train(self):
+               train_data, _, _, _, _, _ = load_data(self.dataset_filepath)
+               opt_0 = optimizer.Adam(lr=self.learning_rate) # optimizer for 
discriminator 
+               opt_1 = optimizer.Adam(lr=self.learning_rate) # optimizer for 
generator, aka the combined model
+               for (p, specs) in zip(self.dis_net.param_names(), 
self.dis_net.param_specs()):
+                       opt_0.register(p, specs)
+               for (p, specs) in zip(self.gen_net.param_names(), 
self.gen_net.param_specs()):
+                       opt_1.register(p, specs)
+
+               for epoch in range(self.epochs):
+                       for d_step in range(self.d_steps):
+                               idx = np.random.randint(0, train_data.shape[0], 
self.batch_size)
+                               real_imgs = train_data[idx]
+                               real_imgs = tensor.from_numpy(real_imgs)
+                               real_imgs.to_device(self.dev)
+                               noise = tensor.Tensor((self.batch_size, 
self.noise_size))
+                               noise.uniform(-1, 1)
+                               noise.to_device(self.dev)
+                               fake_imgs = self.gen_net.forward(flag=False, 
x=noise)
+                               substrahend = 
tensor.Tensor((real_imgs.shape[0], 1))
+                               substrahend.set_value(1.0)
+                               substrahend.to_device(self.dev)
+                               grads, (d_loss_real, _) = 
self.dis_net.train(real_imgs, substrahend)
+                               for (s, p ,g) in 
zip(self.dis_net.param_names(), self.dis_net.param_values(), grads):
+                                       opt_0.apply_with_lr(epoch, 
self.learning_rate, g, p, str(s), epoch)
+                               substrahend.set_value(-1.0)
+                               grads, (d_loss_fake, _) = 
self.dis_net.train(fake_imgs, substrahend)
+                               for (s, p ,g) in 
zip(self.dis_net.param_names(), self.dis_net.param_values(), grads):
+                                       opt_0.apply_with_lr(epoch, 
self.learning_rate, g, p, str(s), epoch)
+                               d_loss = d_loss_real + d_loss_fake
+                       
+                       for g_step in range(self.g_steps): 
+                               noise = tensor.Tensor((self.batch_size, 
self.noise_size))
+                               noise.uniform(-1, 1)
+                               noise.to_device(self.dev)
+                               substrahend = 
tensor.Tensor((real_imgs.shape[0], 1))
+                               substrahend.set_value(0.0)
+                               substrahend.to_device(self.dev)
+                               grads, (g_loss, _) = 
self.combined_net.train(noise, substrahend)
+                               for (s, p ,g) in 
zip(self.gen_net.param_names(), self.gen_net.param_values(), grads):
+                                       opt_1.apply_with_lr(epoch, 
self.learning_rate, g, p, str(s), epoch)
+                       
+                       if epoch % self.interval == 0:
+                               self.save_image(epoch)
+                               print_log('The {} epoch, G_LOSS: {}, D_LOSS: 
{}'.format(epoch, g_loss, d_loss))
+
+       def save_image(self, epoch):
+               rows = 5
+               cols = 5
+               channels = self.channels
+               noise = tensor.Tensor((rows*cols*channels, self.noise_size))
+               noise.uniform(-1,1)
+               noise.to_device(self.dev)
+               gen_imgs = self.gen_net.forward(flag=False, x=noise)
+               gen_imgs = tensor.to_numpy(gen_imgs)
+               show_imgs = np.reshape(gen_imgs, (gen_imgs.shape[0], self.rows, 
self.cols, self.channels))
+               fig, axs = plt.subplots(rows, cols)
+               cnt = 0
+               for r in range(rows):
+                       for c in range(cols):
+                               axs[r,c].imshow(show_imgs[cnt, :, :, 0], 
cmap='gray')
+                               axs[r,c].axis('off')
+                               cnt += 1
+               fig.savefig("{}{}.png".format(self.file_dir, epoch))
+               plt.close()
+
+if __name__ == '__main__':
+       parser = argparse.ArgumentParser(description='Train GAN over MNIST')
+       parser.add_argument('filepath',  type=str, help='the dataset path')
+       parser.add_argument('--use_gpu', action='store_true')
+       args = parser.parse_args()
+       
+       if args.use_gpu:
+               print('Using GPU')
+               dev = device.create_cuda_gpu()
+               layer.engine = 'cudnn'
+       else:
+               print('Using CPU')
+               dev = device.get_default_device()
+               layer.engine = 'singacpp'
+
+       if not os.path.exists('lsgan_images/'):
+               os.makedirs('lsgan_images/')
+
+       rows = 28
+       cols = 28
+       channels = 1
+       noise_size = 100
+       hidden_size = 128
+       batch = 128
+       interval = 1000
+       learning_rate = 0.001
+       epochs = 1000000
+       d_steps = 3
+       g_steps = 1
+       dataset_filepath = 'mnist.pkl.gz'
+       file_dir = 'lsgan_images/'
+       lsgan = LSGAN(dev, rows, cols, channels, noise_size, hidden_size, 
batch, interval, 
+               learning_rate, epochs, d_steps, g_steps, dataset_filepath, 
file_dir)
+       lsgan.train()
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b1610d75/examples/gan/utils.py
----------------------------------------------------------------------
diff --git a/examples/gan/utils.py b/examples/gan/utils.py
new file mode 100644
index 0000000..050d184
--- /dev/null
+++ b/examples/gan/utils.py
@@ -0,0 +1,67 @@
+#!/usr/bin/env python
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#     http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# 
+
+import gzip
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+import pickle
+import sys
+import time
+
+try:
+       import urllib.request as ul_request
+except ImportError:
+       import urllib as ul_request
+
+def print_log(s):
+    t = time.ctime()
+    print('[{}]{}'.format(t, s))
+
+def load_data(filepath):
+       with gzip.open(filepath, 'rb') as f:
+               train_set, valid_set, test_set = pickle.load(f, 
encoding='bytes')
+               traindata = train_set[0].astype(np.float32)
+               validdata = valid_set[0].astype(np.float32)
+               testdata = test_set[0].astype(np.float32)
+               trainlabel = train_set[1].astype(np.float32)
+               validlabel = valid_set[1].astype(np.float32)
+               testlabel = test_set[1].astype(np.float32)
+               return traindata, trainlabel, validdata, validlabel, testdata, 
testlabel
+
+def download_data(gzfile, url):
+       if os.path.exists(gzfile):
+               print('Downloaded already!')
+               sys.exit(0)
+       print('Downloading data %s' % (url))
+       ul_request.urlretrieve(url, gzfile)
+       print('Finished!')
+
+def show_images(filepath):
+       with open(filepath, 'rb') as f:
+               imgs = pickle.load(f)
+               r, c = 5, 5
+               fig, axs = plt.subplots(5, 5)
+               cnt = 0
+               for i in range(r):
+                       for j in range(c):
+                               axs[i, j].imshow(imgs[cnt, :, :, 0], 
cmap='gray')
+                               axs[i, j].axis('off')
+                               cnt += 1
+               plt.show()
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b1610d75/examples/gan/vanilla.py
----------------------------------------------------------------------
diff --git a/examples/gan/vanilla.py b/examples/gan/vanilla.py
new file mode 100644
index 0000000..ce5e048
--- /dev/null
+++ b/examples/gan/vanilla.py
@@ -0,0 +1,207 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import device
+from singa import initializer
+from singa import layer
+from singa import loss
+from singa import net as ffnet
+from singa import optimizer
+from singa import tensor
+
+import argparse
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+
+from utils import load_data
+from utils import print_log
+
+class VANILLA():
+       def  __init__(self, dev, rows=28, cols=28, channels=1, noise_size=100, 
hidden_size=128, batch=128, 
+               interval=1000, learning_rate=0.001, epochs=1000000, 
dataset_filepath='mnist.pkl.gz', file_dir='vanilla_images/'):
+               self.dev = dev
+               self.rows = rows
+               self.cols = cols
+               self.channels = channels
+               self.feature_size = self.rows * self.cols * self.channels
+               self.noise_size = noise_size
+               self.hidden_size = hidden_size
+               self.batch = batch
+               self.batch_size = self.batch//2
+               self.interval = interval
+               self.learning_rate = learning_rate
+               self.epochs = epochs
+               self.dataset_filepath = dataset_filepath
+               self.file_dir = file_dir
+
+               self.g_w0_specs = {'init': 'xavier',}
+               self.g_b0_specs = {'init': 'constant', 'value': 0,}
+               self.g_w1_specs = {'init': 'xavier',}
+               self.g_b1_specs = {'init': 'constant', 'value': 0,}
+               self.gen_net = ffnet.FeedForwardNet(loss.SigmoidCrossEntropy(),)
+               self.gen_net_fc_0 = layer.Dense(name='g_fc_0', 
num_output=self.hidden_size, use_bias=True, 
+                       W_specs=self.g_w0_specs, b_specs=self.g_b0_specs, 
input_sample_shape=(self.noise_size,))
+               self.gen_net_relu_0 = layer.Activation(name='g_relu_0', 
mode='relu',input_sample_shape=(self.hidden_size,))
+               self.gen_net_fc_1 = layer.Dense(name='g_fc_1', 
num_output=self.feature_size, use_bias=True, 
+                       W_specs=self.g_w1_specs, b_specs=self.g_b1_specs, 
input_sample_shape=(self.hidden_size,))
+               self.gen_net_sigmoid_1 = layer.Activation(name='g_relu_1', 
mode='sigmoid', input_sample_shape=(self.feature_size,))
+               self.gen_net.add(self.gen_net_fc_0)
+               self.gen_net.add(self.gen_net_relu_0)
+               self.gen_net.add(self.gen_net_fc_1)
+               self.gen_net.add(self.gen_net_sigmoid_1)
+               for (p, specs) in zip(self.gen_net.param_values(), 
self.gen_net.param_specs()):
+                       filler = specs.filler
+                       if filler.type == 'gaussian':
+                               p.gaussian(filler.mean, filler.std)
+                       elif filler.type == 'xavier':
+                               initializer.xavier(p)
+                       else: 
+                               p.set_value(0)
+                       print(specs.name, filler.type, p.l1())  
+               self.gen_net.to_device(self.dev)                
+
+               self.d_w0_specs = {'init': 'xavier',}
+               self.d_b0_specs = {'init': 'constant', 'value': 0,}
+               self.d_w1_specs = {'init': 'xavier',}
+               self.d_b1_specs = {'init': 'constant', 'value': 0,}             
        
+               self.dis_net = ffnet.FeedForwardNet(loss.SigmoidCrossEntropy(),)
+               self.dis_net_fc_0 = layer.Dense(name='d_fc_0', 
num_output=self.hidden_size, use_bias=True, 
+                       W_specs=self.d_w0_specs, b_specs=self.d_b0_specs, 
input_sample_shape=(self.feature_size,))
+               self.dis_net_relu_0 = layer.Activation(name='d_relu_0', 
mode='relu',input_sample_shape=(self.hidden_size,))
+               self.dis_net_fc_1 = layer.Dense(name='d_fc_1', num_output=1,  
use_bias=True, 
+                       W_specs=self.d_w1_specs, b_specs=self.d_b1_specs, 
input_sample_shape=(self.hidden_size,))
+               self.dis_net.add(self.dis_net_fc_0)
+               self.dis_net.add(self.dis_net_relu_0)
+               self.dis_net.add(self.dis_net_fc_1)                     
+               for (p, specs) in zip(self.dis_net.param_values(), 
self.dis_net.param_specs()):
+                       filler = specs.filler
+                       if filler.type == 'gaussian':
+                               p.gaussian(filler.mean, filler.std)
+                       elif filler.type == 'xavier':
+                               initializer.xavier(p)
+                       else: 
+                               p.set_value(0)
+                       print(specs.name, filler.type, p.l1())
+               self.dis_net.to_device(self.dev)
+
+               self.combined_net = 
ffnet.FeedForwardNet(loss.SigmoidCrossEntropy(), )
+               for l in self.gen_net.layers:
+                       self.combined_net.add(l)
+               for l in self.dis_net.layers:
+                       self.combined_net.add(l)
+               self.combined_net.to_device(self.dev)
+
+       def train(self):
+               train_data, _, _, _, _, _ = load_data(self.dataset_filepath)
+               opt_0 = optimizer.Adam(lr=self.learning_rate) # optimizer for 
discriminator 
+               opt_1 = optimizer.Adam(lr=self.learning_rate) # optimizer for 
generator, aka the combined model
+               for (p, specs) in zip(self.dis_net.param_names(), 
self.dis_net.param_specs()):
+                       opt_0.register(p, specs)
+               for (p, specs) in zip(self.gen_net.param_names(), 
self.gen_net.param_specs()):
+                       opt_1.register(p, specs)
+
+               for epoch in range(self.epochs):
+                       idx = np.random.randint(0, train_data.shape[0], 
self.batch_size)
+                       real_imgs = train_data[idx]
+                       real_imgs = tensor.from_numpy(real_imgs)
+                       real_imgs.to_device(self.dev)
+                       noise = tensor.Tensor((self.batch_size, 
self.noise_size))
+                       noise.uniform(-1, 1)
+                       noise.to_device(self.dev)
+                       fake_imgs = self.gen_net.forward(flag=False, x=noise)
+                       real_labels = tensor.Tensor((self.batch_size, 1))
+                       fake_labels = tensor.Tensor((self.batch_size, 1))
+                       real_labels.set_value(1.0)
+                       fake_labels.set_value(0.0)
+                       real_labels.to_device(self.dev)
+                       fake_labels.to_device(self.dev)
+                       grads, (d_loss_real, _) = self.dis_net.train(real_imgs, 
real_labels)
+                       for (s, p ,g) in zip(self.dis_net.param_names(), 
self.dis_net.param_values(), grads):
+                               opt_0.apply_with_lr(epoch, self.learning_rate, 
g, p, str(s), epoch)
+                       grads, (d_loss_fake, _) = self.dis_net.train(fake_imgs, 
fake_labels)
+                       for (s, p ,g) in zip(self.dis_net.param_names(), 
self.dis_net.param_values(), grads):
+                               opt_0.apply_with_lr(epoch, self.learning_rate, 
g, p, str(s), epoch)
+                       d_loss = d_loss_real + d_loss_fake
+                       noise = tensor.Tensor((self.batch_size, 
self.noise_size))
+                       noise.uniform(-1,1)
+                       noise.to_device(self.dev)
+                       real_labels = tensor.Tensor((self.batch_size, 1))
+                       real_labels.set_value(1.0)
+                       real_labels.to_device(self.dev)
+                       grads, (g_loss, _) = self.combined_net.train(noise, 
real_labels)
+                       for (s, p ,g) in zip(self.gen_net.param_names(), 
self.gen_net.param_values(), grads):
+                               opt_1.apply_with_lr(epoch, self.learning_rate, 
g, p, str(s), epoch)
+                       
+                       if epoch % self.interval == 0:
+                               self.save_image(epoch)
+                               print_log('The {} epoch, G_LOSS: {}, D_LOSS: 
{}'.format(epoch, g_loss, d_loss))
+
+       def save_image(self, epoch):
+               rows = 5
+               cols = 5
+               channels = self.channels
+               noise = tensor.Tensor((rows*cols*channels, self.noise_size))
+               noise.uniform(-1, 1)
+               noise.to_device(self.dev)
+               gen_imgs = self.gen_net.forward(flag=False, x=noise)
+               gen_imgs = tensor.to_numpy(gen_imgs)
+               show_imgs = np.reshape(gen_imgs, (gen_imgs.shape[0], self.rows, 
self.cols, self.channels))
+               fig, axs = plt.subplots(rows, cols)
+               cnt = 0
+               for r in range(rows):
+                       for c in range(cols):
+                               axs[r,c].imshow(show_imgs[cnt, :, :, 0], 
cmap='gray')
+                               axs[r,c].axis('off')
+                               cnt += 1
+               fig.savefig("{}{}.png".format(self.file_dir, epoch))
+               plt.close()
+
+if __name__ == '__main__':
+       parser = argparse.ArgumentParser(description='Train GAN over MNIST')
+       parser.add_argument('filepath',  type=str, help='the dataset path')
+       parser.add_argument('--use_gpu', action='store_true')
+       args = parser.parse_args()
+       
+       if args.use_gpu:
+               print('Using GPU')
+               dev = device.create_cuda_gpu()
+               layer.engine = 'cudnn'
+       else:
+               print('Using CPU')
+               dev = device.get_default_device()
+               layer.engine = 'singacpp'
+
+       if not os.path.exists('vanilla_images/'):
+               os.makedirs('vanilla_images/')
+
+       rows = 28
+       cols = 28
+       channels = 1
+       noise_size = 100
+       hidden_size = 128
+       batch = 128
+       interval = 1000
+       learning_rate = 0.001
+       epochs = 1000000
+       dataset_filepath = 'mnist.pkl.gz'
+       file_dir = 'vanilla_images/'
+       vanilla = VANILLA(dev, rows, cols, channels, noise_size, hidden_size, 
batch, 
+               interval, learning_rate, epochs, dataset_filepath, file_dir)
+       vanilla.train()
\ No newline at end of file

Reply via email to