You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2018/08/27 14:31:16 UTC
[1/2] incubator-singa git commit: SINGA-344 Add a GAN example
Repository: incubator-singa
Updated Branches:
refs/heads/master 8aac80e42 -> f8cd7e384
SINGA-344 Add a GAN example
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/b1610d75
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/b1610d75
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/b1610d75
Branch: refs/heads/master
Commit: b1610d7576cd58cbc0c989af540c6c64c501585c
Parents: 2224d5f
Author: huangwentao <we...@ruc.edu.cn>
Authored: Fri Aug 24 10:16:37 2018 +0800
Committer: huangwentao <we...@ruc.edu.cn>
Committed: Fri Aug 24 10:16:37 2018 +0800
----------------------------------------------------------------------
examples/gan/download_mnist.py | 28 +++++
examples/gan/lsgan.py | 213 ++++++++++++++++++++++++++++++++++++
examples/gan/utils.py | 67 ++++++++++++
examples/gan/vanilla.py | 207 +++++++++++++++++++++++++++++++++++
4 files changed, 515 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b1610d75/examples/gan/download_mnist.py
----------------------------------------------------------------------
diff --git a/examples/gan/download_mnist.py b/examples/gan/download_mnist.py
new file mode 100644
index 0000000..b042a7c
--- /dev/null
+++ b/examples/gan/download_mnist.py
@@ -0,0 +1,28 @@
+#!/usr/bin/env python
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import argparse
+from utils import download_data
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='download the pre-processed MNIST dataset')
+ parser.add_argument('gzfile', type=str, help='the dataset path')
+ parser.add_argument('url', type=str, help='dataset url')
+ args = parser.parse_args()
+ download_data(args.gzfile, args.url)
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b1610d75/examples/gan/lsgan.py
----------------------------------------------------------------------
diff --git a/examples/gan/lsgan.py b/examples/gan/lsgan.py
new file mode 100644
index 0000000..dc6582c
--- /dev/null
+++ b/examples/gan/lsgan.py
@@ -0,0 +1,213 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import device
+from singa import initializer
+from singa import layer
+from singa import loss
+from singa import net as ffnet
+from singa import optimizer
+from singa import tensor
+
+import argparse
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+
+from utils import load_data
+from utils import print_log
+
+class LSGAN():
+ def __init__(self, dev, rows=28, cols=28, channels=1, noise_size=100, hidden_size=128, batch=128,
+ interval=1000, learning_rate=0.001, epochs=1000000, d_steps=3, g_steps=1,
+ dataset_filepath='mnist.pkl.gz', file_dir='lsgan_images/'):
+ self.dev = dev
+ self.rows = rows
+ self.cols = cols
+ self.channels = channels
+ self.feature_size = self.rows * self.cols * self.channels
+ self.noise_size = noise_size
+ self.hidden_size = hidden_size
+ self.batch = batch
+ self.batch_size = self.batch//2
+ self.interval = interval
+ self.learning_rate = learning_rate
+ self.epochs = epochs
+ self.d_steps = d_steps
+ self.g_steps = g_steps
+ self.dataset_filepath = dataset_filepath
+ self.file_dir = file_dir
+
+ self.g_w0_specs = {'init': 'xavier',}
+ self.g_b0_specs = {'init': 'constant', 'value': 0,}
+ self.g_w1_specs = {'init': 'xavier',}
+ self.g_b1_specs = {'init': 'constant', 'value': 0,}
+ self.gen_net = ffnet.FeedForwardNet(loss.SquaredError(),)
+ self.gen_net_fc_0 = layer.Dense(name='g_fc_0', num_output=self.hidden_size, use_bias=True,
+ W_specs=self.g_w0_specs, b_specs=self.g_b0_specs, input_sample_shape=(self.noise_size,))
+ self.gen_net_relu_0 = layer.Activation(name='g_relu_0', mode='relu',input_sample_shape=(self.hidden_size,))
+ self.gen_net_fc_1 = layer.Dense(name='g_fc_1', num_output=self.feature_size, use_bias=True,
+ W_specs=self.g_w1_specs, b_specs=self.g_b1_specs, input_sample_shape=(self.hidden_size,))
+ self.gen_net_sigmoid_1 = layer.Activation(name='g_relu_1', mode='sigmoid', input_sample_shape=(self.feature_size,))
+ self.gen_net.add(self.gen_net_fc_0)
+ self.gen_net.add(self.gen_net_relu_0)
+ self.gen_net.add(self.gen_net_fc_1)
+ self.gen_net.add(self.gen_net_sigmoid_1)
+ for (p, specs) in zip(self.gen_net.param_values(), self.gen_net.param_specs()):
+ filler = specs.filler
+ if filler.type == 'gaussian':
+ p.gaussian(filler.mean, filler.std)
+ elif filler.type == 'xavier':
+ initializer.xavier(p)
+ else:
+ p.set_value(0)
+ print(specs.name, filler.type, p.l1())
+ self.gen_net.to_device(self.dev)
+
+ self.d_w0_specs = {'init': 'xavier',}
+ self.d_b0_specs = {'init': 'constant', 'value': 0,}
+ self.d_w1_specs = {'init': 'xavier',}
+ self.d_b1_specs = {'init': 'constant', 'value': 0,}
+ self.dis_net = ffnet.FeedForwardNet(loss.SquaredError(),)
+ self.dis_net_fc_0 = layer.Dense(name='d_fc_0', num_output=self.hidden_size, use_bias=True,
+ W_specs=self.d_w0_specs, b_specs=self.d_b0_specs, input_sample_shape=(self.feature_size,))
+ self.dis_net_relu_0 = layer.Activation(name='d_relu_0', mode='relu',input_sample_shape=(self.hidden_size,))
+ self.dis_net_fc_1 = layer.Dense(name='d_fc_1', num_output=1, use_bias=True,
+ W_specs=self.d_w1_specs, b_specs=self.d_b1_specs, input_sample_shape=(self.hidden_size,))
+ self.dis_net.add(self.dis_net_fc_0)
+ self.dis_net.add(self.dis_net_relu_0)
+ self.dis_net.add(self.dis_net_fc_1)
+ for (p, specs) in zip(self.dis_net.param_values(), self.dis_net.param_specs()):
+ filler = specs.filler
+ if filler.type == 'gaussian':
+ p.gaussian(filler.mean, filler.std)
+ elif filler.type == 'xavier':
+ initializer.xavier(p)
+ else:
+ p.set_value(0)
+ print(specs.name, filler.type, p.l1())
+ self.dis_net.to_device(self.dev)
+
+ self.combined_net = ffnet.FeedForwardNet(loss.SquaredError(), )
+ for l in self.gen_net.layers:
+ self.combined_net.add(l)
+ for l in self.dis_net.layers:
+ self.combined_net.add(l)
+ self.combined_net.to_device(self.dev)
+
+ def train(self):
+ train_data, _, _, _, _, _ = load_data(self.dataset_filepath)
+ opt_0 = optimizer.Adam(lr=self.learning_rate) # optimizer for discriminator
+ opt_1 = optimizer.Adam(lr=self.learning_rate) # optimizer for generator, aka the combined model
+ for (p, specs) in zip(self.dis_net.param_names(), self.dis_net.param_specs()):
+ opt_0.register(p, specs)
+ for (p, specs) in zip(self.gen_net.param_names(), self.gen_net.param_specs()):
+ opt_1.register(p, specs)
+
+ for epoch in range(self.epochs):
+ for d_step in range(self.d_steps):
+ idx = np.random.randint(0, train_data.shape[0], self.batch_size)
+ real_imgs = train_data[idx]
+ real_imgs = tensor.from_numpy(real_imgs)
+ real_imgs.to_device(self.dev)
+ noise = tensor.Tensor((self.batch_size, self.noise_size))
+ noise.uniform(-1, 1)
+ noise.to_device(self.dev)
+ fake_imgs = self.gen_net.forward(flag=False, x=noise)
+ substrahend = tensor.Tensor((real_imgs.shape[0], 1))
+ substrahend.set_value(1.0)
+ substrahend.to_device(self.dev)
+ grads, (d_loss_real, _) = self.dis_net.train(real_imgs, substrahend)
+ for (s, p ,g) in zip(self.dis_net.param_names(), self.dis_net.param_values(), grads):
+ opt_0.apply_with_lr(epoch, self.learning_rate, g, p, str(s), epoch)
+ substrahend.set_value(-1.0)
+ grads, (d_loss_fake, _) = self.dis_net.train(fake_imgs, substrahend)
+ for (s, p ,g) in zip(self.dis_net.param_names(), self.dis_net.param_values(), grads):
+ opt_0.apply_with_lr(epoch, self.learning_rate, g, p, str(s), epoch)
+ d_loss = d_loss_real + d_loss_fake
+
+ for g_step in range(self.g_steps):
+ noise = tensor.Tensor((self.batch_size, self.noise_size))
+ noise.uniform(-1, 1)
+ noise.to_device(self.dev)
+ substrahend = tensor.Tensor((real_imgs.shape[0], 1))
+ substrahend.set_value(0.0)
+ substrahend.to_device(self.dev)
+ grads, (g_loss, _) = self.combined_net.train(noise, substrahend)
+ for (s, p ,g) in zip(self.gen_net.param_names(), self.gen_net.param_values(), grads):
+ opt_1.apply_with_lr(epoch, self.learning_rate, g, p, str(s), epoch)
+
+ if epoch % self.interval == 0:
+ self.save_image(epoch)
+ print_log('The {} epoch, G_LOSS: {}, D_LOSS: {}'.format(epoch, g_loss, d_loss))
+
+ def save_image(self, epoch):
+ rows = 5
+ cols = 5
+ channels = self.channels
+ noise = tensor.Tensor((rows*cols*channels, self.noise_size))
+ noise.uniform(-1,1)
+ noise.to_device(self.dev)
+ gen_imgs = self.gen_net.forward(flag=False, x=noise)
+ gen_imgs = tensor.to_numpy(gen_imgs)
+ show_imgs = np.reshape(gen_imgs, (gen_imgs.shape[0], self.rows, self.cols, self.channels))
+ fig, axs = plt.subplots(rows, cols)
+ cnt = 0
+ for r in range(rows):
+ for c in range(cols):
+ axs[r,c].imshow(show_imgs[cnt, :, :, 0], cmap='gray')
+ axs[r,c].axis('off')
+ cnt += 1
+ fig.savefig("{}{}.png".format(self.file_dir, epoch))
+ plt.close()
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Train GAN over MNIST')
+ parser.add_argument('filepath', type=str, help='the dataset path')
+ parser.add_argument('--use_gpu', action='store_true')
+ args = parser.parse_args()
+
+ if args.use_gpu:
+ print('Using GPU')
+ dev = device.create_cuda_gpu()
+ layer.engine = 'cudnn'
+ else:
+ print('Using CPU')
+ dev = device.get_default_device()
+ layer.engine = 'singacpp'
+
+ if not os.path.exists('lsgan_images/'):
+ os.makedirs('lsgan_images/')
+
+ rows = 28
+ cols = 28
+ channels = 1
+ noise_size = 100
+ hidden_size = 128
+ batch = 128
+ interval = 1000
+ learning_rate = 0.001
+ epochs = 1000000
+ d_steps = 3
+ g_steps = 1
+ dataset_filepath = 'mnist.pkl.gz'
+ file_dir = 'lsgan_images/'
+ lsgan = LSGAN(dev, rows, cols, channels, noise_size, hidden_size, batch, interval,
+ learning_rate, epochs, d_steps, g_steps, dataset_filepath, file_dir)
+ lsgan.train()
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b1610d75/examples/gan/utils.py
----------------------------------------------------------------------
diff --git a/examples/gan/utils.py b/examples/gan/utils.py
new file mode 100644
index 0000000..050d184
--- /dev/null
+++ b/examples/gan/utils.py
@@ -0,0 +1,67 @@
+#!/usr/bin/env python
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import gzip
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+import pickle
+import sys
+import time
+
+try:
+ import urllib.request as ul_request
+except ImportError:
+ import urllib as ul_request
+
+def print_log(s):
+ t = time.ctime()
+ print('[{}]{}'.format(t, s))
+
+def load_data(filepath):
+ with gzip.open(filepath, 'rb') as f:
+ train_set, valid_set, test_set = pickle.load(f, encoding='bytes')
+ traindata = train_set[0].astype(np.float32)
+ validdata = valid_set[0].astype(np.float32)
+ testdata = test_set[0].astype(np.float32)
+ trainlabel = train_set[1].astype(np.float32)
+ validlabel = valid_set[1].astype(np.float32)
+ testlabel = test_set[1].astype(np.float32)
+ return traindata, trainlabel, validdata, validlabel, testdata, testlabel
+
+def download_data(gzfile, url):
+ if os.path.exists(gzfile):
+ print('Downloaded already!')
+ sys.exit(0)
+ print('Downloading data %s' % (url))
+ ul_request.urlretrieve(url, gzfile)
+ print('Finished!')
+
+def show_images(filepath):
+ with open(filepath, 'rb') as f:
+ imgs = pickle.load(f)
+ r, c = 5, 5
+ fig, axs = plt.subplots(5, 5)
+ cnt = 0
+ for i in range(r):
+ for j in range(c):
+ axs[i, j].imshow(imgs[cnt, :, :, 0], cmap='gray')
+ axs[i, j].axis('off')
+ cnt += 1
+ plt.show()
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b1610d75/examples/gan/vanilla.py
----------------------------------------------------------------------
diff --git a/examples/gan/vanilla.py b/examples/gan/vanilla.py
new file mode 100644
index 0000000..ce5e048
--- /dev/null
+++ b/examples/gan/vanilla.py
@@ -0,0 +1,207 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import device
+from singa import initializer
+from singa import layer
+from singa import loss
+from singa import net as ffnet
+from singa import optimizer
+from singa import tensor
+
+import argparse
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+
+from utils import load_data
+from utils import print_log
+
+class VANILLA():
+ def __init__(self, dev, rows=28, cols=28, channels=1, noise_size=100, hidden_size=128, batch=128,
+ interval=1000, learning_rate=0.001, epochs=1000000, dataset_filepath='mnist.pkl.gz', file_dir='vanilla_images/'):
+ self.dev = dev
+ self.rows = rows
+ self.cols = cols
+ self.channels = channels
+ self.feature_size = self.rows * self.cols * self.channels
+ self.noise_size = noise_size
+ self.hidden_size = hidden_size
+ self.batch = batch
+ self.batch_size = self.batch//2
+ self.interval = interval
+ self.learning_rate = learning_rate
+ self.epochs = epochs
+ self.dataset_filepath = dataset_filepath
+ self.file_dir = file_dir
+
+ self.g_w0_specs = {'init': 'xavier',}
+ self.g_b0_specs = {'init': 'constant', 'value': 0,}
+ self.g_w1_specs = {'init': 'xavier',}
+ self.g_b1_specs = {'init': 'constant', 'value': 0,}
+ self.gen_net = ffnet.FeedForwardNet(loss.SigmoidCrossEntropy(),)
+ self.gen_net_fc_0 = layer.Dense(name='g_fc_0', num_output=self.hidden_size, use_bias=True,
+ W_specs=self.g_w0_specs, b_specs=self.g_b0_specs, input_sample_shape=(self.noise_size,))
+ self.gen_net_relu_0 = layer.Activation(name='g_relu_0', mode='relu',input_sample_shape=(self.hidden_size,))
+ self.gen_net_fc_1 = layer.Dense(name='g_fc_1', num_output=self.feature_size, use_bias=True,
+ W_specs=self.g_w1_specs, b_specs=self.g_b1_specs, input_sample_shape=(self.hidden_size,))
+ self.gen_net_sigmoid_1 = layer.Activation(name='g_relu_1', mode='sigmoid', input_sample_shape=(self.feature_size,))
+ self.gen_net.add(self.gen_net_fc_0)
+ self.gen_net.add(self.gen_net_relu_0)
+ self.gen_net.add(self.gen_net_fc_1)
+ self.gen_net.add(self.gen_net_sigmoid_1)
+ for (p, specs) in zip(self.gen_net.param_values(), self.gen_net.param_specs()):
+ filler = specs.filler
+ if filler.type == 'gaussian':
+ p.gaussian(filler.mean, filler.std)
+ elif filler.type == 'xavier':
+ initializer.xavier(p)
+ else:
+ p.set_value(0)
+ print(specs.name, filler.type, p.l1())
+ self.gen_net.to_device(self.dev)
+
+ self.d_w0_specs = {'init': 'xavier',}
+ self.d_b0_specs = {'init': 'constant', 'value': 0,}
+ self.d_w1_specs = {'init': 'xavier',}
+ self.d_b1_specs = {'init': 'constant', 'value': 0,}
+ self.dis_net = ffnet.FeedForwardNet(loss.SigmoidCrossEntropy(),)
+ self.dis_net_fc_0 = layer.Dense(name='d_fc_0', num_output=self.hidden_size, use_bias=True,
+ W_specs=self.d_w0_specs, b_specs=self.d_b0_specs, input_sample_shape=(self.feature_size,))
+ self.dis_net_relu_0 = layer.Activation(name='d_relu_0', mode='relu',input_sample_shape=(self.hidden_size,))
+ self.dis_net_fc_1 = layer.Dense(name='d_fc_1', num_output=1, use_bias=True,
+ W_specs=self.d_w1_specs, b_specs=self.d_b1_specs, input_sample_shape=(self.hidden_size,))
+ self.dis_net.add(self.dis_net_fc_0)
+ self.dis_net.add(self.dis_net_relu_0)
+ self.dis_net.add(self.dis_net_fc_1)
+ for (p, specs) in zip(self.dis_net.param_values(), self.dis_net.param_specs()):
+ filler = specs.filler
+ if filler.type == 'gaussian':
+ p.gaussian(filler.mean, filler.std)
+ elif filler.type == 'xavier':
+ initializer.xavier(p)
+ else:
+ p.set_value(0)
+ print(specs.name, filler.type, p.l1())
+ self.dis_net.to_device(self.dev)
+
+ self.combined_net = ffnet.FeedForwardNet(loss.SigmoidCrossEntropy(), )
+ for l in self.gen_net.layers:
+ self.combined_net.add(l)
+ for l in self.dis_net.layers:
+ self.combined_net.add(l)
+ self.combined_net.to_device(self.dev)
+
+ def train(self):
+ train_data, _, _, _, _, _ = load_data(self.dataset_filepath)
+ opt_0 = optimizer.Adam(lr=self.learning_rate) # optimizer for discriminator
+ opt_1 = optimizer.Adam(lr=self.learning_rate) # optimizer for generator, aka the combined model
+ for (p, specs) in zip(self.dis_net.param_names(), self.dis_net.param_specs()):
+ opt_0.register(p, specs)
+ for (p, specs) in zip(self.gen_net.param_names(), self.gen_net.param_specs()):
+ opt_1.register(p, specs)
+
+ for epoch in range(self.epochs):
+ idx = np.random.randint(0, train_data.shape[0], self.batch_size)
+ real_imgs = train_data[idx]
+ real_imgs = tensor.from_numpy(real_imgs)
+ real_imgs.to_device(self.dev)
+ noise = tensor.Tensor((self.batch_size, self.noise_size))
+ noise.uniform(-1, 1)
+ noise.to_device(self.dev)
+ fake_imgs = self.gen_net.forward(flag=False, x=noise)
+ real_labels = tensor.Tensor((self.batch_size, 1))
+ fake_labels = tensor.Tensor((self.batch_size, 1))
+ real_labels.set_value(1.0)
+ fake_labels.set_value(0.0)
+ real_labels.to_device(self.dev)
+ fake_labels.to_device(self.dev)
+ grads, (d_loss_real, _) = self.dis_net.train(real_imgs, real_labels)
+ for (s, p ,g) in zip(self.dis_net.param_names(), self.dis_net.param_values(), grads):
+ opt_0.apply_with_lr(epoch, self.learning_rate, g, p, str(s), epoch)
+ grads, (d_loss_fake, _) = self.dis_net.train(fake_imgs, fake_labels)
+ for (s, p ,g) in zip(self.dis_net.param_names(), self.dis_net.param_values(), grads):
+ opt_0.apply_with_lr(epoch, self.learning_rate, g, p, str(s), epoch)
+ d_loss = d_loss_real + d_loss_fake
+ noise = tensor.Tensor((self.batch_size, self.noise_size))
+ noise.uniform(-1,1)
+ noise.to_device(self.dev)
+ real_labels = tensor.Tensor((self.batch_size, 1))
+ real_labels.set_value(1.0)
+ real_labels.to_device(self.dev)
+ grads, (g_loss, _) = self.combined_net.train(noise, real_labels)
+ for (s, p ,g) in zip(self.gen_net.param_names(), self.gen_net.param_values(), grads):
+ opt_1.apply_with_lr(epoch, self.learning_rate, g, p, str(s), epoch)
+
+ if epoch % self.interval == 0:
+ self.save_image(epoch)
+ print_log('The {} epoch, G_LOSS: {}, D_LOSS: {}'.format(epoch, g_loss, d_loss))
+
+ def save_image(self, epoch):
+ rows = 5
+ cols = 5
+ channels = self.channels
+ noise = tensor.Tensor((rows*cols*channels, self.noise_size))
+ noise.uniform(-1, 1)
+ noise.to_device(self.dev)
+ gen_imgs = self.gen_net.forward(flag=False, x=noise)
+ gen_imgs = tensor.to_numpy(gen_imgs)
+ show_imgs = np.reshape(gen_imgs, (gen_imgs.shape[0], self.rows, self.cols, self.channels))
+ fig, axs = plt.subplots(rows, cols)
+ cnt = 0
+ for r in range(rows):
+ for c in range(cols):
+ axs[r,c].imshow(show_imgs[cnt, :, :, 0], cmap='gray')
+ axs[r,c].axis('off')
+ cnt += 1
+ fig.savefig("{}{}.png".format(self.file_dir, epoch))
+ plt.close()
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Train GAN over MNIST')
+ parser.add_argument('filepath', type=str, help='the dataset path')
+ parser.add_argument('--use_gpu', action='store_true')
+ args = parser.parse_args()
+
+ if args.use_gpu:
+ print('Using GPU')
+ dev = device.create_cuda_gpu()
+ layer.engine = 'cudnn'
+ else:
+ print('Using CPU')
+ dev = device.get_default_device()
+ layer.engine = 'singacpp'
+
+ if not os.path.exists('vanilla_images/'):
+ os.makedirs('vanilla_images/')
+
+ rows = 28
+ cols = 28
+ channels = 1
+ noise_size = 100
+ hidden_size = 128
+ batch = 128
+ interval = 1000
+ learning_rate = 0.001
+ epochs = 1000000
+ dataset_filepath = 'mnist.pkl.gz'
+ file_dir = 'vanilla_images/'
+ vanilla = VANILLA(dev, rows, cols, channels, noise_size, hidden_size, batch,
+ interval, learning_rate, epochs, dataset_filepath, file_dir)
+ vanilla.train()
\ No newline at end of file
[2/2] incubator-singa git commit: Merge branch 'pr410'
Posted by wa...@apache.org.
Merge branch 'pr410'
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/f8cd7e38
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/f8cd7e38
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/f8cd7e38
Branch: refs/heads/master
Commit: f8cd7e3846a0eb016f2d511f826c18730eeda4cc
Parents: 8aac80e b1610d7
Author: Wang Wei <wa...@gmail.com>
Authored: Mon Aug 27 22:30:56 2018 +0800
Committer: Wang Wei <wa...@gmail.com>
Committed: Mon Aug 27 22:30:56 2018 +0800
----------------------------------------------------------------------
examples/gan/download_mnist.py | 28 +++++
examples/gan/lsgan.py | 213 ++++++++++++++++++++++++++++++++++++
examples/gan/utils.py | 67 ++++++++++++
examples/gan/vanilla.py | 207 +++++++++++++++++++++++++++++++++++
4 files changed, 515 insertions(+)
----------------------------------------------------------------------