Skip to content

Commit 95d06b4

Browse files
committedJul 15, 2024
update vr model to use pretrained
1 parent 11ddc82 commit 95d06b4

File tree

4 files changed

+81
-51
lines changed

4 files changed

+81
-51
lines changed
 

‎hrdae/models/gan_model.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
@dataclass
2121
class GANModelOption(ModelOption):
2222
network: NetworkOption = MISSING
23+
network_weight: str = ""
2324
discriminator: NetworkOption = MISSING
2425
optimizer_g: OptimizerOption = MISSING
2526
optimizer_d: OptimizerOption = MISSING
@@ -35,6 +36,7 @@ class GANModel(Model):
3536
def __init__(
3637
self,
3738
generator: nn.Module,
39+
generator_weight: str,
3840
discriminator: nn.Module,
3941
optimizer_g: Optimizer,
4042
optimizer_d: Optimizer,
@@ -54,6 +56,9 @@ def __init__(
5456
self.criterion_g = criterion_g
5557
self.criterion_d = criterion_d
5658

59+
if generator_weight != "":
60+
self.generator.load_state_dict(torch.load(generator_weight))
61+
5762
if torch.cuda.is_available():
5863
print("GPU is enabled")
5964
self.device = torch.device("cuda:0")
@@ -108,7 +113,7 @@ def train(
108113
mixed_state1 = state1[shuffled_indices(batch_size)]
109114

110115
same = self.discriminator(torch.cat([state1, state2], dim=1))
111-
diff = self.discriminator(torch.cat([state1, mixed_state1], dim=1))
116+
# diff = self.discriminator(torch.cat([state1, mixed_state1], dim=1))
112117

113118
loss_g_basic = self.criterion(
114119
y,
@@ -120,7 +125,7 @@ def train(
120125
# diff == zerosなら、異なるビデオと見破られたことになるため、state encoderのロスは最大となる
121126
loss_g_adv = self.criterion_g(
122127
same, torch.zeros_like(same)
123-
) + self.criterion_g(diff, torch.ones_like(diff))
128+
) # + self.criterion_g(diff, torch.ones_like(diff))
124129

125130
loss_g = loss_g_basic + adv_ratio * loss_g_adv
126131
loss_g.backward()
@@ -139,9 +144,9 @@ def train(
139144
diff = self.discriminator(
140145
torch.cat([state1.detach(), mixed_state1.detach()], dim=1)
141146
)
142-
loss_d_adv = self.criterion_d(
143-
same, torch.ones_like(same)
144-
) + self.criterion_d(diff, torch.zeros_like(diff))
147+
loss_d_adv_same = self.criterion_d(same, torch.ones_like(same))
148+
loss_d_adv_diff = self.criterion_d(diff, torch.zeros_like(diff))
149+
loss_d_adv = (loss_d_adv_same + loss_d_adv_diff) / 2
145150
loss_d_adv.backward()
146151
self.optimizer_d.step()
147152

@@ -152,6 +157,8 @@ def train(
152157
f"Epoch: {epoch+1}, "
153158
f"Batch: {idx}, "
154159
f"Loss D Adv: {loss_d_adv.item():.6f}, "
160+
f"Loss D Adv (same): {loss_d_adv_same.item():.6f}, "
161+
f"Loss D Adv (diff): {loss_d_adv_diff.item():.6f}, "
155162
f"Loss G: {loss_g.item():.6f}, "
156163
f"Loss G Adv: {loss_g_adv.item():.6f}, "
157164
f"Loss G Basic: {loss_g_basic.item():.6f}, "
@@ -194,7 +201,7 @@ def train(
194201
mixed_state1 = state1[shuffled_indices(batch_size)]
195202

196203
same = self.discriminator(torch.cat([state1, state2], dim=1))
197-
diff = self.discriminator(torch.cat([state1, mixed_state1], dim=1))
204+
# diff = self.discriminator(torch.cat([state1, mixed_state1], dim=1))
198205

199206
y = y.detach().clone()
200207
loss_g_basic = self.criterion(
@@ -205,12 +212,12 @@ def train(
205212
)
206213
loss_g_adv = self.criterion_g(
207214
same, torch.zeros_like(same)
208-
) + self.criterion_g(diff, torch.ones_like(diff))
215+
) # + self.criterion_g(diff, torch.ones_like(diff))
209216

210217
loss_g = loss_g_basic + adv_ratio * loss_g_adv
211-
loss_d_adv = self.criterion_d(
212-
same, torch.ones_like(same)
213-
) + self.criterion_d(diff, torch.zeros_like(diff))
218+
loss_d_adv_same = self.criterion_d(same, torch.ones_like(same))
219+
loss_d_adv_diff = self.criterion_d(diff, torch.zeros_like(diff))
220+
loss_d_adv = (loss_d_adv_same + loss_d_adv_diff) / 2
214221

215222
total_val_loss_g += loss_g.item()
216223
total_val_loss_g_basic += loss_g_basic.item()
@@ -272,6 +279,9 @@ def train(
272279
}
273280
)
274281

282+
with open(result_dir / "training_history.json", "w") as f:
283+
json.dump(training_history, f, indent=2)
284+
275285
if epoch % 10 == 0:
276286
data = next(iter(val_loader))
277287

@@ -295,9 +305,6 @@ def train(
295305
f"epoch_{epoch}",
296306
)
297307

298-
with open(result_dir / "training_history.json", "w") as f:
299-
json.dump(training_history, f)
300-
301308
return least_val_loss_g
302309

303310

@@ -353,6 +360,7 @@ def create_gan_model(
353360
criterion_d = create_loss(opt.loss_d)
354361
return GANModel(
355362
generator,
363+
opt.network_weight,
356364
discriminator,
357365
optimizer_g,
358366
optimizer_d,

‎hrdae/models/networks/discriminator.py

+27-24
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from torch import Tensor, nn
44

5-
from .modules import ConvModule2d, ConvModule3d
5+
from .modules import ConvModule3d
66
from .option import NetworkOption
77

88

@@ -11,12 +11,8 @@ class Discriminator2dOption(NetworkOption):
1111
in_channels: int = 8
1212
hidden_channels: int = 256
1313
image_size: list[int] = field(default_factory=lambda: [4, 4])
14-
conv_params: list[dict[str, list[int]]] = field(
15-
default_factory=lambda: [
16-
{"kernel_size": [3], "stride": [2], "padding": [1], "output_padding": [1]},
17-
]
18-
)
19-
debug_show_dim: bool = False
14+
dropout_rate: float = 0.5
15+
fc_layer: int = 3
2016

2117

2218
def create_discriminator2d(opt: Discriminator2dOption) -> nn.Module:
@@ -25,8 +21,8 @@ def create_discriminator2d(opt: Discriminator2dOption) -> nn.Module:
2521
out_channels=1,
2622
hidden_channels=opt.hidden_channels,
2723
image_size=opt.image_size,
28-
conv_params=opt.conv_params,
29-
debug_show_dim=opt.debug_show_dim,
24+
dropout_rate=opt.dropout_rate,
25+
fc_layer=opt.fc_layer,
3026
)
3127

3228

@@ -37,29 +33,27 @@ def __init__(
3733
out_channels: int,
3834
hidden_channels: int,
3935
image_size: list[int],
40-
conv_params: list[dict[str, list[int]]],
41-
debug_show_dim: bool,
36+
fc_layer: int,
37+
dropout_rate: float,
4238
) -> None:
4339
super().__init__()
44-
self.cnn = ConvModule2d(
45-
in_channels,
46-
hidden_channels,
47-
hidden_channels,
48-
conv_params,
49-
transpose=False,
50-
act_norm=False,
51-
debug_show_dim=debug_show_dim,
52-
)
5340
size = image_size[0] * image_size[1]
54-
self.bottleneck = nn.Sequential(
55-
nn.Linear(size * hidden_channels, hidden_channels),
41+
self.fc = nn.Sequential(
42+
nn.Linear(in_channels * size, hidden_channels),
43+
nn.BatchNorm1d(hidden_channels),
5644
nn.ReLU(),
45+
nn.Dropout1d(dropout_rate),
46+
*[
47+
nn.Linear(hidden_channels, hidden_channels),
48+
nn.BatchNorm1d(hidden_channels),
49+
nn.ReLU(),
50+
nn.Dropout1d(dropout_rate),
51+
] * fc_layer,
5752
nn.Linear(hidden_channels, out_channels),
5853
)
5954

6055
def forward(self, x: Tensor) -> Tensor:
61-
h = self.cnn(x)
62-
z = self.bottleneck(h.reshape(h.size(0), -1))
56+
z = self.fc(x.reshape(x.size(0), -1))
6357
return z
6458

6559

@@ -73,6 +67,7 @@ class Discriminator3dOption(NetworkOption):
7367
{"kernel_size": [3], "stride": [2], "padding": [1], "output_padding": [1]},
7468
]
7569
)
70+
dropout_rate: float = 0.5
7671
debug_show_dim: bool = False
7772

7873

@@ -83,6 +78,7 @@ def create_discriminator3d(opt: Discriminator3dOption) -> nn.Module:
8378
hidden_channels=opt.hidden_channels,
8479
image_size=opt.image_size,
8580
conv_params=opt.conv_params,
81+
dropout_rate=opt.dropout_rate,
8682
debug_show_dim=opt.debug_show_dim,
8783
)
8884

@@ -95,6 +91,7 @@ def __init__(
9591
hidden_channels: int,
9692
image_size: list[int],
9793
conv_params: list[dict[str, list[int]]],
94+
dropout_rate: float,
9895
debug_show_dim: bool,
9996
) -> None:
10097
super().__init__()
@@ -110,7 +107,13 @@ def __init__(
110107
size = image_size[0] * image_size[1] * image_size[2]
111108
self.bottleneck = nn.Sequential(
112109
nn.Linear(size * hidden_channels, hidden_channels),
110+
nn.BatchNorm1d(hidden_channels),
111+
nn.ReLU(),
112+
nn.Dropout1d(dropout_rate),
113+
nn.Linear(hidden_channels, hidden_channels),
114+
nn.BatchNorm1d(hidden_channels),
113115
nn.ReLU(),
116+
nn.Dropout1d(dropout_rate),
114117
nn.Linear(hidden_channels, out_channels),
115118
)
116119

‎hrdae/models/vr_model.py

+6
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
@dataclass
2121
class VRModelOption(ModelOption):
2222
network: NetworkOption = MISSING
23+
network_weight: str = ""
2324
optimizer: OptimizerOption = MISSING
2425
scheduler: SchedulerOption = MISSING
2526
loss: dict[str, LossOption] = MISSING
@@ -31,6 +32,7 @@ class VRModel(Model):
3132
def __init__(
3233
self,
3334
network: nn.Module,
35+
network_weight: str,
3436
optimizer: Optimizer,
3537
scheduler: LRScheduler,
3638
criterion: nn.Module,
@@ -42,6 +44,9 @@ def __init__(
4244
self.criterion = criterion
4345
self.use_triplet = use_triplet
4446

47+
if network_weight != "":
48+
self.network.load_state_dict(torch.load(network_weight))
49+
4550
if torch.cuda.is_available():
4651
print("GPU is enabled")
4752
self.device = torch.device("cuda:0")
@@ -213,6 +218,7 @@ def create_vr_model(
213218
)
214219
return VRModel(
215220
network,
221+
opt.network_weight,
216222
optimizer,
217223
scheduler,
218224
criterion,

‎notebook/mmnist.ipynb

+27-14
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)
Failed to load comments.