-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
45 lines (35 loc) · 1.46 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
from src.corruptions import init_corruption
from src.datasets import init_dataset
from src.experiments import init_experiment
from src.modules import init_module
from src.optimizers import init_optimizer
from src.utils.sacred import sacred_run
def init_and_run(experiment, modules, corruption, corruption_test, datasets, optimizers, _run, _log, _seed):
torch.backends.cudnn.benchmark = True
torch.set_num_threads(2)
# initializing corruption
corr = init_corruption(**corruption)
corr_test = init_corruption(**corruption_test)
# initializing datasets
dsets = {}
for dataset_name, dataset_config in datasets.items():
if dataset_name in ['test', 'val']:
dsets[dataset_name] = init_dataset(corr_test, **dataset_config)
else:
dsets[dataset_name] = init_dataset(corr, **dataset_config)
# initializing modules
mods = {}
for module_name, module_config in modules.items():
mods[module_name] = init_module(**module_config)
# initializing optimizers
optims = {}
for optimizer_name, optimizer_config in optimizers.items():
optims[optimizer_name] = init_optimizer(mods, **optimizer_config)
# initializing experiment and running it
init_experiment(sacred_run=_run, seed=_seed,
corruption=corr,
**dsets, **mods, **optims,
**experiment).run()
if __name__ == '__main__':
sacred_run(init_and_run)