@@ -182,6 +182,7 @@ def __init__(
182
182
max_grad_norm = None ,
183
183
group_wd_params = True ,
184
184
warmup_steps = 1 ,
185
+ cosine_decay_max_steps = None ,
185
186
** kwargs
186
187
):
187
188
super ().__init__ ()
@@ -233,8 +234,11 @@ def __init__(
233
234
** self .optim_kwargs ,
234
235
** kwargs
235
236
)
236
-
237
- self .scheduler = LambdaLR (self .optimizer , lr_lambda = lambda _ : 1.0 )
237
+
238
+ if exists (cosine_decay_max_steps ):
239
+ self .scheduler = CosineAnnealingLR (optimizer , T_max = cosine_decay_max_steps )
240
+ else :
241
+ self .scheduler = LambdaLR (self .optimizer , lr_lambda = lambda _ : 1.0 )
238
242
239
243
self .warmup_scheduler = warmup .LinearWarmup (self .optimizer , warmup_period = warmup_steps ) if exists (warmup_steps ) else None
240
244
@@ -271,6 +275,7 @@ def save(self, path, overwrite = True, **kwargs):
271
275
# FIXME: LambdaLR can't be saved due to pickling issues
272
276
save_obj = dict (
273
277
optimizer = self .optimizer .state_dict (),
278
+ scheduler = self .scheduler .state_dict (),
274
279
warmup_scheduler = self .warmup_scheduler ,
275
280
model = self .accelerator .unwrap_model (self .diffusion_prior ).state_dict (),
276
281
version = version .parse (__version__ ),
@@ -317,7 +322,9 @@ def load(self, path_or_state, overwrite_lr = True, strict = True):
317
322
# unwrap the model when loading from checkpoint
318
323
self .accelerator .unwrap_model (self .diffusion_prior ).load_state_dict (loaded_obj ['model' ], strict = strict )
319
324
self .step .copy_ (torch .ones_like (self .step , device = self .device ) * loaded_obj ['step' ].to (self .device ))
325
+
320
326
self .optimizer .load_state_dict (loaded_obj ['optimizer' ])
327
+ self .scheduler .load_state_dict (loaded_obj ['scheduler' ])
321
328
322
329
# set warmupstep
323
330
if exists (self .warmup_scheduler ):
0 commit comments