Skip to content

Commit 6b9b4b9

Browse files
committed
add cosine annealing lr schedule
1 parent 44e09d5 commit 6b9b4b9

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

dalle2_pytorch/trainer.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
import torch.nn.functional as F
1111
from torch import nn
12-
from torch.optim.lr_scheduler import LambdaLR
12+
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
1313
from torch.cuda.amp import autocast, GradScaler
1414

1515
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
@@ -433,6 +433,7 @@ def __init__(
433433
wd = 1e-2,
434434
eps = 1e-8,
435435
warmup_steps = None,
436+
cosine_decay_max_steps = None,
436437
max_grad_norm = 0.5,
437438
amp = False,
438439
group_wd_params = True,
@@ -454,15 +455,15 @@ def __init__(
454455
# be able to finely customize learning rate, weight decay
455456
# per unet
456457

457-
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
458+
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
458459

459460
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
460461

461462
optimizers = []
462463
schedulers = []
463464
warmup_schedulers = []
464465

465-
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
466+
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
466467
if isinstance(unet, nn.Identity):
467468
optimizers.append(None)
468469
schedulers.append(None)
@@ -478,7 +479,11 @@ def __init__(
478479
)
479480

480481
optimizers.append(optimizer)
481-
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
482+
483+
if exists(unet_cosine_decay_max_steps):
484+
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
485+
else:
486+
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
482487

483488
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
484489
warmup_schedulers.append(warmup_scheduler)
@@ -558,9 +563,15 @@ def save(self, path, overwrite = True, **kwargs):
558563

559564
for ind in range(0, self.num_unets):
560565
optimizer_key = f'optim{ind}'
566+
scheduler_key = f'sched{ind}'
567+
561568
optimizer = getattr(self, optimizer_key)
562-
state_dict = optimizer.state_dict() if optimizer is not None else None
563-
save_obj = {**save_obj, optimizer_key: state_dict}
569+
scheduler = getattr(self, scheduler_key)
570+
571+
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
572+
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
573+
574+
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
564575

565576
if self.use_ema:
566577
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
@@ -581,10 +592,18 @@ def load_state_dict(self, loaded_obj, only_model = False, strict = True):
581592

582593
optimizer_key = f'optim{ind}'
583594
optimizer = getattr(self, optimizer_key)
595+
596+
scheduler_key = f'sched{ind}'
597+
scheduler = getattr(self, scheduler_key)
598+
584599
warmup_scheduler = self.warmup_schedulers[ind]
585-
if optimizer is not None:
600+
601+
if exists(optimizer):
586602
optimizer.load_state_dict(loaded_obj[optimizer_key])
587603

604+
if exists(scheduler):
605+
scheduler.load_state_dict(loaded_obj[scheduler_key])
606+
588607
if exists(warmup_scheduler):
589608
warmup_scheduler.last_step = last_step
590609

dalle2_pytorch/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.7.0'
1+
__version__ = '1.8.0'

0 commit comments

Comments
 (0)