-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
executable file
·27 lines (21 loc) · 933 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import hydra, logging
import torch, glob, os
import numpy as np
from trainers import *
from models import *
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import CometLogger, TensorBoardLogger, WandbLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
log = logging.getLogger(__name__)
@hydra.main(config_path="configs", config_name="Canonical_fields.yaml")
def run(cfg):
seed_everything(cfg.utils.seed)
train_logger = eval(cfg.logging.type)(project = cfg.logging.project)
log.info(cfg)
print(os.getcwd())
checkpoint_callback = ModelCheckpoint(**cfg.callback.model_checkpoint.segmentation.args)
model = getattr(eval(cfg.trainer_file.file), cfg.trainer_file.type)(configs = cfg)
trainer = Trainer(**cfg.trainer, callbacks = [checkpoint_callback], logger = train_logger)
trainer.fit(model)
if __name__ == '__main__':
run()