Skip to content

Commit 8f38339

Browse files
committed
give diffusion prior trainer cosine annealing lr too
1 parent 6b9b4b9 commit 8f38339

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

dalle2_pytorch/trainer.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def __init__(
182182
max_grad_norm = None,
183183
group_wd_params = True,
184184
warmup_steps = 1,
185+
cosine_decay_max_steps = None,
185186
**kwargs
186187
):
187188
super().__init__()
@@ -233,8 +234,11 @@ def __init__(
233234
**self.optim_kwargs,
234235
**kwargs
235236
)
236-
237-
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
237+
238+
if exists(cosine_decay_max_steps):
239+
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
240+
else:
241+
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
238242

239243
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
240244

@@ -271,6 +275,7 @@ def save(self, path, overwrite = True, **kwargs):
271275
# FIXME: LambdaLR can't be saved due to pickling issues
272276
save_obj = dict(
273277
optimizer = self.optimizer.state_dict(),
278+
scheduler = self.scheduler.state_dict(),
274279
warmup_scheduler = self.warmup_scheduler,
275280
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
276281
version = version.parse(__version__),
@@ -317,7 +322,9 @@ def load(self, path_or_state, overwrite_lr = True, strict = True):
317322
# unwrap the model when loading from checkpoint
318323
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
319324
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
325+
320326
self.optimizer.load_state_dict(loaded_obj['optimizer'])
327+
self.scheduler.load_state_dict(loaded_obj['scheduler'])
321328

322329
# set warmupstep
323330
if exists(self.warmup_scheduler):

dalle2_pytorch/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.8.0'
1+
__version__ = '1.8.1'

0 commit comments

Comments
 (0)