forked from cvlab-stonybrook/PaperEdge
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo.py
110 lines (90 loc) · 3.53 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# -*- encoding: utf-8 -*-
import os
import glob
import argparse
import copy
from pathlib import Path
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from networks.paperedge import GlobalWarper, LocalWarper, WarperUtil
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
def load_img(img_path):
im = cv2.imread(img_path).astype(np.float32) / 255.0
im = im[:, :, (2, 1, 0)]
im = cv2.resize(im, (256, 256), interpolation=cv2.INTER_AREA)
im = torch.from_numpy(np.transpose(im, (2, 0, 1)))
return im
def get_device() -> str:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def unwarp_img(
netG: torch.nn.Module,
netL: torch.nn.Module,
img_path: str,
):
gs_d, ls_d = None, None
with torch.no_grad():
x = load_img(img_path)
x = x.unsqueeze(0)
x = x.to('cuda')
d = netG(x) # d_E the edged-based deformation field
d = warpUtil.global_post_warp(d, 64)
gs_d = copy.deepcopy(d)
d = F.interpolate(d, size=256, mode='bilinear', align_corners=True)
y0 = F.grid_sample(x, d.permute(0, 2, 3, 1), align_corners=True)
ls_d = netL(y0)
ls_d = F.interpolate(ls_d, size=256, mode='bilinear', align_corners=True)
ls_d = ls_d.clamp(-1.0, 1.0)
im = cv2.imread(img_path).astype(np.float32) / 255.0
im = torch.from_numpy(np.transpose(im, (2, 0, 1)))
im = im.to('cuda').unsqueeze(0)
gs_d = F.interpolate(gs_d, (im.size(2), im.size(3)), mode='bilinear', align_corners=True)
gs_y = F.grid_sample(im, gs_d.permute(0, 2, 3, 1), align_corners=True).detach()
tmp_y = gs_y.squeeze().permute(1, 2, 0).cpu().numpy()
ls_d = F.interpolate(ls_d, (im.size(2), im.size(3)), mode='bilinear', align_corners=True)
ls_y = F.grid_sample(gs_y, ls_d.permute(0, 2, 3, 1), align_corners=True).detach()
ls_y = ls_y.squeeze().permute(1, 2, 0).cpu().numpy()
return tmp_y, ls_y
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--Enet_ckpt', type=str,
default='models/G_w_checkpoint_13820.pt')
parser.add_argument('--Tnet_ckpt', type=str,
default='models/L_w_checkpoint_27640.pt')
parser.add_argument('--img_source', type=str, default='images/3.jpg')
parser.add_argument('--out_dir', type=str, default='output')
args = parser.parse_args()
img_source = args.img_source
dst_dir = args.out_dir
Path(dst_dir).mkdir(parents=True, exist_ok=True)
netG = GlobalWarper().to('cuda')
netG.load_state_dict(torch.load(args.Enet_ckpt)['G'])
netG.eval()
netL = LocalWarper().to('cuda')
netL.load_state_dict(torch.load(args.Tnet_ckpt)['L'])
netL.eval()
warpUtil = WarperUtil(64).to('cuda')
if Path(img_source).is_dir():
img_paths = glob.glob(os.path.join(img_source, '*'))
for img_path in img_paths:
tmp_y, ls_y = unwarp_img(netG, netL, img_path)
cv2.imwrite(
os.path.join(dst_dir, f"netG_{os.path.basename(img_path)}"),
tmp_y* 255.,
)
cv2.imwrite(
os.path.join(dst_dir, f"netL_{os.path.basename(img_path)}"),
ls_y* 255.,
)
else:
tmp_y, ls_y = unwarp_img(netG, netL, img_source)
cv2.imwrite(
os.path.join(dst_dir, f"netG_{os.path.basename(img_source)}"),
tmp_y* 255.,
)
cv2.imwrite(
os.path.join(dst_dir, f"netL_{os.path.basename(img_source)}"),
ls_y* 255.,
)