-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtest.py
38 lines (28 loc) · 915 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import models.models as models
import dataloaders.dataloaders as dataloaders
import utils.utils as utils
import config
import time
import torch
#--- read options ---#
opt = config.read_arguments(train=False)
#--- create dataloader ---#
_, dataloader_val = dataloaders.get_dataloaders(opt)
#--- create utils ---#
image_saver = utils.results_saver(opt)
#--- create models ---#
model = models.DP_GAN_model(opt)
model = models.put_on_multi_gpus(model, opt)
model.eval()
total_time = 0
#--- iterate over validation set ---#
for i, data_i in enumerate(dataloader_val):
_, label = models.preprocess_input(opt, data_i)
end = time.time()
generated = model(None, label, "generate", None)
if torch.cuda.is_available():
torch.cuda.synchronize()
t = time.time() - end
total_time += t
image_saver(label, generated, data_i["name"])
print("Avg time: ", total_time/len(dataloader_val))