-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
157 lines (124 loc) · 5.6 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import torchvision.models as models
from einops import rearrange, repeat, reduce
# B (b) batch dimension
# R (r) ratio dimension
# C (c) color dimension
# H (h) height dimension
# W (w) width dimension
def create_feature_extractor(net_type, name_layer):
net_list = []
num_cascade = 0
num_channel_feature = 0
if net_type == 'vgg16':
base_model = models.vgg16().features
for name, layer in base_model.named_children():
net_list.append(layer)
if isinstance(layer, nn.MaxPool2d):
num_cascade += 1
if isinstance(layer, nn.Conv2d):
num_channel_feature = layer.out_channels
if name == name_layer:
break
net = nn.Sequential(*net_list)
return net, num_channel_feature, num_cascade
class BiForwardHead(nn.Module):
def __init__(self, num_embedding_dim, num_channel_out, num_channel_feature):
super().__init__()
self.num_channel_feature = num_channel_feature
self.ARS_FTM_head = nn.Linear(
num_embedding_dim, num_channel_feature**2)
self.ARS_PWP_head = nn.Linear(num_embedding_dim, num_channel_out)
def forward(self, x):
ARS_FTM = rearrange(self.ARS_FTM_head(
x), 'b r (c1 c2) -> b r c1 c2', c1=self.num_channel_feature)
ARS_PWP = rearrange(self.ARS_PWP_head(x), 'b r cout -> b r cout () ()')
return ARS_FTM, ARS_PWP
class MetaLearner(nn.Module):
def __init__(self, num_embedding_dim, num_layers, num_channel_out, num_channel_feature, dropout_rate=0.5):
super().__init__()
self.net = nn.Sequential(*[
nn.Linear(num_embedding_dim, num_embedding_dim), nn.ReLu(), nn.Dropout(dropout_rate)]*num_layers, BiForwardHead(
num_embedding_dim, num_channel_out, num_channel_feature))
def forward(self, x):
return self.net(x)
class DeconvBlock(nn.Module):
def __init__(self, num_channel_in, num_channel_out):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(num_channel_in, num_channel_out, (3, 3), padding=(1, 1)), nn.BatchNorm2d(num_channel_out), nn.ReLu())
def forward(self, x):
x = repeat(x, 'br c h w -> br c (h f1) (w f2)', f1=2, f2=2)
return self.net(x)
class Deconv(nn.Module):
def __init__(self, num_cascade, num_channel_out, num_channel_feature):
super().__init__()
self.net = nn.Sequential(DeconvBlock(num_channel_feature, num_channel_out),
* [DeconvBlock(num_channel_out, num_channel_out)]*(num_cascade-1))
def forward(self, x):
return self.net(x)
class Mars(nn.Module):
def __init__(self, net_type, name_layer, dropout_rate=0.2, dim_embedding=512, num_embedding=501, num_channel_out=96, num_meta_learner_hidden_layers=2):
super().__init__()
self.ratio_embedding_nodes = nn.Parameter(
torch.rand(num_embedding, dim_embedding))
self.embedding_interp_step = (2*math.log(2))/(num_embedding-1)
# TODO 优化device
self.feature_extractor, num_channel_feature, num_cascade = create_feature_extractor(
net_type, name_layer)
self.meta_learner = MetaLearner(
dim_embedding, num_meta_learner_hidden_layers, num_channel_out, num_channel_feature, dropout_rate)
self.deconv_layers = Deconv(
num_cascade, num_channel_out, num_channel_feature)
self.GAP = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x, ratio):
# generate ARS_FTM and ARS_PWP given ratio
ratio_embedding = self.get_ratio_embedding_batch(ratio)
ARS_FTM, ARS_PWP = self.meta_learner(ratio_embedding)
# get the features from image
x = self.feature_extractor(x)
h = x.shape[2]
w = x.shape[3]
x = self.GAP(x)
# repeat to fill the ratio dimension
r = ratio.shape[1]
x = repeat(x, 'b c () () -> b r c () ()', r=r)
# transform and add
x = x + self.ratio_transform(x, ARS_FTM) # b*r*c*1*1
# replicate to h*w*c
x = repeat(x, 'b r c () () -> b r c h w', h=h, w=w)
# deconv
x = rearrange(x, 'b r c h w -> (b r) c h w')
x = self.deconv_layers(x)
x = rearrange(x, '(b r) c h w -> b r c h w', r=r)
# predict point-wise
x = self.pixelwise_predict(x, ARS_PWP)
return F.sigmoid(x)
def get_ratio_embedding(self, ratio):
log_ratio = math.log(ratio)
idx_low_node = math.floor(
(log_ratio+math.log(2))/self.embedding_interp_step) - 1
rate_high = (log_ratio - (idx_low_node+1) *
self.embedding_interp_step + math.log(2))/self.embedding_interp_step
ratio_embedding = self.ratio_embedding_nodes[idx_low_node, :]*(
1-rate_high)+self.ratio_embedding_nodes[idx_low_node+1, :]*rate_high
ratio_embedding = rearrange(ratio_embedding, 'n -> () () n')
return ratio_embedding
def get_ratio_embedding_batch(self, batch_ratios):
ratio_embedding = torch.cat([torch.cat([self.get_ratio_embedding(ratio)
for ratio in ratios], dim=1) for ratios in batch_ratios], dim=0)
return ratio_embedding
@staticmethod
def ratio_transform(x, ARS_FTM):
x = rearrange(x, 'b r c () () -> b r () c')
x = torch.matmul(x, ARS_FTM)
x = rearrange(x, 'b r () c -> b r c () ()')
return x
@staticmethod
def pixelwise_predict(x, ARS_PWP):
x = x*ARS_PWP
x = reduce(x, 'b r c h w -> b r h w', 'sum')
return x