-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathoptimizers.py
103 lines (96 loc) · 2.45 KB
/
optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from pytorch_optimizer import create_optimizer
from pytorch_optimizer.optimizer import TRAC, Lookahead, OrthoGrad
def get_optimizer_profile(name="AdamW"):
lowercase_profiles = {k.lower(): v for k, v in OPTIMIZER_PROFILES.items()}
return {**lowercase_profiles.get(name.lower()), "wd_ban_list": WD_BAN_LIST}
def get_optimizer(model, trac=False, ortho=False, lookahead=False, *args, **kwargs):
optimizer = create_optimizer(model, *args, **kwargs)
if trac:
optimizer = TRAC(optimizer, num_coefs=128)
if ortho:
optimizer = OrthoGrad(optimizer)
if lookahead:
optimizer = Lookahead(optimizer, k=5, alpha=0.5, pullback_momentum="none")
return optimizer
# Most optimizer settings can be found here:
# https://pytorch-optimizers.readthedocs.io/en/latest/optimizer
OPTIMIZER_PROFILES = {
"AdamG": dict(
optimizer_name="AdamG",
lr=1.0,
weight_decay=0.1,
p=0.5,
q=0.24,
betas=(0.95, 0.999, 0.95),
),
"AdamW": dict(
optimizer_name="AdamW",
lr=1e-3,
weight_decay=0.1,
betas=(0.9, 0.95),
),
"AdEMAMix": dict(
optimizer_name="AdEMAMix",
lr=0.001,
weight_decay=0.1,
betas=(0.9, 0.95, 0.9999),
alpha=5.0,
cautious=True,
),
"Grams": dict(
optimizer_name="Grams",
lr=0.001,
betas=(0.9, 0.95),
weight_decay=0.1,
),
"Lion": dict(
optimizer_name="Lion",
lr=0.000333,
weight_decay=0.1,
betas=(0.9, 0.95),
r=0.95,
use_gc=True,
adanorm=True,
cautious=True,
),
"Prodigy": dict(
optimizer_name="Prodigy",
lr=1.0,
weight_decay=0.1,
betas=(0.9, 0.95),
bias_correction=True,
safeguard_warmup=False,
),
"SOAP": dict(
optimizer_name="SOAP",
lr=0.003,
weight_decay=0.1,
betas=(0.95, 0.95),
precondition_frequency=10,
max_precondition_dim=10000,
normalize_gradient=False,
correct_bias=True,
precondition_1d=False,
merge_dims=False,
),
}
WD_BAN_LIST = [
"bias",
"edge_embeddings",
"spatial_embeddings",
"Embedding",
"BatchNorm",
"BatchNorm1d",
"BatchNorm2d",
"BatchNorm3d",
"GroupNorm",
"LayerNorm",
"RMSNorm",
"InstanceNorm",
"InstanceNorm1d",
"InstanceNorm3d",
"InstanceNorm2d",
"PReLU",
"SinLU",
"NMDA",
]