-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcurriculum.py
34 lines (26 loc) · 1.11 KB
/
curriculum.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import math
class Curriculum:
def __init__(self, args):
# args.dims and args.points each contain start, end, inc, interval attributes
# inc denotes the change in n_dims,
# this change is done every interval,
# and start/end are the limits of the parameter
self.n_dims_truncated = args.dims.start
self.n_points = args.points.start
self.n_dims_schedule = args.dims
self.n_points_schedule = args.points
self.step_count = 0
def update(self):
self.step_count += 1
self.n_dims_truncated = self.update_var(
self.n_dims_truncated, self.n_dims_schedule
)
self.n_points = self.update_var(self.n_points, self.n_points_schedule)
def update_var(self, var, schedule):
if self.step_count % schedule.interval == 0:
var += schedule.inc
return min(var, schedule.end)
# returns the final value of var after applying curriculum.
def get_final_var(init_var, total_steps, inc, n_steps, lim):
final_var = init_var + math.floor((total_steps) / n_steps) * inc
return min(final_var, lim)