-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathOneCycleLR.py
90 lines (63 loc) · 3.1 KB
/
OneCycleLR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Implement One Cycle Policy Algorithm in the Keras Callback Class
# from https://www.kaggle.com/robotdreams/one-cycle-policy-with-keras
# converted to run with tf.keras on tf2
import tensorflow as tf
from tensorflow import keras
from sklearn.metrics import log_loss, roc_auc_score, accuracy_score
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.metrics import binary_accuracy
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import *
import numpy as np
class CyclicLR(keras.callbacks.Callback):
def __init__(self,base_lr, max_lr, step_size, base_m, max_m, cyclical_momentum):
self.base_lr = base_lr
self.max_lr = max_lr
self.base_m = base_m
self.max_m = max_m
self.cyclical_momentum = cyclical_momentum
self.step_size = step_size
self.clr_iterations = 0.
self.cm_iterations = 0.
self.trn_iterations = 0.
self.history = {}
def clr(self):
cycle = np.floor(1+self.clr_iterations/(2*self.step_size))
if cycle == 2:
x = np.abs(self.clr_iterations/self.step_size - 2*cycle + 1)
return self.base_lr-(self.base_lr-self.base_lr/100)*np.maximum(0,(1-x))
else:
x = np.abs(self.clr_iterations/self.step_size - 2*cycle + 1)
return self.base_lr + (self.max_lr-self.base_lr)*np.maximum(0,(1-x))
def cm(self):
cycle = np.floor(1+self.clr_iterations/(2*self.step_size))
if cycle == 2:
x = np.abs(self.clr_iterations/self.step_size - 2*cycle + 1)
return self.max_m
else:
x = np.abs(self.clr_iterations/self.step_size - 2*cycle + 1)
return self.max_m - (self.max_m-self.base_m)*np.maximum(0,(1-x))
def on_train_begin(self, logs={}):
logs = logs or {}
if self.clr_iterations == 0:
K.set_value(self.model.optimizer.lr, self.base_lr)
else:
K.set_value(self.model.optimizer.lr, self.clr())
if self.cyclical_momentum == True:
if self.clr_iterations == 0:
K.set_value(self.model.optimizer.momentum, self.cm())
else:
K.set_value(self.model.optimizer.momentum, self.cm())
def on_batch_begin(self, batch, logs=None):
logs = logs or {}
self.trn_iterations += 1
self.clr_iterations += 1
self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr))
self.history.setdefault('iterations', []).append(self.trn_iterations)
if self.cyclical_momentum == True:
self.history.setdefault('momentum', []).append(K.get_value(self.model.optimizer.momentum))
for k, v in logs.items():
self.history.setdefault(k, []).append(v)
K.set_value(self.model.optimizer.lr, self.clr())
if self.cyclical_momentum == True:
K.set_value(self.model.optimizer.momentum, self.cm())