9
9
import torch
10
10
import torch .nn .functional as F
11
11
from torch import nn
12
- from torch .optim .lr_scheduler import LambdaLR
12
+ from torch .optim .lr_scheduler import LambdaLR , CosineAnnealingLR
13
13
from torch .cuda .amp import autocast , GradScaler
14
14
15
15
from dalle2_pytorch .dalle2_pytorch import Decoder , DiffusionPrior
@@ -433,6 +433,7 @@ def __init__(
433
433
wd = 1e-2 ,
434
434
eps = 1e-8 ,
435
435
warmup_steps = None ,
436
+ cosine_decay_max_steps = None ,
436
437
max_grad_norm = 0.5 ,
437
438
amp = False ,
438
439
group_wd_params = True ,
@@ -454,15 +455,15 @@ def __init__(
454
455
# be able to finely customize learning rate, weight decay
455
456
# per unet
456
457
457
- lr , wd , eps , warmup_steps = map (partial (cast_tuple , length = self .num_unets ), (lr , wd , eps , warmup_steps ))
458
+ lr , wd , eps , warmup_steps , cosine_decay_max_steps = map (partial (cast_tuple , length = self .num_unets ), (lr , wd , eps , warmup_steps , cosine_decay_max_steps ))
458
459
459
460
assert all ([unet_lr <= 1e-2 for unet_lr in lr ]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
460
461
461
462
optimizers = []
462
463
schedulers = []
463
464
warmup_schedulers = []
464
465
465
- for unet , unet_lr , unet_wd , unet_eps , unet_warmup_steps in zip (decoder .unets , lr , wd , eps , warmup_steps ):
466
+ for unet , unet_lr , unet_wd , unet_eps , unet_warmup_steps , unet_cosine_decay_max_steps in zip (decoder .unets , lr , wd , eps , warmup_steps , cosine_decay_max_steps ):
466
467
if isinstance (unet , nn .Identity ):
467
468
optimizers .append (None )
468
469
schedulers .append (None )
@@ -478,7 +479,11 @@ def __init__(
478
479
)
479
480
480
481
optimizers .append (optimizer )
481
- scheduler = LambdaLR (optimizer , lr_lambda = lambda step : 1.0 )
482
+
483
+ if exists (unet_cosine_decay_max_steps ):
484
+ scheduler = CosineAnnealingLR (optimizer , T_max = unet_cosine_decay_max_steps )
485
+ else :
486
+ scheduler = LambdaLR (optimizer , lr_lambda = lambda step : 1.0 )
482
487
483
488
warmup_scheduler = warmup .LinearWarmup (optimizer , warmup_period = unet_warmup_steps ) if exists (unet_warmup_steps ) else None
484
489
warmup_schedulers .append (warmup_scheduler )
@@ -558,9 +563,15 @@ def save(self, path, overwrite = True, **kwargs):
558
563
559
564
for ind in range (0 , self .num_unets ):
560
565
optimizer_key = f'optim{ ind } '
566
+ scheduler_key = f'sched{ ind } '
567
+
561
568
optimizer = getattr (self , optimizer_key )
562
- state_dict = optimizer .state_dict () if optimizer is not None else None
563
- save_obj = {** save_obj , optimizer_key : state_dict }
569
+ scheduler = getattr (self , scheduler_key )
570
+
571
+ optimizer_state_dict = optimizer .state_dict () if exists (optimizer ) else None
572
+ scheduler_state_dict = scheduler .state_dict () if exists (scheduler ) else None
573
+
574
+ save_obj = {** save_obj , optimizer_key : optimizer_state_dict , scheduler_key : scheduler_state_dict }
564
575
565
576
if self .use_ema :
566
577
save_obj = {** save_obj , 'ema' : self .ema_unets .state_dict ()}
@@ -581,10 +592,18 @@ def load_state_dict(self, loaded_obj, only_model = False, strict = True):
581
592
582
593
optimizer_key = f'optim{ ind } '
583
594
optimizer = getattr (self , optimizer_key )
595
+
596
+ scheduler_key = f'sched{ ind } '
597
+ scheduler = getattr (self , scheduler_key )
598
+
584
599
warmup_scheduler = self .warmup_schedulers [ind ]
585
- if optimizer is not None :
600
+
601
+ if exists (optimizer ):
586
602
optimizer .load_state_dict (loaded_obj [optimizer_key ])
587
603
604
+ if exists (scheduler ):
605
+ scheduler .load_state_dict (loaded_obj [scheduler_key ])
606
+
588
607
if exists (warmup_scheduler ):
589
608
warmup_scheduler .last_step = last_step
590
609
0 commit comments