-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy path5_infer.py
61 lines (51 loc) · 1.83 KB
/
5_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import csv
import argparse
import numpy as np
import pickle as pkl
from os import remove
from os.path import join
from keras.models import load_model
from keras_self_attention import SeqWeightedAttention
from keras_ordered_neurons import ONLSTM
from helpers import load_embeddings_dict
from helpers import map_sentence, f1
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data-dir', default='data_dir')
parser.add_argument('--embeddings-type', default='elmo', choices=['elmo', 'bert'])
parser.add_argument('--model-path', default='checkpoints/epoch100.h5')
parser.add_argument('--threshold', default=0.5, type=float)
args = parser.parse_args()
embeddings_dict = load_embeddings_dict(join(args.data_dir, '%s_dict.pkl' % args.embeddings_type))
model = load_model(
filepath=args.model_path,
custom_objects={
'f1': f1,
'SeqWeightedAttention': SeqWeightedAttention,
'ONLSTM': ONLSTM
}
)
data = list()
with open(join(args.data_dir, 'test_processed.csv'), 'r') as file:
reader = csv.reader(file)
for idx, row in enumerate(reader):
print('Prepare Data: %s' % (idx + 1), end='\r')
data.append((
map_sentence(row[0], embeddings_dict),
map_sentence(row[1], embeddings_dict),
int(row[2])
))
try:
remove(join(args.data_dir, 'submit.csv'))
except:
pass
with open(join(args.data_dir, 'submit.csv'), 'w') as file:
writer = csv.writer(file)
writer.writerow(['QuestionPairID', 'Prediction'])
for idx, example in enumerate(data):
print('Predicting Example: %s' % (idx + 1), end='\r')
prediction = model.predict([[np.array(example[0])], [np.array(example[1])]]).squeeze()
if prediction >= args.threshold:
writer.writerow([example[2], 1])
else:
writer.writerow([example[2], 0])