-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainlm.py
151 lines (120 loc) · 5.28 KB
/
trainlm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup
import nltk.translate.bleu_score as bleu
from model.lm import languageModel
from model.seq2seq import Seq2SeqBaseModel
from utils.argsLM import getArgs
from utils.dataset import CustomDataset
from utils.process import encode, decode, getData
from utils.tools import getVocab
import torch.nn as nn
args = getArgs()
best_bleu = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
word2id, id2word = getVocab(path="./resources/couplet/vocabs")
data = getData(dataPathIn='./resources/poem/train/in.txt', dataPathOut='./resources/poem/train/out.txt',
word2id=word2id, id2word=id2word)
dataDev = getData(dataPathIn='./resources/couplet/test/in.txt', dataPathOut='./resources/couplet/test/out.txt',
word2id=word2id, id2word=id2word)
# data = getData(dataPathIn='./resources/couplet/test/in.txt',dataPathOut='./resources/couplet/test/out.txt',word2id=word2id,id2word=id2word)
trainDataset = CustomDataset(data=data, word2id=word2id, id2word=id2word, device=device)
devDataset = CustomDataset(data=dataDev, word2id=word2id, id2word=id2word, device=device)
dataloader = DataLoader(trainDataset, batch_size=args.train_batch_size, shuffle=True,
collate_fn=trainDataset.collate_fn)
dataloaderDev = DataLoader(devDataset, batch_size=args.eval_batch_size, shuffle=False,
collate_fn=trainDataset.collate_fn)
model = languageModel(vocab_size=len(word2id), embedding_dim=args.embedding_dim, hidden_dim=args.hidden_dim, num_layers=args.num_layers, dropout=args.dropout)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
scheduler = get_linear_schedule_with_warmup(optimizer, len(dataloader), 10 * len(dataloader))
class AverageMeter: # 为了tqdm实时显示loss和acc
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def train():
# config = Config("resources/couplet/chars_sort.txt")
global best_bleu
print("start training...")
for epoch in range(args.num_train_epochs):
model.train() # set mode to train
losses = AverageMeter()
clips = AverageMeter()
optimizer.zero_grad()
tk = tqdm(dataloader, total=len(dataloader), position=0, leave=True)
for data in tk:
inputs, decoderInput, labels = data
logits, loss = model(decoderInput, labels)
losses.update(loss.item(), logits.size(0))
loss.backward()
clip = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
clips.update(clip.item(), logits.size(0))
tk.set_postfix(loss=losses.avg,clips = clips.avg)
optimizer.step()
optimizer.zero_grad()
scheduler.step()
bleu_score = eval()
if bleu_score > best_bleu:
best_bleu = bleu_score
torch.save(model.state_dict(),
"bsz"+str(args.train_batch_size)+"ed"+str(args.embedding_dim)+"tb"+str(args.train_batch_size)+"hd"+str(args.hidden_dim)+"bs"+str(args.num_layers)+"lr"+str(args.learning_rate)+'seq2seqPoem{}.pt'.format(best_bleu))
generate("一醉一相逢")
generate("一日一相逢")
generate("你好")
generate("我是")
def generate(inputs):
model.eval()
with torch.no_grad():
inputs = torch.tensor(encode(inputs, word2id)).unsqueeze(dim=0).to(device)
predictList, prbList = model.generate(inputs, maxLen=64)
if len(predictList) < 5:
generate(inputs)
return
print(predictList)
print(decode(predictList, id2word))
def eval():
model.eval()
bleu = AverageMeter()
optimizer.zero_grad()
tk = tqdm(dataloaderDev, total=len(dataloaderDev), position=0, leave=True)
with torch.no_grad():
for data in tqdm(tk):
inputs, decoderInput, labels = data
predictList, prbList = model.generate(decoderInput, maxLen=64)
predict = decode(predictList, id2word)
target = decode(labels.cpu().squeeze().tolist(), id2word)
print("----------")
print(predict)
print(target)
if len(predict) == 0:
bleu.update(0.0, inputs.shape[0])
else:
bleu.update(bleu_score(predict, target), inputs.shape[0])
tk.set_postfix(bleu=bleu.avg)
return bleu.avg
def bleu_score(predict, target):
predict = [item for item in predict]
target = [item for item in target]
return bleu.sentence_bleu(predict, target, weights=[1])
if __name__ == "__main__":
model.load_state_dict(torch.load("checkpoint/bsz32ed256tb32hd512bs2lr0.001seq2seqPoem0.09252526045827876.pt",map_location=device))
# generate("瑟批")
# train()
generate("您好北京")
generate("北京理工大小")
generate("春风吹")
generate("空间环境")
generate("人工智能")
generate("人工智障")
generate("自然语言处理")
generate("计算机科学")