-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathDiscoGAN.py
373 lines (278 loc) · 15.8 KB
/
DiscoGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
from ops import *
from utils import *
from glob import glob
import time
from tensorflow.contrib.data import batch_and_drop_remainder
class DiscoGAN(object) :
def __init__(self, sess, args):
self.model_name = 'DiscoGAN'
self.sess = sess
self.checkpoint_dir = args.checkpoint_dir
self.result_dir = args.result_dir
self.log_dir = args.log_dir
self.sample_dir = args.sample_dir
self.dataset_name = args.dataset
self.augment_flag = args.augment_flag
self.epoch = args.epoch
self.iteration = args.iteration
self.gan_type = args.gan_type
self.batch_size = args.batch_size
self.print_freq = args.print_freq
self.save_freq = args.save_freq
self.img_size = args.img_size
self.img_ch = args.img_ch
self.init_lr = args.lr
self.ch = args.ch
""" Weight """
self.gan_w = args.gan_w
self.cycle_w = args.cycle_w
""" Generator """
""" Discriminator """
self.n_dis = args.n_dis
self.sample_dir = os.path.join(args.sample_dir, self.model_dir)
check_folder(self.sample_dir)
self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA'))
self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB'))
self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset))
print("##### Information #####")
print("# gan type : ", self.gan_type)
print("# dataset : ", self.dataset_name)
print("# max dataset number : ", self.dataset_num)
print("# batch_size : ", self.batch_size)
print("# epoch : ", self.epoch)
print("# iteration per epoch : ", self.iteration)
print()
print("##### Generator #####")
print()
print("##### Discriminator #####")
print("# Discriminator layer : ", self.n_dis)
##################################################################################
# Generator
##################################################################################
def generator(self, x, is_training=True, reuse=False, scope="generator"):
channel = self.ch
with tf.variable_scope(scope, reuse=reuse) :
x = conv(x, channel, kernel=4, stride=2, pad=1, use_bias=False, scope='conv_0')
x = lrelu(x, 0.2)
# Down-Sampling
for i in range(3) :
x = conv(x, channel*2, kernel=4, stride=2, pad=1, use_bias=False, scope='conv_'+str(i+1))
x = batch_norm(x, is_training, scope='down_bn_'+str(i+1))
x = lrelu(x, 0.2)
channel = channel * 2
# Up-Sampling
for i in range(3) :
x = deconv(x, channel//2, kernel=4, stride=2, use_bias=False, scope='deconv_'+str(i+1))
x = batch_norm(x, is_training, scope='up_bn_'+str(i+1))
x = relu(x)
channel = channel // 2
x = deconv(x, channels=3, kernel=4, stride=2, use_bias=False, scope='G_logit')
x = tanh(x)
return x
##################################################################################
# Discriminator
##################################################################################
def discriminator(self, x, is_training=True, reuse=False, scope="discriminator"):
channel = self.ch
with tf.variable_scope(scope, reuse=reuse) :
x = conv(x, channel, kernel=4, stride=2, pad=1, use_bias=False, scope='conv_0')
x = lrelu(x, 0.2)
for i in range(1, self.n_dis) :
x = conv(x, channel*2, kernel=4, stride=2, pad=1, use_bias=False, scope='conv_'+str(i))
x = batch_norm(x, is_training, scope='bn_'+str(i))
x = lrelu(x, 0.2)
channel = channel * 2
x = conv(x, channels=1, kernel=4, stride=1, use_bias=False, scope='D_logit')
return x
##################################################################################
# Model
##################################################################################
def generate_a2b(self, x_A, is_training=True, reuse=False):
x_ab = self.generator(x_A, is_training, reuse=reuse, scope='generator_B')
return x_ab
def generate_b2a(self, x_B, is_training=True, reuse=False):
x_ba = self.generator(x_B, is_training, reuse=reuse, scope='generator_A')
return x_ba
def discriminate_real(self, x_A, x_B, is_training=True):
real_A_logit = self.discriminator(x_A, is_training, scope="discriminator_A")
real_B_logit = self.discriminator(x_B, is_training, scope="discriminator_B")
return real_A_logit, real_B_logit
def discriminate_fake(self, x_ba, x_ab, is_training=True):
fake_A_logit = self.discriminator(x_ba, is_training, reuse=True, scope="discriminator_A")
fake_B_logit = self.discriminator(x_ab, is_training, reuse=True, scope="discriminator_B")
return fake_A_logit, fake_B_logit
def build_model(self):
self.lr = tf.placeholder(tf.float32, name='learning_rate')
""" Input Image"""
Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag)
trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)
trainA = trainA.prefetch(self.batch_size).shuffle(self.dataset_num).map(Image_Data_Class.image_processing, num_parallel_calls=8).apply(batch_and_drop_remainder(self.batch_size)).repeat()
trainB = trainB.prefetch(self.batch_size).shuffle(self.dataset_num).map(Image_Data_Class.image_processing, num_parallel_calls=8).apply(batch_and_drop_remainder(self.batch_size)).repeat()
trainA_iterator = trainA.make_one_shot_iterator()
trainB_iterator = trainB.make_one_shot_iterator()
self.domain_A = trainA_iterator.get_next()
self.domain_B = trainB_iterator.get_next()
""" Define Encoder, Generator, Discriminator """
x_ab = self.generate_a2b(self.domain_A)
x_ba = self.generate_b2a(self.domain_B)
x_aba = self.generate_b2a(x_ab, reuse=True)
x_bab = self.generate_a2b(x_ba, reuse=True)
real_A_logit, real_B_logit = self.discriminate_real(self.domain_A, self.domain_B)
fake_A_logit, fake_B_logit = self.discriminate_fake(x_ba, x_ab)
""" Define Loss """
G_ad_loss_a = generator_loss(self.gan_type, fake_A_logit)
G_ad_loss_b = generator_loss(self.gan_type, fake_B_logit)
D_ad_loss_a = discriminator_loss(self.gan_type, real_A_logit, fake_A_logit)
D_ad_loss_b = discriminator_loss(self.gan_type, real_B_logit, fake_B_logit)
recon_loss_a = L1_loss(x_aba, self.domain_A) # reconstruction
recon_loss_b = L1_loss(x_bab, self.domain_B) # reconstruction
Generator_A_loss = self.gan_w * G_ad_loss_a + \
self.cycle_w * recon_loss_b
Generator_B_loss = self.gan_w * G_ad_loss_b + \
self.cycle_w * recon_loss_a
Discriminator_A_loss = self.gan_w * D_ad_loss_a
Discriminator_B_loss = self.gan_w * D_ad_loss_b
self.Generator_loss = Generator_A_loss + Generator_B_loss
self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss
""" Training """
t_vars = tf.trainable_variables()
G_vars = [var for var in t_vars if 'generator' in var.name]
D_vars = [var for var in t_vars if 'discriminator' in var.name]
self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars)
self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars)
"""" Summary """
self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)
self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss)
self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss)
self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss)
self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss)
self.G_loss = tf.summary.merge([self.G_A_loss, self.G_B_loss, self.all_G_loss])
self.D_loss = tf.summary.merge([self.D_A_loss, self.D_B_loss, self.all_D_loss])
""" Image """
self.fake_A = x_ba
self.fake_B = x_ab
self.real_A = self.domain_A
self.real_B = self.domain_B
""" Test """
self.test_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_image')
self.test_fake_A = self.generate_b2a(self.test_image, is_training=False, reuse=True)
self.test_fake_B = self.generate_a2b(self.test_image, is_training=False, reuse=True)
def train(self):
# initialize all variables
tf.global_variables_initializer().run()
# saver to save model
self.saver = tf.train.Saver()
# summary writer
self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
# restore check-point if it exits
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
if could_load:
start_epoch = (int)(checkpoint_counter / self.iteration)
start_batch_id = checkpoint_counter - start_epoch * self.iteration
counter = checkpoint_counter
print(" [*] Load SUCCESS")
else:
start_epoch = 0
start_batch_id = 0
counter = 1
print(" [!] Load failed...")
# loop for epoch
start_time = time.time()
lr = self.init_lr
for epoch in range(start_epoch, self.epoch):
for idx in range(start_batch_id, self.iteration):
train_feed_dict = {
self.lr : lr
}
# Update D
_, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss], feed_dict = train_feed_dict)
self.writer.add_summary(summary_str, counter)
# Update G
batch_A_images, batch_B_images, fake_A, fake_B, _, g_loss, summary_str = self.sess.run([self.real_A, self.real_B, self.fake_A, self.fake_B, self.G_optim, self.Generator_loss, self.G_loss], feed_dict = train_feed_dict)
self.writer.add_summary(summary_str, counter)
# display training status
counter += 1
print("Epoch: [%2d] [%6d/%6d] time: %4.4f d_loss: %.8f, g_loss: %.8f" \
% (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))
if np.mod(idx+1, self.print_freq) == 0 :
save_images(batch_A_images, [self.batch_size, 1],
'./{}/real_A_{:02d}_{:06d}.jpg'.format(self.sample_dir, epoch, idx+1))
# save_images(batch_B_images, [self.batch_size, 1],
# './{}/real_B_{}_{:02d}_{:06d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+1))
# save_images(fake_A, [self.batch_size, 1],
# './{}/fake_A_{}_{:02d}_{:06d}.jpg'.format(self.sample_dir, gpu_id, epoch, idx+1))
save_images(fake_B, [self.batch_size, 1],
'./{}/fake_B_{:02d}_{:06d}.jpg'.format(self.sample_dir, epoch, idx+1))
if np.mod(idx+1, self.save_freq) == 0 :
self.save(self.checkpoint_dir, counter)
# After an epoch, start_batch_id is set to zero
# non-zero value is only for the first epoch after loading pre-trained model
start_batch_id = 0
# save model for final step
self.save(self.checkpoint_dir, counter)
@property
def model_dir(self):
return "{}_{}_{}".format(self.model_name, self.dataset_name, self.gan_type)
def save(self, checkpoint_dir, step):
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
def load(self, checkpoint_dir):
import re
print(" [*] Reading checkpoints...")
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
print(" [*] Success to read {}".format(ckpt_name))
return True, counter
else:
print(" [*] Failed to find a checkpoint")
return False, 0
def test(self):
tf.global_variables_initializer().run()
test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA'))
test_B_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testB'))
self.saver = tf.train.Saver()
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
self.result_dir = os.path.join(self.result_dir, self.model_dir)
check_folder(self.result_dir)
if could_load :
print(" [*] Load SUCCESS")
else :
print(" [!] Load failed...")
# write html for visual comparison
index_path = os.path.join(self.result_dir, 'index.html')
index = open(index_path, 'w')
index.write("<html><body><table><tr>")
index.write("<th>name</th><th>input</th><th>output</th></tr>")
for sample_file in test_A_files : # A -> B
print('Processing A image: ' + sample_file)
sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
image_path = os.path.join(self.result_dir, '{0}'.format(os.path.basename(sample_file)))
fake_img = self.sess.run(self.test_fake_B, feed_dict={self.test_image: sample_image})
save_images(fake_img, [1, 1], image_path)
index.write("<td>%s</td>" % os.path.basename(image_path))
index.write("<td><img src='%s' width='%d' height='%d'></td>" % (sample_file if os.path.isabs(sample_file) else (
'../..' + os.path.sep + sample_file), self.img_size, self.img_size))
index.write("<td><img src='%s' width='%d' height='%d'></td>" % (image_path if os.path.isabs(image_path) else (
'../..' + os.path.sep + image_path), self.img_size, self.img_size))
index.write("</tr>")
for sample_file in test_B_files : # B -> A
print('Processing B image: ' + sample_file)
sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
image_path = os.path.join(self.result_dir, '{0}'.format(os.path.basename(sample_file)))
fake_img = self.sess.run(self.test_fake_A, feed_dict={self.test_image: sample_image})
save_images(fake_img, [1, 1], image_path)
index.write("<td>%s</td>" % os.path.basename(image_path))
index.write("<td><img src='%s' width='%d' height='%d'></td>" % (sample_file if os.path.isabs(sample_file) else (
'../..' + os.path.sep + sample_file), self.img_size, self.img_size))
index.write("<td><img src='%s' width='%d' height='%d'></td>" % (image_path if os.path.isabs(image_path) else (
'../..' + os.path.sep + image_path), self.img_size, self.img_size))
index.write("</tr>")
index.close()