-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmain.py
75 lines (61 loc) · 2.35 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import List
import hydra
from omegaconf import DictConfig
from pytorch_lightning.callbacks.base import Callback
import torch
import pytorch_lightning as pl
import argparse
from utils.custom_logger import CustomLogger
from data_loader.data_module import DataModule
from helper import load_system
from helper import config_init, refine_args
from helper.helper_slurm import run_cluster
def main(params: DictConfig, LightningSystem: pl.LightningModule, *args,
**kwargs):
params = config_init(params)
params = refine_args(params)
datamodule = DataModule(params.data)
# Init PyTorch Lightning model ⚡
model = LightningSystem(params, datamodule)
if params.ckpt is not None and params.ckpt != 'none':
if params.load_base:
model.load_base(params.ckpt)
else:
ckpt = torch.load(
params.ckpt,
map_location=lambda storage, loc: storage)['state_dict']
model.load_state_dict(ckpt, strict=not params.load_flexible)
logger = CustomLogger(save_dir=params.logger.save_dir,
name=params.logger.name,
version=params.logger.version,
test=params.test,
disable_logfile=params.disable_logfile)
# Init PyTorch Lightning callbacks ⚡
callbacks: List[Callback] = [
hydra.utils.instantiate(callback_conf)
for _, callback_conf in params["callbacks"].items()
] if "callbacks" in params else []
trainer = pl.Trainer.from_argparse_args(
argparse.Namespace(**params.trainer),
logger=logger,
callbacks=callbacks,
limit_test_batches=params.trainer.limit_val_batches)
if params.test:
out = trainer.test(model)
return out
else:
return trainer.fit(model)
@hydra.main(config_name="config", config_path="configs")
def hydra_main(cfg: DictConfig):
lt_system = load_system(cfg.system_name)
if cfg.launcher.name == "local":
# add Lightning parse
main(cfg, lt_system)
elif cfg.launcher.name == "slurm":
# submit job to slurm
run_cluster(cfg, main, lt_system)
elif cfg.launcher.name == "submitit_eval":
from helper.helper_submitit_eval import submitit_eval_main
submitit_eval_main(cfg, lt_system)
if __name__ == "__main__":
hydra_main()