-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerator_model.py
89 lines (74 loc) · 2.98 KB
/
generator_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch # noqa
from torch import nn # noqa
from custom_layers import *
class Generator(nn.Module):
def __init__(self,
input_code_dim=128,
in_channel=128,
pixel_norm=True,
tanh=True):
super(Generator, self).__init__()
self.input_dim = input_code_dim
self.tanh = tanh
self.input_layer = nn.Sequential(
EqualizedConvTranspose2d(input_code_dim, in_channel, 4, 1, 0),
PixelNormalization(),
nn.LeakyReLU(0.1)
)
self.progression = nn.ModuleList(
[
ConvolutionLayer(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm),
ConvolutionLayer(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm),
ConvolutionLayer(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm),
ConvolutionLayer(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm),
ConvolutionLayer(in_channel, in_channel // 2, 3, 1, pixel_norm=pixel_norm),
ConvolutionLayer(in_channel // 2, in_channel // 4, 3, 1, pixel_norm=pixel_norm),
ConvolutionLayer(in_channel // 4, in_channel // 4, 3, 1, pixel_norm=pixel_norm),
]
)
self.to_rgb = nn.ModuleList(
[
EqualizedConv2d(in_channel, 3, 1),
EqualizedConv2d(in_channel, 3, 1),
EqualizedConv2d(in_channel, 3, 1),
EqualizedConv2d(in_channel // 2, 3, 1),
EqualizedConv2d(in_channel // 4, 3, 1),
EqualizedConv2d(in_channel // 4, 3, 1)
]
)
self.max_step = 6
# noinspection PyUnusedLocal
# noinspection PyMethodMayBeStatic
def progress(self, feat, module):
out = upscale(feat)
out = module(out)
return out
def output(self, feat1, feat2, module1, module2, alpha):
if 0 <= alpha < 1:
skip_rgb = upscale(module1(feat1))
out = (1 - alpha) * skip_rgb + alpha * module2(feat2)
else:
out = module2(feat2)
if self.tanh:
return torch.tanh(out)
return out
def forward(self, input, step=0, alpha=-1):
step = min(self.max_step, step)
outputs = [None for _ in range(self.max_step + 1)]
outputs[0] = self.input_layer(input.view(-1, self.input_dim, 1, 1))
outputs[0] = self.progression[0](outputs[0])
outputs[1] = self.progress(outputs[0], self.progression[1])
if step == 1:
if self.tanh:
return torch.tanh(self.to_rgb_8(outputs[1]))
return self.to_rgb[0](outputs[1])
current = 2
while current <= step:
outputs[current] = self.progress(outputs[current - 1], self.progression[current])
current += 1
return self.output(
outputs[step - 1],
outputs[step],
self.to_rgb[step - 2],
self.to_rgb[step - 1], alpha
)