-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
108 lines (76 loc) · 3.36 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import config
from ext import now
from model import make_model_higher, respond_to
from model import load_model, save_model
from model import sgd, adaptive_sgd
from data import load_data, split_data, batchify_data
from torch import no_grad
from matplotlib.pyplot import plot, show
##
def main():
if config.attention_only:
from model2 import make_model_higher, respond_to
else: from model import make_model_higher, respond_to
if config.fresh_model:
save_model(make_model_higher())
model = load_model()
print('created model.',end=' ')
else:
model = load_model()
if not model:
save_model(make_model_higher())
model = load_model()
print('created model.',end=' ')
else:
print('loaded model.',end=' ')
print(f'info: {config.creation_info}')
data = load_data(frames=not config.attention_only)
data, data_dev = split_data(data)
if not config.batch_size or config.batch_size >= len(data):
config.batch_size = len(data)
one_batch = True
elif config.batch_size < 1:
config.batch_size = int(len(data)*config.batch_size)
one_batch = False
else: one_batch = False
print(f'hm data: {len(data)}, hm dev: {len(data_dev)}, bs: {config.batch_size}, lr: {config.learning_rate}, \ntraining started @ {now()}')
data_losss, dev_losss = [], []
if config.batch_size != len(data):
data_losss.append(dev_loss(model, data))
if config.dev_ratio:
dev_losss.append(dev_loss(model, data_dev))
if data_losss or dev_losss:
print(f'initial loss(es): {data_losss[-1] if data_losss else ""} {dev_losss[-1] if dev_losss else ""}')
for ep in range(config.hm_epochs):
loss = 0
for i, batch in enumerate(batchify_data(data, do_shuffle=not one_batch)):
# print(f'\tbatch {i}, started @ {now()}', flush=True)
batch_size = sum(len(sequence) for sequence in batch)
loss += respond_to(model, batch)
sgd(model, batch_size=batch_size) if config.optimizer == 'sgd' else \
adaptive_sgd(model, batch_size=batch_size)
# loss /= sum(len(sequence) for sequence in data)
if not one_batch: loss = dev_loss(model, data)
data_losss.append(loss)
if config.dev_ratio:
dev_losss.append(dev_loss(model, data_dev))
print(f'epoch {ep}, loss {loss}, dev loss {dev_losss[-1] if config.dev_ratio else ""}, completed @ {now()}', flush=True)
if config.ckp_per_ep and ((ep+1)%config.ckp_per_ep==0):
save_model(model,config.model_path+f'_ckp{ep}')
# data_losss.append(dev_loss(model, data))
# if config.dev_ratio:
# dev_losss.append(dev_loss(model, data_dev))
print(f'training ended @ {now()} \nfinal losses: {data_losss[-1]}, {dev_losss[-1] if config.dev_ratio else ""}', flush=True)
show(plot(data_losss))
if config.dev_ratio:
show(plot(dev_losss))
# if input(f'Save model as {config.model_path}? (y/n): ').lower() == 'y':
# save_model(load_model(), config.model_path + '_prev')
# save_model(model)
return model, [data_losss, dev_losss]
def dev_loss(model, batch):
with no_grad():
loss,_ = respond_to(model, batch, training_run=False)
return loss /sum(len(sequence) for sequence in batch)
if __name__ == '__main__':
main()