forked from hzwer/ECCV2022-RIFE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathInterpolatorInterface.py
149 lines (136 loc) · 6.12 KB
/
InterpolatorInterface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import sys
import cv2
import torch
import argparse
from torch.nn import functional as F
import warnings
import numpy as np
INTERPOLATOR_ROOT = os.path.dirname(__file__) # INTERPOLATOR_ROOT root directory
__location__ = os.path.realpath(
os.path.join(os.getcwd(), os.path.dirname(__file__)))
if str(INTERPOLATOR_ROOT) not in sys.path:
sys.path.append(str(INTERPOLATOR_ROOT)) # add INTERPOLATOR_ROOT to PATH
class InterpolatorInterface:
def __init__(self, model_pth=os.path.join(INTERPOLATOR_ROOT, "train_log") ): # os.path.join(INTERPOLATOR_ROOT, "RIFEv4.22", "train_log")):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
try:
try:
try:
from model.RIFE_HDv2 import Model
self.model = Model()
self.model.load_model(model_pth, -1)
print("Loaded v2.x HD model.")
except:
from train_log.RIFE_HDv3 import Model
self.model = Model()
self.model.load_model(model_pth, -1)
print("Loaded v3.x HD model.")
except:
from model.RIFE_HD import Model
self.model = Model()
self.model.load_model(model_pth, -1)
print("Loaded v1.x HD model")
except:
from model.RIFE import Model
self.model = Model()
self.model.load_model(model_pth, -1)
print("Loaded ArXiv-RIFE model")
self.model.eval()
# self.model.to()
self.model.device()
def generate(
self,
imgs:tuple[str, str]|None = None,
exp:int=4,
ratio:float=0,
rthreshold:float=0.02,
rmaxcycles:int=8,
outputdir:str|None = None
) -> list[str] | list[np.ndarray]:
if imgs[0].endswith('.exr') and imgs[1].endswith('.exr'):
img0 = cv2.imread(imgs[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
img1 = cv2.imread(imgs[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(self.device)).unsqueeze(0)
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(self.device)).unsqueeze(0)
else:
img0 = cv2.imread(imgs[0], cv2.IMREAD_UNCHANGED)
img1 = cv2.imread(imgs[1], cv2.IMREAD_UNCHANGED)
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(self.device) / 255.).unsqueeze(0)
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(self.device) / 255.).unsqueeze(0)
n, c, h, w = img0.shape
ph = ((h - 1) // 32 + 1) * 32
pw = ((w - 1) // 32 + 1) * 32
padding = (0, pw - w, 0, ph - h)
img0 = F.pad(img0, padding)
img1 = F.pad(img1, padding)
if ratio:
img_list = [img0]
img0_ratio = 0.0
img1_ratio = 1.0
if ratio <= img0_ratio + rthreshold / 2:
middle = img0
elif ratio >= img1_ratio - rthreshold / 2:
middle = img1
else:
tmp_img0 = img0
tmp_img1 = img1
for inference_cycle in range(rmaxcycles):
middle = self.model.inference(tmp_img0, tmp_img1)
middle_ratio = ( img0_ratio + img1_ratio ) / 2
if ratio - (rthreshold / 2) <= middle_ratio <= ratio + (rthreshold / 2):
break
if ratio > middle_ratio:
tmp_img0 = middle
img0_ratio = middle_ratio
else:
tmp_img1 = middle
img1_ratio = middle_ratio
img_list.append(middle)
img_list.append(img1)
else:
img_list = [img0, img1]
for i in range(exp):
tmp = []
for j in range(len(img_list) - 1):
mid = self.model.inference(img_list[j], img_list[j + 1])
tmp.append(img_list[j])
tmp.append(mid)
tmp.append(img1)
img_list = tmp
if outputdir is not None:
def get_next_index(directory, prefix):
existing_files = [f for f in os.listdir(directory) if f.startswith(prefix)]
indices = []
for f in existing_files:
basename = os.path.basename(f)
name, ext = os.path.splitext(basename)
if name[len(prefix):].isdigit():
indices.append(int(name[len(prefix):]))
return max(indices, default=-1) + 1
if not os.path.exists(outputdir):
os.mkdir(outputdir)
next_index = get_next_index(outputdir, 'img')
output_list = []
for i in range(len(img_list)):
if imgs[0].endswith('.exr') and imgs[1].endswith('.exr'):
new_img_pth = os.path.join(outputdir, 'img{}.exr'.format(next_index+i))
cv2.imwrite(new_img_pth, (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
output_list.append(new_img_pth)
else:
new_img_pth = os.path.join(outputdir, 'img{}.png'.format(next_index+i))
cv2.imwrite(new_img_pth, (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
output_list.append(new_img_pth)
return output_list
else:
return_list = []
for i in range(len(img_list)):
if imgs[0].endswith('.exr') and imgs[1].endswith('.exr'):
return_list.append((img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w])
else:
return_list.append((img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
return return_list