-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathbert_evaluation.py
117 lines (73 loc) · 3.2 KB
/
bert_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import json
import os
import numpy as np
import torch
from scipy.spatial.distance import cosine
from transformers import BertModel, BertTokenizer
from tqdm import tqdm
from scipy.stats import spearmanr
import string
PUNCTUATION = list(string.punctuation)
def calculate_cosine(tokenizer, model, texts):
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
for _ in inputs:
inputs[_] = inputs[_].cuda()
temp = inputs["input_ids"]
# Get the embeddings
with torch.no_grad():
embeddings = model(**inputs, output_hidden_states=True, return_dict=True).last_hidden_state.cpu()
embeddings = embeddings[temp == tokenizer.mask_token_id]
embeddings = embeddings.tolist()
score = 1 - cosine(embeddings[0], embeddings[1])
return score
def main(ckpt_path, evaluation_file):
tokenizer = BertTokenizer.from_pretrained(r"bert-base-uncased")
model = BertModel.from_pretrained(r"bert-base-uncased").cuda()
model.eval()
temp = {"mask_token": tokenizer.mask_token}
tokenizer.add_special_tokens(temp)
# File path of STSB development set
path = r"Files/STSB_dev_set"
file = os.listdir(path)
for ckpt in os.listdir(ckpt_path):
_path = os.path.join(ckpt_path, ckpt)
if not os.path.isdir(_path):
continue
for _ in os.listdir(_path):
if _ != "pytorch_model.bin":
os.remove(os.path.join(_path, _))
_path = os.path.join(_path, "pytorch_model.bin")
params = torch.load(_path)
key = [_ for _ in params.keys() if _[5:] in model.state_dict()]
values = [params[_] for _ in key]
params = dict(zip([_[5:] for _ in key], values))
for _ in model.state_dict():
if _ not in params:
params[_] = torch.zeros_like(model.state_dict()[_])
model.load_state_dict(params)
_labels = list()
_scores = list()
for _file in file:
if ".input." in _file:
f = open(os.path.join(path, _file), encoding="utf-8")
for line in tqdm(f):
texts = line.strip().split("\t")
texts = [_ + " ." if _.strip()[-1] not in PUNCTUATION else _ for _ in texts]
texts[0] = '''This sentence : " ''' + texts[0] + ''' " means [MASK] .'''
texts[1] = '''This sentence : " ''' + texts[1] + ''' " means [MASK] .'''
_scores.append(calculate_cosine(tokenizer=tokenizer, model=model, texts=texts))
f.close()
_file = _file.replace(".input.", ".gs.")
f = open(os.path.join(path, _file))
for line in f:
line = line.strip()
_labels.append(float(line))
f.close()
f = open(evaluation_file, "a")
_temp = {ckpt: str(spearmanr(_labels, _scores)[0])}
f.write(str(_temp) + "\n")
f.close()
if __name__ == "__main__":
ckpt_path = r""
evaluation_file = r""
main(ckpt_path=ckpt_path, evaluation_file=evaluation_file)