-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun-ibm1-VI.py
70 lines (51 loc) · 1.84 KB
/
run-ibm1-VI.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from lib.IBM1 import IBM1
from lib.util import write_list, read_list, draw_weighted_alignment, plot_aer
from lib.aer_import import test
import matplotlib.pyplot as plt
import numpy as np
def main():
ibm = IBM1()
english_path = 'training/hansards.36.2.e'
french_path = 'training/hansards.36.2.f'
ibm.read_data(english_path, french_path, null=True, UNK=True, max_sents=1000, test_repr=False)
Save = False
T = 5
aers = []
for step in range(T):
print('Iteration {}'.format(step+1))
# setting saving paths
save_path = 'prediction/validation/IBM1/VI/'
model_path = '../../models/IBM1/VI/{0}-'.format(step+1)
alignment_path = save_path + 'prediction-{0}'.format(step+1)
ibm.epoch_VI(alpha=1, log=True, ELBO=False)
ibm.predict_alignment('validation/dev.f',
'validation/dev.e',
alignment_path)
aer = test('validation/dev.wa.nonullalign',
alignment_path)
aers.append(aer)
print('AER: {}'.format(aer))
print('Total NULL alignments: {}'.format(ibm.null_generations[-1]))
# draw weighted alignments for sentence 21 (not working properly)
# draw_weighted_alignment(ibm, alignment_path,
# '../validation/dev.f',
# '../validation/dev.e',
# '../prediction/validation/sentence-draws/IBM1-sentence-21-iter-{}'.format(step+1),
# sentence=21)
if Save:
# save translation probabilities
ibm.save_t(model_path)
if Save:
# save elbos
write_list(ibm.elbos, save_path + 'ELBOs')
# plot elbos
ibm.plot_elbos(save_path + 'ELBOs.pdf')
# save aers
write_list(aers, save_path + 'AERs')
# plot aers
plot_aer(aers, save_path)
# save total NULL alignments
write_list(ibm.null_generations, save_path + 'NULL-generations')
ibm.tabulate_t(english_words=['the', 'and', 'me', 'is', 'where', 'of', 'or', '-NULL-'], k=4)
if __name__ == "__main__":
main()