-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_translation_model.py
98 lines (75 loc) · 3.18 KB
/
train_translation_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import csv
import json
from pathlib import Path
import torch
from torch import optim
from libs.text_pair_dataset import TextPairDataset
from libs.transformer import Transformer
from libs.translation_model_trainer import TransformerLRScheduler, TranslationModelTrainer
def get_instance(params: dict):
transformer = Transformer(**params)
optimizer = optim.Adam(transformer.parameters(), betas=(0.9, 0.98), eps=1e-9)
lr_scheduler = TransformerLRScheduler(optimizer, transformer.n_dim, warmup_steps=4000)
return transformer, optimizer, lr_scheduler
def main():
import argparse
parser = argparse.ArgumentParser(description="Learning the model.")
parser.add_argument("dataset_dir", help="Dataset root directory path", type=str)
args = parser.parse_args()
base_path = Path(args.dataset_dir).resolve()
# 指定したディレクトリが存在しない場合は終了する
if not base_path.exists():
print("Target directory does not exist.")
return
# パラメータ設定の読み込み
with (base_path / "settings.json").open("r") as f:
settings = json.load(f)
# 学習データセット作成
src_txt_file_path = base_path / "src_train_texts.txt"
tgt_txt_file_path = base_path / "tgt_train_texts.txt"
src_word_freqs_path = base_path / "src_word_freqs.json"
tgt_word_freqs_path = base_path / "tgt_word_freqs.json"
src_min_freq = settings["min_freq"]["source"]
tgt_min_freq = settings["min_freq"]["target"]
train_dataset = TextPairDataset.create(
src_txt_file_path,
tgt_txt_file_path,
src_word_freqs_path,
tgt_word_freqs_path,
src_min_freq,
tgt_min_freq,
)
# 検証データセット作成
src_val_txt_file_path = base_path / "src_val_texts.txt"
tgt_val_txt_file_path = base_path / "tgt_val_texts.txt"
valid_dataset = TextPairDataset.create(
src_val_txt_file_path,
tgt_val_txt_file_path,
src_word_freqs_path,
tgt_word_freqs_path,
src_min_freq,
tgt_min_freq,
)
# GPUが使える場合は使う
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
enc_vocab_size, dec_vocab_size = train_dataset.get_vocab_size()
settings["params"]["enc_vocab_size"] = enc_vocab_size
settings["params"]["dec_vocab_size"] = dec_vocab_size
# インスタンス作成
model, optimizer, lr_scheduler = get_instance(settings["params"])
# モデル保存パス
save_path = base_path / "models"
# Trainerの作成と学習の実行
translation_model_trainer = TranslationModelTrainer(
model, optimizer, lr_scheduler, device, train_dataset, valid_dataset, save_path
)
train_loss_list, valid_loss_list = translation_model_trainer.fit(**settings["training"])
# Lossをcsvファイルに保存
with (save_path / "loss.csv").open("w") as f:
header = ["epoch", "train_loss", "valid_loss"]
csv_writer = csv.writer(f)
csv_writer.writerow(header)
for epoch, row in enumerate(zip(train_loss_list, valid_loss_list), start=1):
csv_writer.writerow((epoch,) + row)
if __name__ == "__main__":
main()