-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
62 lines (41 loc) · 1.74 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import argparse
import configparser
import logging
import os
from data_loader import load_vocab, load_tokenizer, load_ontology, build_init_embeddings
from logger import set_logger
from train import train_model
def print_config(config):
config_str = []
for sec in config.sections():
config_str.append(f"{sec}:")
for k, v in config.items(sec):
config_str.append(f"\t{k}: {v}")
logging.info("Config:\n%s" % "\n".join(config_str))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--name", required=True, help="Used as a prefix for model name and log file name")
parser.add_argument("--config", default="configs/config_MAG.ini")
args = parser.parse_args()
config = configparser.ConfigParser()
config.read(args.config)
if not os.path.exists("logs"):
os.mkdir("logs")
if not os.path.exists("models"):
os.mkdir("models")
set_logger(args)
config.set("Model", "name", args.name)
print_config(config)
# Load spaCy tokenizer for English language
spacy_en = load_tokenizer()
# Tokenizer for source sequence (text in natural language)
tokenizer_src = lambda x: [tok.text for tok in spacy_en.tokenizer(x)]
# Build <label: ontology level> mapping
label2level = load_ontology(config["Paths"]["ontology"])
# Each vocab is instance of torchtext.vocab.Vocab, which maps tokens/labels to their unique ids
vocab_src, vocab_tgt = load_vocab(config, tokenizer_src, label2level)
# Initialize token embeddings from GloVe
build_init_embeddings(config, vocab_src)
train_model(vocab_src, vocab_tgt, tokenizer_src, config)
if __name__ == '__main__':
main()