20
20
@dataclass
21
21
class GANModelOption (ModelOption ):
22
22
network : NetworkOption = MISSING
23
+ network_weight : str = ""
23
24
discriminator : NetworkOption = MISSING
24
25
optimizer_g : OptimizerOption = MISSING
25
26
optimizer_d : OptimizerOption = MISSING
@@ -35,6 +36,7 @@ class GANModel(Model):
35
36
def __init__ (
36
37
self ,
37
38
generator : nn .Module ,
39
+ generator_weight : str ,
38
40
discriminator : nn .Module ,
39
41
optimizer_g : Optimizer ,
40
42
optimizer_d : Optimizer ,
@@ -54,6 +56,9 @@ def __init__(
54
56
self .criterion_g = criterion_g
55
57
self .criterion_d = criterion_d
56
58
59
+ if generator_weight != "" :
60
+ self .generator .load_state_dict (torch .load (generator_weight ))
61
+
57
62
if torch .cuda .is_available ():
58
63
print ("GPU is enabled" )
59
64
self .device = torch .device ("cuda:0" )
@@ -108,7 +113,7 @@ def train(
108
113
mixed_state1 = state1 [shuffled_indices (batch_size )]
109
114
110
115
same = self .discriminator (torch .cat ([state1 , state2 ], dim = 1 ))
111
- diff = self .discriminator (torch .cat ([state1 , mixed_state1 ], dim = 1 ))
116
+ # diff = self.discriminator(torch.cat([state1, mixed_state1], dim=1))
112
117
113
118
loss_g_basic = self .criterion (
114
119
y ,
@@ -120,7 +125,7 @@ def train(
120
125
# diff == zerosなら、異なるビデオと見破られたことになるため、state encoderのロスは最大となる
121
126
loss_g_adv = self .criterion_g (
122
127
same , torch .zeros_like (same )
123
- ) + self .criterion_g (diff , torch .ones_like (diff ))
128
+ ) # + self.criterion_g(diff, torch.ones_like(diff))
124
129
125
130
loss_g = loss_g_basic + adv_ratio * loss_g_adv
126
131
loss_g .backward ()
@@ -139,9 +144,9 @@ def train(
139
144
diff = self .discriminator (
140
145
torch .cat ([state1 .detach (), mixed_state1 .detach ()], dim = 1 )
141
146
)
142
- loss_d_adv = self .criterion_d (
143
- same , torch .ones_like ( same )
144
- ) + self . criterion_d ( diff , torch . zeros_like ( diff ))
147
+ loss_d_adv_same = self .criterion_d (same , torch . ones_like ( same ))
148
+ loss_d_adv_diff = self . criterion_d ( diff , torch .zeros_like ( diff ) )
149
+ loss_d_adv = ( loss_d_adv_same + loss_d_adv_diff ) / 2
145
150
loss_d_adv .backward ()
146
151
self .optimizer_d .step ()
147
152
@@ -152,6 +157,8 @@ def train(
152
157
f"Epoch: { epoch + 1 } , "
153
158
f"Batch: { idx } , "
154
159
f"Loss D Adv: { loss_d_adv .item ():.6f} , "
160
+ f"Loss D Adv (same): { loss_d_adv_same .item ():.6f} , "
161
+ f"Loss D Adv (diff): { loss_d_adv_diff .item ():.6f} , "
155
162
f"Loss G: { loss_g .item ():.6f} , "
156
163
f"Loss G Adv: { loss_g_adv .item ():.6f} , "
157
164
f"Loss G Basic: { loss_g_basic .item ():.6f} , "
@@ -194,7 +201,7 @@ def train(
194
201
mixed_state1 = state1 [shuffled_indices (batch_size )]
195
202
196
203
same = self .discriminator (torch .cat ([state1 , state2 ], dim = 1 ))
197
- diff = self .discriminator (torch .cat ([state1 , mixed_state1 ], dim = 1 ))
204
+ # diff = self.discriminator(torch.cat([state1, mixed_state1], dim=1))
198
205
199
206
y = y .detach ().clone ()
200
207
loss_g_basic = self .criterion (
@@ -205,12 +212,12 @@ def train(
205
212
)
206
213
loss_g_adv = self .criterion_g (
207
214
same , torch .zeros_like (same )
208
- ) + self .criterion_g (diff , torch .ones_like (diff ))
215
+ ) # + self.criterion_g(diff, torch.ones_like(diff))
209
216
210
217
loss_g = loss_g_basic + adv_ratio * loss_g_adv
211
- loss_d_adv = self .criterion_d (
212
- same , torch .ones_like ( same )
213
- ) + self . criterion_d ( diff , torch . zeros_like ( diff ))
218
+ loss_d_adv_same = self .criterion_d (same , torch . ones_like ( same ))
219
+ loss_d_adv_diff = self . criterion_d ( diff , torch .zeros_like ( diff ) )
220
+ loss_d_adv = ( loss_d_adv_same + loss_d_adv_diff ) / 2
214
221
215
222
total_val_loss_g += loss_g .item ()
216
223
total_val_loss_g_basic += loss_g_basic .item ()
@@ -272,6 +279,9 @@ def train(
272
279
}
273
280
)
274
281
282
+ with open (result_dir / "training_history.json" , "w" ) as f :
283
+ json .dump (training_history , f , indent = 2 )
284
+
275
285
if epoch % 10 == 0 :
276
286
data = next (iter (val_loader ))
277
287
@@ -295,9 +305,6 @@ def train(
295
305
f"epoch_{ epoch } " ,
296
306
)
297
307
298
- with open (result_dir / "training_history.json" , "w" ) as f :
299
- json .dump (training_history , f )
300
-
301
308
return least_val_loss_g
302
309
303
310
@@ -353,6 +360,7 @@ def create_gan_model(
353
360
criterion_d = create_loss (opt .loss_d )
354
361
return GANModel (
355
362
generator ,
363
+ opt .network_weight ,
356
364
discriminator ,
357
365
optimizer_g ,
358
366
optimizer_d ,
0 commit comments