-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy patheval.py
executable file
·47 lines (31 loc) · 1.1 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from pathlib import Path
from data import read_image_tensor, write_image_tensor, ImageDataset
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
from train import data_path, model_save_path
# you can overwrite data_path here
output_dir = data_path/'output'
input_dir = data_path/'input'
# Change these depending on your hardware, has to match training settings
device = 'cuda'
dtype = torch.float16
generator = torch.load(model_save_path/"generator.pt")
generator.eval()
generator.to(device, dtype)
# TODO batch size, async dataloader
file_paths = [file for file in input_dir.iterdir()]
params = {'batch_size': 1,
'num_workers': 8,
'pin_memory': True}
dataset = ImageDataset(file_paths,)
loader = DataLoader(dataset, **params)
# TODO multiprocess and asynchronous writing of files
with torch.no_grad():
for inputs, names in tqdm(loader):
inputs = inputs.to(device, dtype)
outputs = generator(inputs)
del inputs
for j in range(len(outputs)):
write_image_tensor(outputs[j], output_dir/names[j])
del outputs