-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathoptimize_nasbench201.py
68 lines (53 loc) · 1.91 KB
/
optimize_nasbench201.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import sys
import numpy as np
from util.utils import get_logger
from optimizer import opts
from util.experiment_helper import apply_knowledge_augmentation, exist_file
from targets.nasbench201.api import ConstraintChoices, DatasetChoices, NASBench201
from util.utils import get_args_from_parser, get_filename_from_args
constraint_choices = [
[ConstraintChoices.runtime, ConstraintChoices.size_in_mb],
[ConstraintChoices.size_in_mb],
[ConstraintChoices.runtime]
]
if __name__ == '__main__':
args = get_args_from_parser(DatasetChoices, opts=opts)
constraints = constraint_choices[args.constraint_mode]
file_name = get_filename_from_args('nasbench201', constraints, args)
if exist_file(file_name, args.max_evals):
print('Skip the optimization')
sys.exit()
logger = get_logger(file_name=file_name, logger_name=file_name)
seed = args.exp_id
bm = NASBench201(
dataset=getattr(DatasetChoices, args.dataset),
feasible_domain_ratio=args.feasible_domain,
constraints=constraint_choices[args.constraint_mode],
seed=seed
)
obj_func = bm.objective_func
kwargs = dict(
obj_func=obj_func,
config_space=bm.config_space,
resultfile=file_name,
max_evals=args.max_evals,
constraints={k: v if args.constraint else np.inf for k, v in bm.constraints.items()},
seed=seed
)
if args.opt_name == 'hm':
kwargs.update(hypermapper_json='targets/nasbench201/hypermapper.json')
if args.naive:
kwargs.update(naive_mode=True)
opt = opts[args.opt_name](**kwargs)
apply_knowledge_augmentation(
args=args,
opt=opt,
logger=logger,
cheap_obj=bm.cheap_objective_func,
cheap_metrics=bm.cheap_metrics,
constraints=bm.constraints,
config_space=bm.config_space,
file_name=file_name,
seed=seed
)
opt.optimize(logger)