-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtest.py
104 lines (87 loc) · 3.02 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argparse
import os
from os import path
import copy
from tqdm import tqdm
import torch
from torch import nn
from gan_training import utils
from gan_training.checkpoints import CheckpointIO
from gan_training.distributions import get_ydist, get_zdist
from gan_training.eval import Evaluator
from gan_training.config import (
load_config, build_models
)
# Arguments
parser = argparse.ArgumentParser(
description='Test a trained GAN and create visualizations.'
)
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.')
args = parser.parse_args()
config = load_config(args.config)
is_cuda = (torch.cuda.is_available() and not args.no_cuda)
# Shorthands
nlabels = config['data']['nlabels']
out_dir = config['training']['out_dir']
batch_size = config['test']['batch_size']
sample_size = config['test']['sample_size']
sample_nrow = config['test']['sample_nrow']
checkpoint_dir = path.join(out_dir, 'chkpts')
img_dir = path.join(out_dir, 'test', 'img')
img_all_dir = path.join(out_dir, 'test', 'img_all')
# Creat missing directories
if not path.exists(img_dir):
os.makedirs(img_dir)
if not path.exists(img_all_dir):
os.makedirs(img_all_dir)
# Logger
checkpoint_io = CheckpointIO(
checkpoint_dir=checkpoint_dir
)
device = torch.device("cuda:0" if is_cuda else "cpu")
generator, discriminator = build_models(config)
print(generator)
print(discriminator)
# Put models on gpu if needed
generator = generator.to(device)
discriminator = discriminator.to(device)
# Use multiple GPUs if possible
generator = nn.DataParallel(generator)
discriminator = nn.DataParallel(discriminator)
# Register modules to checkpoint
checkpoint_io.register_modules(
generator=generator,
discriminator=discriminator,
)
# Test generator
if config['test']['use_model_average']:
generator_test = copy.deepcopy(generator)
checkpoint_io.register_modules(generator_test=generator_test)
else:
generator_test = generator
# Distributions
ydist = get_ydist(nlabels, device=device)
zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'],
device=device)
# Evaluator
evaluator = Evaluator(generator_test, zdist, ydist,
batch_size=batch_size, device=device)
# Load checkpoint if existant
it = checkpoint_io.load('model.pt')
# Inception score
if config['test']['compute_inception']:
print('Computing inception score...')
inception_mean, inception_std = evaluator.compute_inception_score()
print('Inception score: %.4f +- %.4f' % (inception_mean, inception_std))
# Samples
print('Creating samples...')
ztest = zdist.sample((sample_size,))
x = evaluator.create_samples(ztest)
utils.save_images(x, path.join(img_all_dir, '%08d.png' % it),
nrow=sample_nrow)
if config['test']['conditional_samples']:
for y_inst in tqdm(range(nlabels)):
x = evaluator.create_samples(ztest, y_inst)
utils.save_images(x, path.join(img_dir, '%04d.png' % y_inst),
nrow=sample_nrow)