-
Notifications
You must be signed in to change notification settings - Fork 43
/
Copy pathmodel.py
105 lines (84 loc) · 3.53 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
from torch import nn
from modules import ConvSC, Inception
def stride_generator(N, reverse=False):
strides = [1, 2]*10
if reverse: return list(reversed(strides[:N]))
else: return strides[:N]
class Encoder(nn.Module):
def __init__(self,C_in, C_hid, N_S):
super(Encoder,self).__init__()
strides = stride_generator(N_S)
self.enc = nn.Sequential(
ConvSC(C_in, C_hid, stride=strides[0]),
*[ConvSC(C_hid, C_hid, stride=s) for s in strides[1:]]
)
def forward(self,x):# B*4, 3, 128, 128
enc1 = self.enc[0](x)
latent = enc1
for i in range(1,len(self.enc)):
latent = self.enc[i](latent)
return latent,enc1
class Decoder(nn.Module):
def __init__(self,C_hid, C_out, N_S):
super(Decoder,self).__init__()
strides = stride_generator(N_S, reverse=True)
self.dec = nn.Sequential(
*[ConvSC(C_hid, C_hid, stride=s, transpose=True) for s in strides[:-1]],
ConvSC(2*C_hid, C_hid, stride=strides[-1], transpose=True)
)
self.readout = nn.Conv2d(C_hid, C_out, 1)
def forward(self, hid, enc1=None):
for i in range(0,len(self.dec)-1):
hid = self.dec[i](hid)
Y = self.dec[-1](torch.cat([hid, enc1], dim=1))
Y = self.readout(Y)
return Y
class Mid_Xnet(nn.Module):
def __init__(self, channel_in, channel_hid, N_T, incep_ker = [3,5,7,11], groups=8):
super(Mid_Xnet, self).__init__()
self.N_T = N_T
enc_layers = [Inception(channel_in, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)]
for i in range(1, N_T-1):
enc_layers.append(Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
enc_layers.append(Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
dec_layers = [Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)]
for i in range(1, N_T-1):
dec_layers.append(Inception(2*channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
dec_layers.append(Inception(2*channel_hid, channel_hid//2, channel_in, incep_ker= incep_ker, groups=groups))
self.enc = nn.Sequential(*enc_layers)
self.dec = nn.Sequential(*dec_layers)
def forward(self, x):
B, T, C, H, W = x.shape
x = x.reshape(B, T*C, H, W)
# encoder
skips = []
z = x
for i in range(self.N_T):
z = self.enc[i](z)
if i < self.N_T - 1:
skips.append(z)
# decoder
z = self.dec[0](z)
for i in range(1, self.N_T):
z = self.dec[i](torch.cat([z, skips[-i]], dim=1))
y = z.reshape(B, T, C, H, W)
return y
class SimVP(nn.Module):
def __init__(self, shape_in, hid_S=16, hid_T=256, N_S=4, N_T=8, incep_ker=[3,5,7,11], groups=8):
super(SimVP, self).__init__()
T, C, H, W = shape_in
self.enc = Encoder(C, hid_S, N_S)
self.hid = Mid_Xnet(T*hid_S, hid_T, N_T, incep_ker, groups)
self.dec = Decoder(hid_S, C, N_S)
def forward(self, x_raw):
B, T, C, H, W = x_raw.shape
x = x_raw.view(B*T, C, H, W)
embed, skip = self.enc(x)
_, C_, H_, W_ = embed.shape
z = embed.view(B, T, C_, H_, W_)
hid = self.hid(z)
hid = hid.reshape(B*T, C_, H_, W_)
Y = self.dec(hid, skip)
Y = Y.reshape(B, T, C, H, W)
return Y